data_augmentation.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import copy
  2. import random
  3. import numpy as np
  4. import math
  5. import cv2
  6. import numpy as np
  7. import math
  8. import cv2
  9. class LetterBox(object):
  10. def __init__(self, size={'width':640,'height':640}, auto=False, stride=32,*args, **kwargs):
  11. # 需要调整的额size
  12. self.size = size
  13. self.h = size["height"]
  14. self.w = size["width"]
  15. self.auto = auto # pass max size integer, automatically solve for short side using stride
  16. self.stride = stride # used with auto
  17. def __call__(self, im_lb):
  18. imgList = im_lb['imgList']
  19. lb = im_lb['lb']
  20. if lb is not None:
  21. assert imgList[0].shape[:2] == lb.shape[:2]
  22. ans_imgList = self.handle_imgList(imgList)
  23. # 处理label
  24. # 处理label
  25. if lb is not None:
  26. ans_lb = self.handle_single_label(lb)
  27. else:
  28. ans_lb = None
  29. returnObj = dict(imgList=ans_imgList, lb=ans_lb)
  30. return returnObj
  31. def handle_imgList(self,imgList):
  32. # 处理图片
  33. ans_imgList = []
  34. for per_img in imgList:
  35. ans_img = self.handle_single_img(per_img)
  36. ans_imgList.append(ans_img)
  37. return ans_imgList
  38. def get_offset(self,originImgSize={'width':4096,'height':7000}):
  39. # _240429_1543_
  40. # [特别注意]:ResizeBeforeLetterbox中重写了这个逻辑
  41. originH = originImgSize['height']
  42. originW = originImgSize['width']
  43. dstH = self.h
  44. dstW = self.w
  45. def fry_resize_realParams(originH, originW, dstH, dstW):
  46. r = min(dstH / originH, dstW / originW) # ratio of new/old
  47. resize_h, resize_w = int(round(originH * r)), int(round(originW * r)) # resized image
  48. total_pad_h = int(dstH - resize_h)
  49. total_pad_w = int(dstW - resize_w)
  50. assert total_pad_h >= 0, "total_pad_h 必须大于等于0"
  51. assert total_pad_w >= 0, "total_pad_w 必须大于等于0"
  52. assert total_pad_h == 0 or total_pad_w == 0, "total_pad_h 和 total_pad_w中必须有一个为0"
  53. pad_left = int(total_pad_w // 2)
  54. pad_right = total_pad_w - pad_left
  55. pad_top = int(total_pad_h // 2)
  56. pad_bottom = total_pad_h - pad_top
  57. before_letterbox_dict = {}
  58. before_letterbox_dict['ratio'] = r
  59. before_letterbox_dict['resize_h'] = resize_h
  60. before_letterbox_dict['resize_w'] = resize_w
  61. before_letterbox_dict['total_pad_h'] = total_pad_h
  62. before_letterbox_dict['total_pad_w'] = total_pad_w
  63. before_letterbox_dict['pad_left'] = pad_left
  64. before_letterbox_dict['pad_right'] = pad_right
  65. before_letterbox_dict['pad_top'] = pad_top
  66. before_letterbox_dict['pad_bottom'] = pad_bottom
  67. return before_letterbox_dict
  68. before_letterbox_dict = fry_resize_realParams(originH,originW,dstH,dstW)
  69. rect_dict = {}
  70. rect_dict['x'] = before_letterbox_dict['pad_left']
  71. rect_dict['y'] = before_letterbox_dict['pad_top']
  72. rect_dict['width'] = before_letterbox_dict['resize_w']
  73. rect_dict['height'] = before_letterbox_dict['resize_h']
  74. rect_dict['ratio'] = before_letterbox_dict['ratio']
  75. return rect_dict
  76. def handle_single_img(self, im):
  77. assert len(im.shape) == 3, "im 必须是3维的"
  78. assert (im.shape[2] == 1) or (im.shape[2] == 3), "im 的通道数必须是一个通道或者三个通道"
  79. imh, imw = im.shape[:2]
  80. r = min(self.h / imh, self.w / imw) # ratio of new/old
  81. h, w = round(imh * r), round(imw * r) # resized image
  82. hs, ws = (math.ceil(x / self.stride) * self.stride for x in (h, w)) if self.auto else self.h, self.w
  83. top, left = round((hs - h) / 2 - 0.1), round((ws - w) / 2 - 0.1)
  84. # 这里弄成0没有关系,因为均值是0方差是1
  85. # 还是都弄成114吧
  86. if im.shape[2]==3:
  87. im_out = np.full((self.h, self.w, 3), 114, dtype=im.dtype)
  88. elif im.shape[2]==1:
  89. im_out = np.full((self.h, self.w, 1), 114, dtype=im.dtype)
  90. else:
  91. raise ValueError("图片的通道数异常")
  92. if im.shape[2]==1:
  93. gray_image_hw1 = im
  94. gray_image_hw = np.squeeze(gray_image_hw1, axis=-1)
  95. singleImg = gray_image_hw
  96. else:
  97. singleImg = im
  98. originImg_resized = cv2.resize(singleImg, (w, h), interpolation=cv2.INTER_LINEAR)
  99. if len(originImg_resized.shape)==2:
  100. newSingleImg2D = originImg_resized
  101. newSingleImg3D = np.expand_dims(newSingleImg2D, axis=-1)
  102. newSingleImg = newSingleImg3D
  103. else:
  104. newSingleImg = originImg_resized
  105. im_out[top:top + h, left:left + w] = newSingleImg
  106. return im_out
  107. def handle_single_label(self, im):
  108. assert len(im.shape) == 2, "label 必须是2维的"
  109. imh, imw = im.shape[:2]
  110. r = min(self.h / imh, self.w / imw) # ratio of new/old
  111. h, w = round(imh * r), round(imw * r) # resized image
  112. hs, ws = (math.ceil(x / self.stride) * self.stride for x in (h, w)) if self.auto else self.h, self.w
  113. top, left = round((hs - h) / 2 - 0.1), round((ws - w) / 2 - 0.1)
  114. # label 直接弄成255不参与计算
  115. im_out = np.full((self.h, self.w), 0, dtype=im.dtype)
  116. im_out[top:top + h, left:left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_NEAREST)
  117. return im_out