data_augmentation.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. import numpy as np
  2. import math
  3. import cv2
  4. class LetterBox(object):
  5. def __init__(self, size={'width': 640, 'height': 640}, auto=False, stride=32, *args, **kwargs):
  6. # 需要调整的额size
  7. self.size = size
  8. self.h = size["height"]
  9. self.w = size["width"]
  10. self.auto = auto # pass max size integer, automatically solve for short side using stride
  11. self.stride = stride # used with auto
  12. def __call__(self, im_lb):
  13. imgList = im_lb['imgList']
  14. lb = im_lb['lb']
  15. if lb is not None:
  16. assert imgList[0].shape[:2] == lb.shape[:2]
  17. ans_imgList = self.handle_imgList(imgList)
  18. # 处理label
  19. # 处理label
  20. if lb is not None:
  21. ans_lb = self.handle_single_label(lb)
  22. else:
  23. ans_lb = None
  24. returnObj = dict(imgList=ans_imgList, lb=ans_lb)
  25. return returnObj
  26. def handle_imgList(self, imgList):
  27. # 处理图片
  28. ans_imgList = []
  29. for per_img in imgList:
  30. ans_img = self.handle_single_img(per_img)
  31. ans_imgList.append(ans_img)
  32. return ans_imgList
  33. def get_offset(self, originImgSize={'width': 4096, 'height': 7000}):
  34. # _240429_1543_
  35. # [特别注意]:ResizeBeforeLetterbox中重写了这个逻辑
  36. originH = originImgSize['height']
  37. originW = originImgSize['width']
  38. dstH = self.h
  39. dstW = self.w
  40. def fry_resize_realParams(originH, originW, dstH, dstW):
  41. r = min(dstH / originH, dstW / originW) # ratio of new/old
  42. resize_h, resize_w = int(round(originH * r)), int(round(originW * r)) # resized image
  43. total_pad_h = int(dstH - resize_h)
  44. total_pad_w = int(dstW - resize_w)
  45. assert total_pad_h >= 0, "total_pad_h 必须大于等于0"
  46. assert total_pad_w >= 0, "total_pad_w 必须大于等于0"
  47. assert total_pad_h == 0 or total_pad_w == 0, "total_pad_h 和 total_pad_w中必须有一个为0"
  48. pad_left = int(total_pad_w // 2)
  49. pad_right = total_pad_w - pad_left
  50. pad_top = int(total_pad_h // 2)
  51. pad_bottom = total_pad_h - pad_top
  52. before_letterbox_dict = {}
  53. before_letterbox_dict['ratio'] = r
  54. before_letterbox_dict['resize_h'] = resize_h
  55. before_letterbox_dict['resize_w'] = resize_w
  56. before_letterbox_dict['total_pad_h'] = total_pad_h
  57. before_letterbox_dict['total_pad_w'] = total_pad_w
  58. before_letterbox_dict['pad_left'] = pad_left
  59. before_letterbox_dict['pad_right'] = pad_right
  60. before_letterbox_dict['pad_top'] = pad_top
  61. before_letterbox_dict['pad_bottom'] = pad_bottom
  62. return before_letterbox_dict
  63. before_letterbox_dict = fry_resize_realParams(originH, originW, dstH, dstW)
  64. rect_dict = {}
  65. rect_dict['x'] = before_letterbox_dict['pad_left']
  66. rect_dict['y'] = before_letterbox_dict['pad_top']
  67. rect_dict['width'] = before_letterbox_dict['resize_w']
  68. rect_dict['height'] = before_letterbox_dict['resize_h']
  69. rect_dict['ratio'] = before_letterbox_dict['ratio']
  70. return rect_dict
  71. def handle_single_img(self, im):
  72. assert len(im.shape) == 3, "im 必须是3维的"
  73. assert (im.shape[2] == 1) or (im.shape[2] == 3), "im 的通道数必须是一个通道或者三个通道"
  74. imh, imw = im.shape[:2]
  75. r = min(self.h / imh, self.w / imw) # ratio of new/old
  76. h, w = round(imh * r), round(imw * r) # resized image
  77. hs, ws = (math.ceil(x / self.stride) * self.stride for x in (h, w)) if self.auto else self.h, self.w
  78. top, left = round((hs - h) / 2 - 0.1), round((ws - w) / 2 - 0.1)
  79. # 这里弄成0没有关系,因为均值是0方差是1
  80. # 还是都弄成114吧
  81. if im.shape[2] == 3:
  82. im_out = np.full((self.h, self.w, 3), 114, dtype=im.dtype)
  83. elif im.shape[2] == 1:
  84. im_out = np.full((self.h, self.w, 1), 114, dtype=im.dtype)
  85. else:
  86. raise ValueError("图片的通道数异常")
  87. if im.shape[2] == 1:
  88. gray_image_hw1 = im
  89. gray_image_hw = np.squeeze(gray_image_hw1, axis=-1)
  90. singleImg = gray_image_hw
  91. else:
  92. singleImg = im
  93. originImg_resized = cv2.resize(singleImg, (w, h), interpolation=cv2.INTER_LINEAR)
  94. if len(originImg_resized.shape) == 2:
  95. newSingleImg2D = originImg_resized
  96. newSingleImg3D = np.expand_dims(newSingleImg2D, axis=-1)
  97. newSingleImg = newSingleImg3D
  98. else:
  99. newSingleImg = originImg_resized
  100. im_out[top:top + h, left:left + w] = newSingleImg
  101. return im_out
  102. def handle_single_label(self, im):
  103. assert len(im.shape) == 2, "label 必须是2维的"
  104. imh, imw = im.shape[:2]
  105. r = min(self.h / imh, self.w / imw) # ratio of new/old
  106. h, w = round(imh * r), round(imw * r) # resized image
  107. hs, ws = (math.ceil(x / self.stride) * self.stride for x in (h, w)) if self.auto else self.h, self.w
  108. top, left = round((hs - h) / 2 - 0.1), round((ws - w) / 2 - 0.1)
  109. # label 直接弄成255不参与计算
  110. im_out = np.full((self.h, self.w), 0, dtype=im.dtype)
  111. im_out[top:top + h, left:left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_NEAREST)
  112. return im_out