import os import cv2 import json import numpy as np from typing import Dict, List, Tuple, Any from shapely.geometry import Polygon, MultiPolygon import shutil def log_print(level_str: str, info_str: str): print(f"[{level_str}] : {info_str}") def fry_cv2_imread(filename, flags=cv2.IMREAD_COLOR): try: with open(filename, 'rb') as f: chunk = f.read() chunk_arr = np.frombuffer(chunk, dtype=np.uint8) img = cv2.imdecode(chunk_arr, flags) return img except IOError: return None def fry_cv2_imwrite(filename, img, params=None): try: ext = os.path.splitext(filename)[1].lower() if not ext: ext = ".jpg" # 默认格式 result, encoded_img = cv2.imencode(ext, img, params) if result: with open(filename, 'wb') as f: encoded_img.tofile(f) return True else: return False except Exception as e: log_print("错误", f"保存图片失败: {filename}, 错误: {e}") return False def to_json_serializable(obj): if isinstance(obj, (np.ndarray,)): return obj.tolist() if isinstance(obj, (np.integer,)): return int(obj) if isinstance(obj, (np.floating,)): return float(obj) if isinstance(obj, (Polygon, MultiPolygon)): return str(obj) # Or any other serializable representation return obj class CardDefectAggregator: """ 负责对大尺寸卡片图像进行分块、预测、结果合并的处理器, 有仅分割并保存图块。 """ def __init__(self, predictor: Any, # 传入一个类似 FryBisenetV2Predictor 的实例 tile_size: int = 512, overlap_ratio: float = 0.1): """ 初始化聚合器。 Args: predictor: 你的模型预测器实例,需要有一个 predict_from_image(img_bgr) 方法。 tile_size: 模型输入的小图块尺寸。 overlap_ratio: 滑窗切图的重叠率 (例如 0.1 表示 10%)。 """ self.predictor = predictor self.tile_size = tile_size self.overlap_ratio = overlap_ratio self.stride = int(self.tile_size * (1 - self.overlap_ratio)) # _calculate_face_tiles 和 _calculate_edge_tiles 方法保持不变 def _calculate_face_tiles(self, image_shape: Tuple[int, int]) -> List[Dict]: height, width = image_shape tiles = [] y_steps = range(0, height, self.stride) x_steps = range(0, width, self.stride) for r, y in enumerate(y_steps): for c, x in enumerate(x_steps): y_start = min(y, height - self.tile_size) if y + self.tile_size > height else y x_start = min(x, width - self.tile_size) if x + self.tile_size > width else x y_start, x_start = max(0, y_start), max(0, x_start) tile_info = {'origin': (x_start, y_start), 'row': r, 'col': c} if tile_info not in tiles: tiles.append(tile_info) return tiles def _calculate_edge_tiles(self, image_shape: Tuple[int, int]) -> List[Dict]: height, width = image_shape tiles = [] def add_tile(x, y, index): x_start = max(0, min(x, width - self.tile_size)) y_start = max(0, min(y, height - self.tile_size)) tile_info = {'origin': (x_start, y_start), 'index': index} if tile_info not in [t for t in tiles if t['origin'] == tile_info['origin']]: tiles.append(tile_info) idx = 0 for x in range(0, width, self.stride): add_tile(x, 0, idx); idx += 1 for x in range(0, width, self.stride): add_tile(x, height - self.tile_size, idx); idx += 1 for y in range(self.stride, height - self.stride, self.stride): add_tile(0, y, idx); idx += 1 for y in range(self.stride, height - self.stride, self.stride): add_tile(width - self.tile_size, y, idx); idx += 1 return tiles # _run_prediction_on_tiles 方法保持不变 def _run_prediction_on_tiles(self, image: np.ndarray, tiles: List[Dict]) -> List[Dict]: all_detections = [] for i, tile_info in enumerate(tiles): x_origin, y_origin = tile_info['origin'] tile_image = image[y_origin: y_origin + self.tile_size, x_origin: x_origin + self.tile_size] log_print("信息", f"正在预测图块 {i + 1}/{len(tiles)} at ({x_origin}, {y_origin})...") result_dict = self.predictor.predict_from_image(tile_image) if not result_dict or not result_dict.get('shapes'): continue for shape in result_dict['shapes']: global_points = [[p[0] + x_origin, p[1] + y_origin] for p in shape['points']] all_detections.append({ "class_num": shape['class_num'], "label": shape['label'], "probability": shape['probability'], "points": global_points }) log_print("成功", f"所有图块预测完成,共收集到 {len(all_detections)} 个初步检测结果。") return all_detections # split_and_save_tiles 方法保持不变 def split_and_save_tiles(self, image_path: str, output_dir: str, mode: str = 'face'): log_print("组开始", f"开始分割图片: {image_path}") image = fry_cv2_imread(image_path) if image is None: log_print("错误", f"无法读取图片: {image_path}") return [] if os.path.exists(output_dir): shutil.rmtree(output_dir) os.makedirs(output_dir, exist_ok=True) if mode == 'face': tiles = self._calculate_face_tiles(image.shape[:2]) elif mode == 'edge': tiles = self._calculate_edge_tiles(image.shape[:2]) else: raise ValueError(f"不支持的模式: {mode}。请选择 'face' 或 'edge'。") log_print("信息", f"计算得到 {len(tiles)} 个图块位置。") base_name = os.path.splitext(os.path.basename(image_path))[0] saved_files = [] for tile_info in tiles: x_origin, y_origin = tile_info['origin'] tile_image = image[y_origin: y_origin + self.tile_size, x_origin: x_origin + self.tile_size] if mode == 'face': filename = f"{base_name}_grid_r{tile_info['row']}_c{tile_info['col']}.jpg" else: filename = f"{base_name}_edge_{tile_info['index']}.jpg" output_path = os.path.join(output_dir, filename) if fry_cv2_imwrite(output_path, tile_image): saved_files.append(output_path) log_print("成功", f"成功保存 {len(saved_files)} 个图块到目录: {output_dir}") log_print("组结束", "图片分割完成。") return saved_files def _merge_detections_by_mask(self, image_shape: Tuple[int, int], detections: List[Dict]) -> List[Dict]: """ 使用蒙版合并法来处理重叠检测,并返回合并后的不重叠多边形。 Args: image_shape: 原始大图的尺寸 (height, width)。 detections: 未经NMS处理的所有检测结果列表。 Returns: 一个包含合并后、不重叠缺陷多边形的列表。 """ if not detections: return [] log_print("信息", "开始使用蒙版法合并重叠缺陷...") height, width = image_shape # 1. 按类别对检测结果进行分组 detections_by_class = {} for det in detections: class_label = det['label'] if class_label not in detections_by_class: detections_by_class[class_label] = [] detections_by_class[class_label].append(det) final_merged_defects = [] # 2. 对每个类别独立进行蒙版合并 for class_label, dets in detections_by_class.items(): # 创建该类别的专属蒙版 class_mask = np.zeros((height, width), dtype=np.uint8) # 将所有该类别的多边形画到蒙版上 for det in dets: try: contour = np.array(det['points'], dtype=np.int32) cv2.fillPoly(class_mask, [contour], color=255) except Exception as e: log_print("警告", f"为类别 '{class_label}' 绘制多边形失败: {e}") continue # 3. 从合并后的蒙版中提取轮廓 # cv2.RETR_EXTERNAL 只提取最外层的轮廓 contours, _ = cv2.findContours(class_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) # 4. 将提取出的轮廓转换回JSON格式 for contour in contours: # 过滤掉太小的噪点轮廓 (可选,但推荐) if cv2.contourArea(contour) < 5: # 面积小于5像素的轮廓可能是噪点 continue # 找到原始检测中与当前合并轮廓重叠度最高的那个,以继承其属性 # 这是一个可选的优化,也可以简单地取第一个或平均值 max_prob = 0 avg_prob = [] class_num = -1 for det in dets: avg_prob.append(det['probability']) if det['probability'] > max_prob: max_prob = det['probability'] class_num = det['class_num'] if not avg_prob: continue final_merged_defects.append({ "class_num": class_num, "label": class_label, "probability": np.mean(avg_prob), # 使用平均置信度 "points": contour.squeeze().tolist() # 将轮廓点转换为[[x1,y1], [x2,y2]...]格式 }) log_print("成功", f"蒙版合并完成,最终得到 {len(final_merged_defects)} 个独立缺陷。") return final_merged_defects def process_image(self, image: str|np.ndarray, output_json_path: str=None, mode: str = 'face'): """ 处理单张大图的完整流程:分块、预测、蒙版合并、保存结果。 Args: image : 输入大图的路径或bgr图片 output_json_path (str): 输出合并后JSON文件的路径。 mode (str): 处理模式,'face' (全图) 或 'edge' (仅边缘)。 """ log_print("组开始", f"开始处理图片: {image},模式: {mode}") if isinstance(image, str): image = fry_cv2_imread(image) if image is None: log_print("错误", f"无法读取图片: {image}") return # 1. 计算图块位置 if mode == 'face': tiles = self._calculate_face_tiles(image.shape[:2]) elif mode == 'edge': tiles = self._calculate_edge_tiles(image.shape[:2]) else: raise ValueError(f"不支持的模式: {mode}。请选择 'face' 或 'edge'。") log_print("信息", f"步骤1: 计算得到 {len(tiles)} 个图块位置。") # 2. 在所有图块上运行预测,得到初步的、可能重叠的检测结果 all_detections = self._run_prediction_on_tiles(image, tiles) # 3. 【核心修改】使用蒙版法合并重叠的检测 final_defects = self._merge_detections_by_mask(image.shape[:2], all_detections) # 4. 格式化并保存最终JSON output_data = { "num": len(final_defects), "cls": [d['class_num'] for d in final_defects], "names": [d['label'] for d in final_defects], "conf": np.mean([d['probability'] for d in final_defects]) if final_defects else 0.0, "shapes": final_defects } if output_json_path is not None: os.makedirs(os.path.dirname(output_json_path), exist_ok=True) with open(output_json_path, 'w', encoding='utf-8') as f: json.dump(output_data, f, ensure_ascii=False, indent=2, default=to_json_serializable) log_print("成功", f"最终结果已保存到: {output_json_path}") log_print("组结束", "处理完成。") return output_data