CardDefectAggregator.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. import os
  2. import cv2
  3. import json
  4. import numpy as np
  5. from typing import Dict, List, Tuple, Any, Union
  6. from shapely.geometry import Polygon, MultiPolygon
  7. import shutil
  8. def log_print(level_str: str, info_str: str):
  9. print(f"[{level_str}] : {info_str}")
  10. def fry_cv2_imread(filename, flags=cv2.IMREAD_COLOR):
  11. try:
  12. with open(filename, 'rb') as f:
  13. chunk = f.read()
  14. chunk_arr = np.frombuffer(chunk, dtype=np.uint8)
  15. img = cv2.imdecode(chunk_arr, flags)
  16. return img
  17. except IOError:
  18. return None
  19. def fry_cv2_imwrite(filename, img, params=None):
  20. try:
  21. ext = os.path.splitext(filename)[1].lower()
  22. if not ext: ext = ".jpg" # 默认格式
  23. result, encoded_img = cv2.imencode(ext, img, params)
  24. if result:
  25. with open(filename, 'wb') as f:
  26. encoded_img.tofile(f)
  27. return True
  28. else:
  29. return False
  30. except Exception as e:
  31. log_print("错误", f"保存图片失败: {filename}, 错误: {e}")
  32. return False
  33. def to_json_serializable(obj):
  34. if isinstance(obj, (np.ndarray,)):
  35. return obj.tolist()
  36. if isinstance(obj, (np.integer,)):
  37. return int(obj)
  38. if isinstance(obj, (np.floating,)):
  39. return float(obj)
  40. if isinstance(obj, (Polygon, MultiPolygon)):
  41. return str(obj) # Or any other serializable representation
  42. return obj
  43. class CardDefectAggregator:
  44. """
  45. 负责对大尺寸卡片图像进行分块、预测、结果合并的处理器, 有仅分割并保存图块。
  46. """
  47. def __init__(self,
  48. predictor: Any, # 传入一个类似 FryBisenetV2Predictor 的实例
  49. tile_size: int = 512,
  50. overlap_ratio: float = 0.1):
  51. """
  52. 初始化聚合器。
  53. Args:
  54. predictor: 你的模型预测器实例,需要有一个 predict_from_image(img_bgr) 方法。
  55. tile_size: 模型输入的小图块尺寸。
  56. overlap_ratio: 滑窗切图的重叠率 (例如 0.1 表示 10%)。
  57. """
  58. self.predictor = predictor
  59. self.tile_size = tile_size
  60. self.overlap_ratio = overlap_ratio
  61. self.stride = int(self.tile_size * (1 - self.overlap_ratio))
  62. # _calculate_face_tiles 和 _calculate_edge_tiles 方法保持不变
  63. def _calculate_face_tiles(self, image_shape: Tuple[int, int]) -> List[Dict]:
  64. height, width = image_shape
  65. tiles = []
  66. y_steps = range(0, height, self.stride)
  67. x_steps = range(0, width, self.stride)
  68. for r, y in enumerate(y_steps):
  69. for c, x in enumerate(x_steps):
  70. y_start = min(y, height - self.tile_size) if y + self.tile_size > height else y
  71. x_start = min(x, width - self.tile_size) if x + self.tile_size > width else x
  72. y_start, x_start = max(0, y_start), max(0, x_start)
  73. tile_info = {'origin': (x_start, y_start), 'row': r, 'col': c}
  74. if tile_info not in tiles:
  75. tiles.append(tile_info)
  76. return tiles
  77. def _calculate_edge_tiles(self, image_shape: Tuple[int, int]) -> List[Dict]:
  78. height, width = image_shape
  79. tiles = []
  80. def add_tile(x, y, index):
  81. x_start = max(0, min(x, width - self.tile_size))
  82. y_start = max(0, min(y, height - self.tile_size))
  83. tile_info = {'origin': (x_start, y_start), 'index': index}
  84. if tile_info not in [t for t in tiles if t['origin'] == tile_info['origin']]:
  85. tiles.append(tile_info)
  86. idx = 0
  87. for x in range(0, width, self.stride): add_tile(x, 0, idx); idx += 1
  88. for x in range(0, width, self.stride): add_tile(x, height - self.tile_size, idx); idx += 1
  89. for y in range(self.stride, height - self.stride, self.stride): add_tile(0, y, idx); idx += 1
  90. for y in range(self.stride, height - self.stride, self.stride): add_tile(width - self.tile_size, y,
  91. idx); idx += 1
  92. return tiles
  93. # _run_prediction_on_tiles 方法保持不变
  94. def _run_prediction_on_tiles(self, image: np.ndarray, tiles: List[Dict]) -> List[Dict]:
  95. all_detections = []
  96. for i, tile_info in enumerate(tiles):
  97. x_origin, y_origin = tile_info['origin']
  98. tile_image = image[y_origin: y_origin + self.tile_size, x_origin: x_origin + self.tile_size]
  99. log_print("信息", f"正在预测图块 {i + 1}/{len(tiles)} at ({x_origin}, {y_origin})...")
  100. result_dict = self.predictor.predict_from_image(tile_image)
  101. if not result_dict or not result_dict.get('shapes'):
  102. continue
  103. for shape in result_dict['shapes']:
  104. global_points = [[p[0] + x_origin, p[1] + y_origin] for p in shape['points']]
  105. all_detections.append({
  106. "class_num": shape['class_num'],
  107. "label": shape['label'],
  108. "probability": shape['probability'],
  109. "points": global_points
  110. })
  111. log_print("成功", f"所有图块预测完成,共收集到 {len(all_detections)} 个初步检测结果。")
  112. return all_detections
  113. # split_and_save_tiles 方法保持不变
  114. def split_and_save_tiles(self, image_path: str, output_dir: str, mode: str = 'face'):
  115. log_print("组开始", f"开始分割图片: {image_path}")
  116. image = fry_cv2_imread(image_path)
  117. if image is None:
  118. log_print("错误", f"无法读取图片: {image_path}")
  119. return []
  120. if os.path.exists(output_dir):
  121. shutil.rmtree(output_dir)
  122. os.makedirs(output_dir, exist_ok=True)
  123. if mode == 'face':
  124. tiles = self._calculate_face_tiles(image.shape[:2])
  125. elif mode == 'edge':
  126. tiles = self._calculate_edge_tiles(image.shape[:2])
  127. else:
  128. raise ValueError(f"不支持的模式: {mode}。请选择 'face' 或 'edge'。")
  129. log_print("信息", f"计算得到 {len(tiles)} 个图块位置。")
  130. base_name = os.path.splitext(os.path.basename(image_path))[0]
  131. saved_files = []
  132. for tile_info in tiles:
  133. x_origin, y_origin = tile_info['origin']
  134. tile_image = image[y_origin: y_origin + self.tile_size, x_origin: x_origin + self.tile_size]
  135. if mode == 'face':
  136. filename = f"{base_name}_grid_r{tile_info['row']}_c{tile_info['col']}.jpg"
  137. else:
  138. filename = f"{base_name}_edge_{tile_info['index']}.jpg"
  139. output_path = os.path.join(output_dir, filename)
  140. if fry_cv2_imwrite(output_path, tile_image):
  141. saved_files.append(output_path)
  142. log_print("成功", f"成功保存 {len(saved_files)} 个图块到目录: {output_dir}")
  143. log_print("组结束", "图片分割完成。")
  144. return saved_files
  145. def _merge_detections_by_mask(self, image_shape: Tuple[int, int], detections: List[Dict]) -> List[Dict]:
  146. """
  147. 使用蒙版合并法来处理重叠检测,并返回合并后的不重叠多边形。
  148. Args:
  149. image_shape: 原始大图的尺寸 (height, width)。
  150. detections: 未经NMS处理的所有检测结果列表。
  151. Returns:
  152. 一个包含合并后、不重叠缺陷多边形的列表。
  153. """
  154. if not detections:
  155. return []
  156. log_print("信息", "开始使用蒙版法合并重叠缺陷...")
  157. height, width = image_shape
  158. # 1. 按类别对检测结果进行分组
  159. detections_by_class = {}
  160. for det in detections:
  161. class_label = det['label']
  162. if class_label not in detections_by_class:
  163. detections_by_class[class_label] = []
  164. detections_by_class[class_label].append(det)
  165. final_merged_defects = []
  166. # 2. 对每个类别独立进行蒙版合并
  167. for class_label, dets in detections_by_class.items():
  168. # 创建该类别的专属蒙版
  169. class_mask = np.zeros((height, width), dtype=np.uint8)
  170. # 将所有该类别的多边形画到蒙版上
  171. for det in dets:
  172. try:
  173. contour = np.array(det['points'], dtype=np.int32)
  174. cv2.fillPoly(class_mask, [contour], color=255)
  175. except Exception as e:
  176. log_print("警告", f"为类别 '{class_label}' 绘制多边形失败: {e}")
  177. continue
  178. # 3. 从合并后的蒙版中提取轮廓
  179. # cv2.RETR_EXTERNAL 只提取最外层的轮廓
  180. contours, _ = cv2.findContours(class_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  181. # 4. 将提取出的轮廓转换回JSON格式
  182. for contour in contours:
  183. # 过滤掉太小的噪点轮廓 (可选,但推荐)
  184. if cv2.contourArea(contour) < 5: # 面积小于5像素的轮廓可能是噪点
  185. continue
  186. # 找到原始检测中与当前合并轮廓重叠度最高的那个,以继承其属性
  187. # 这是一个可选的优化,也可以简单地取第一个或平均值
  188. max_prob = 0
  189. avg_prob = []
  190. class_num = -1
  191. for det in dets:
  192. avg_prob.append(det['probability'])
  193. if det['probability'] > max_prob:
  194. max_prob = det['probability']
  195. class_num = det['class_num']
  196. if not avg_prob: continue
  197. final_merged_defects.append({
  198. "class_num": class_num,
  199. "label": class_label,
  200. "probability": np.mean(avg_prob), # 使用平均置信度
  201. "points": contour.squeeze().tolist() # 将轮廓点转换为[[x1,y1], [x2,y2]...]格式
  202. })
  203. log_print("成功", f"蒙版合并完成,最终得到 {len(final_merged_defects)} 个独立缺陷。")
  204. return final_merged_defects
  205. def process_image(self, image: Union[str, np.ndarray], output_json_path: str = None, mode: str = 'face'):
  206. """
  207. 处理单张大图的完整流程:分块、预测、蒙版合并、保存结果。
  208. Args:
  209. image : 输入大图的路径或bgr图片
  210. output_json_path (str): 输出合并后JSON文件的路径。
  211. mode (str): 处理模式,'face' (全图) 或 'edge' (仅边缘)。
  212. """
  213. # log_print("组开始", f"开始处理图片: {image},模式: {mode}")
  214. if isinstance(image, str):
  215. image = fry_cv2_imread(image)
  216. if image is None:
  217. log_print("错误", f"无法读取图片: {image}")
  218. return
  219. # 1. 计算图块位置
  220. if mode == 'face':
  221. tiles = self._calculate_face_tiles(image.shape[:2])
  222. elif mode == 'edge':
  223. tiles = self._calculate_edge_tiles(image.shape[:2])
  224. else:
  225. raise ValueError(f"不支持的模式: {mode}。请选择 'face' 或 'edge'。")
  226. log_print("信息", f"步骤1: 计算得到 {len(tiles)} 个图块位置。")
  227. # 2. 在所有图块上运行预测,得到初步的、可能重叠的检测结果
  228. all_detections = self._run_prediction_on_tiles(image, tiles)
  229. # 3. 【核心修改】使用蒙版法合并重叠的检测
  230. final_defects = self._merge_detections_by_mask(image.shape[:2], all_detections)
  231. # 4. 格式化并保存最终JSON
  232. output_data = {
  233. "num": len(final_defects),
  234. "cls": [d['class_num'] for d in final_defects],
  235. "names": [d['label'] for d in final_defects],
  236. "conf": np.mean([d['probability'] for d in final_defects]) if final_defects else 0.0,
  237. "shapes": final_defects
  238. }
  239. if output_json_path is not None:
  240. os.makedirs(os.path.dirname(output_json_path), exist_ok=True)
  241. with open(output_json_path, 'w', encoding='utf-8') as f:
  242. json.dump(output_data, f, ensure_ascii=False, indent=2, default=to_json_serializable)
  243. log_print("成功", f"最终结果已保存到: {output_json_path}")
  244. log_print("组结束", "处理完成。")
  245. return output_data