Răsfoiți Sursa

分割合并

AnlaAnla 3 luni în urmă
părinte
comite
16b2c0a9f9
3 a modificat fișierele cu 447 adăugiri și 8 ștergeri
  1. 26 8
      Test/model_test01.py
  2. 133 0
      Test/切割合并.py
  3. 288 0
      app/utils/CardDefectAggregator.py

+ 26 - 8
Test/model_test01.py

@@ -1,5 +1,8 @@
 import os
 from pathlib import Path
+import json
+import cv2
+
 from app.utils.fry_bisenetv2_predictor_V04_250819 import FryBisenetV2Predictor
 
 BASE_PATH = Path(__file__).parent.parent.absolute()
@@ -7,7 +10,8 @@ BASE_PATH = Path(__file__).parent.parent.absolute()
 
 def predict_single_image(config_params: dict,
                          img_path: str,
-                         output_dir: str):
+                         output_dir: str,
+                         only_json=False):
     # 配置参数
     model_path = BASE_PATH / config_params["pth_path"]
     real_seg_class_dict = config_params['class_dict']
@@ -35,12 +39,18 @@ def predict_single_image(config_params: dict,
     now_img_path = img_path
     answer_json_dir_str = output_dir
 
-    result = predictor.predict_single_image(
-        img_path=now_img_path,
-        save_visualization=True,
-        save_json=True,
-        answer_json_dir_str=answer_json_dir_str
-    )
+    if not only_json:
+        result = predictor.predict_single_image(
+            img_path=now_img_path,
+            save_visualization=True,
+            save_json=True,
+            answer_json_dir_str=answer_json_dir_str
+        )
+    else:
+        img_bgr = cv2.imread(now_img_path)
+        result = predictor.predict_from_image(img_bgr)
+
+    return result
 
 
 if __name__ == '__main__':
@@ -104,8 +114,16 @@ if __name__ == '__main__':
     #                      output_dir=r"C:\Code\ML\Project\CheckCardBoxAndDefectServer\temp\back_corner")
 
     # predict_single_image(config['pokemon_front_face_no_reflect_defect'],
-    #                      img_path=r"C:\Code\ML\Project\CheckCardBoxAndDefectServer\temp\250805_pokemon_0001_grid_r5_c2.jpg",
+    #                      img_path=r"C:\Code\ML\Project\CheckCardBoxAndDefectServer\temp\250805_pokemon_0001_grid_r4_c3.jpg",
     #                      output_dir=r"C:\Code\ML\Project\CheckCardBoxAndDefectServer\temp\face_no_reflect")
 
+    result = predict_single_image(config['pokemon_front_corner_no_reflect_defect'],
+                                  img_path=r"C:\Code\ML\Project\CheckCardBoxAndDefectServer\temp\00006_250805_pokemon_0001_bottom_grid_r0_c5.jpg",
+                                  output_dir=r"C:\Code\ML\Project\CheckCardBoxAndDefectServer\temp\corner_no_reflect")
 
+    # result = predict_single_image(config['pokemon_front_face_no_reflect_defect'],
+    #                      img_path=r"C:\Code\ML\Project\CheckCardBoxAndDefectServer\temp\250805_pokemon_0001_grid_r3_c4.jpg",
+    #                      output_dir=r"C:\Code\ML\Project\CheckCardBoxAndDefectServer\temp\face_no_reflect",
+    #                      only_json=True)
 
+    # print(result)

+ 133 - 0
Test/切割合并.py

@@ -0,0 +1,133 @@
+import json
+import numpy as np
+import cv2
+from typing import Dict
+from app.utils.CardDefectAggregator import CardDefectAggregator
+from pathlib import Path
+
+# 假设你的预测器类在这里,我们为了测试会创建一个MockPredictor
+from app.utils.fry_bisenetv2_predictor_V04_250819 import FryBisenetV2Predictor
+
+BASE_PATH = Path(__file__).parent.parent.absolute()
+
+
+def get_predictor(config_params: dict):
+    # 配置参数
+    model_path = BASE_PATH / config_params["pth_path"]
+    real_seg_class_dict = config_params['class_dict']
+    imgSize_train_dict = config_params['img_size']
+    confidence = config_params['confidence']
+    input_channels = config_params['input_channels']
+
+    # 为不同类别设置不同颜色(可选)
+    label_colors_dict = {
+        'outer_box': (255, 0, 0),
+    }
+
+    # 创建预测器
+    predictor = FryBisenetV2Predictor(
+        pth_path=str(model_path),
+        real_seg_class_dict=real_seg_class_dict,
+        imgSize_train_dict=imgSize_train_dict,
+        confidence=confidence,
+        label_colors_dict=label_colors_dict,
+        input_channels=input_channels,
+    )
+    return predictor
+
+
+def _test_face_big_img():
+    large_image_path = r"C:\Code\ML\Project\CheckCardBoxAndDefectServer\temp\250805_pokemon_0001.jpg"
+
+    pokemon_front_face_no_reflect_defect = {
+        "pth_path": "Model/pokemon_front_face_no_reflect_defect.pth",
+        "class_dict": {"1": "scratch", "2": "pit", "3": "stain"},
+        "img_size": {'width': 512, 'height': 512},
+        "confidence": 0.5,
+        "input_channels": 3,
+    }
+    predictor = get_predictor(pokemon_front_face_no_reflect_defect)
+
+    # 3. 实例化我们的聚合器,传入预测器
+    aggregator = CardDefectAggregator(
+        predictor=predictor,
+        tile_size=512,
+        overlap_ratio=0.1,  # 10% 重叠
+    )
+
+    # --- 执行任务 ---
+    # 任务1: 对整个卡片表面进行缺陷检测
+    aggregator.process_image(
+        image_path=large_image_path,
+        output_json_path="output/final_face_defects.json",
+        mode='face'
+    )
+
+    print("\n" + "=" * 50 + "\n")
+
+
+
+def _test_corner_big_img():
+    large_image_path = r"C:\Code\ML\Project\CheckCardBoxAndDefectServer\temp\250805_pokemon_0001.jpg"
+
+    pokemon_front_corner_no_reflect_defect = {
+        "pth_path": "Model/pokemon_front_corner_no_reflect_defect.pth",
+        "class_dict": {"1": "wear", "2": "wear_and_impact", "3": "impact", "4": "damaged"},
+        "img_size": {'width': 512, 'height': 512},
+        "confidence": 0.5,
+        "input_channels": 3,
+    }
+    predictor = get_predictor(pokemon_front_corner_no_reflect_defect)
+
+    # 3. 实例化我们的聚合器,传入预测器
+    aggregator = CardDefectAggregator(
+        predictor=predictor,
+        tile_size=512,
+        overlap_ratio=0.1,  # 10% 重叠
+    )
+
+    # 任务2: 仅对卡片边缘进行缺陷检测 (使用另一个模型)
+    # 假设你有一个专门用于边角的模型
+    aggregator.process_image(
+        image_path=large_image_path,
+        output_json_path="output/final_edge_defects.json",
+        mode='edge'
+    )
+
+
+def _test_split_img(split_mode):
+    # "edge", "face"
+    # --- 输入与输出路径配置 ---
+    large_image_path = r"C:\Code\ML\Project\CheckCardBoxAndDefectServer\temp\250805_pokemon_0001.jpg"
+    output_dir = r"C:\Code\ML\Project\CheckCardBoxAndDefectServer\temp\split_output"
+
+    # 打印信息
+    print(f"输入大图: {large_image_path}")
+    print(f"输出目录: {output_dir}")
+    print("-" * 50)
+
+    # --- 实例化聚合器 ---
+    # 由于我们只进行分割,不需要预测器,所以 predictor=None 即可。
+    aggregator = CardDefectAggregator(
+        predictor=None,  # 无需预测器
+        tile_size=512,
+        overlap_ratio=0.1,  # 10% 重叠
+    )
+
+    # --- 调用新的分割保存功能 ---
+    # mode='face' 表示对整个图片进行网格分割
+    saved_tile_paths = aggregator.split_and_save_tiles(
+        image_path=large_image_path,
+        output_dir=output_dir,
+        mode=split_mode
+    )
+
+    # 打印一些结果信息
+    if saved_tile_paths:
+        print(f"\n分割出的第一个文件是: {saved_tile_paths[0]}")
+        print(f"分割出的最后一个文件是: {saved_tile_paths[-1]}")
+
+if __name__ == "__main__":
+    # _test_face_big_img()
+    # _test_corner_big_img()
+    _test_split_img(split_mode='edge')

+ 288 - 0
app/utils/CardDefectAggregator.py

@@ -0,0 +1,288 @@
+import os
+import cv2
+import json
+import numpy as np
+from typing import Dict, List, Tuple, Any
+from shapely.geometry import Polygon, MultiPolygon
+import shutil
+
+
+def log_print(level_str: str, info_str: str):
+    print(f"[{level_str}] : {info_str}")
+
+
+def fry_cv2_imread(filename, flags=cv2.IMREAD_COLOR):
+    try:
+        with open(filename, 'rb') as f:
+            chunk = f.read()
+        chunk_arr = np.frombuffer(chunk, dtype=np.uint8)
+        img = cv2.imdecode(chunk_arr, flags)
+        return img
+    except IOError:
+        return None
+
+
+def fry_cv2_imwrite(filename, img, params=None):
+    try:
+        ext = os.path.splitext(filename)[1].lower()
+        if not ext: ext = ".jpg"  # 默认格式
+        result, encoded_img = cv2.imencode(ext, img, params)
+        if result:
+            with open(filename, 'wb') as f:
+                encoded_img.tofile(f)
+            return True
+        else:
+            return False
+    except Exception as e:
+        log_print("错误", f"保存图片失败: {filename}, 错误: {e}")
+        return False
+
+
+def to_json_serializable(obj):
+    if isinstance(obj, (np.ndarray,)):
+        return obj.tolist()
+    if isinstance(obj, (np.integer,)):
+        return int(obj)
+    if isinstance(obj, (np.floating,)):
+        return float(obj)
+    if isinstance(obj, (Polygon, MultiPolygon)):
+        return str(obj)  # Or any other serializable representation
+    return obj
+
+
+class CardDefectAggregator:
+    """
+    负责对大尺寸卡片图像进行分块、预测、结果合并的处理器, 有仅分割并保存图块。
+    """
+
+    def __init__(self,
+                 predictor: Any,  # 传入一个类似 FryBisenetV2Predictor 的实例
+                 tile_size: int = 512,
+                 overlap_ratio: float = 0.1):
+        """
+        初始化聚合器。
+
+        Args:
+            predictor: 你的模型预测器实例,需要有一个 predict_from_image(img_bgr) 方法。
+            tile_size: 模型输入的小图块尺寸。
+            overlap_ratio: 滑窗切图的重叠率 (例如 0.1 表示 10%)。
+        """
+        self.predictor = predictor
+        self.tile_size = tile_size
+        self.overlap_ratio = overlap_ratio
+        self.stride = int(self.tile_size * (1 - self.overlap_ratio))
+
+    # _calculate_face_tiles 和 _calculate_edge_tiles 方法保持不变
+    def _calculate_face_tiles(self, image_shape: Tuple[int, int]) -> List[Dict]:
+        height, width = image_shape
+        tiles = []
+        y_steps = range(0, height, self.stride)
+        x_steps = range(0, width, self.stride)
+        for r, y in enumerate(y_steps):
+            for c, x in enumerate(x_steps):
+                y_start = min(y, height - self.tile_size) if y + self.tile_size > height else y
+                x_start = min(x, width - self.tile_size) if x + self.tile_size > width else x
+                y_start, x_start = max(0, y_start), max(0, x_start)
+                tile_info = {'origin': (x_start, y_start), 'row': r, 'col': c}
+                if tile_info not in tiles:
+                    tiles.append(tile_info)
+        return tiles
+
+    def _calculate_edge_tiles(self, image_shape: Tuple[int, int]) -> List[Dict]:
+        height, width = image_shape
+        tiles = []
+
+        def add_tile(x, y, index):
+            x_start = max(0, min(x, width - self.tile_size))
+            y_start = max(0, min(y, height - self.tile_size))
+            tile_info = {'origin': (x_start, y_start), 'index': index}
+            if tile_info not in [t for t in tiles if t['origin'] == tile_info['origin']]:
+                tiles.append(tile_info)
+
+        idx = 0
+        for x in range(0, width, self.stride): add_tile(x, 0, idx); idx += 1
+        for x in range(0, width, self.stride): add_tile(x, height - self.tile_size, idx); idx += 1
+        for y in range(self.stride, height - self.stride, self.stride): add_tile(0, y, idx); idx += 1
+        for y in range(self.stride, height - self.stride, self.stride): add_tile(width - self.tile_size, y,
+                                                                                 idx); idx += 1
+        return tiles
+
+    # _run_prediction_on_tiles 方法保持不变
+    def _run_prediction_on_tiles(self, image: np.ndarray, tiles: List[Dict]) -> List[Dict]:
+        all_detections = []
+        for i, tile_info in enumerate(tiles):
+            x_origin, y_origin = tile_info['origin']
+            tile_image = image[y_origin: y_origin + self.tile_size, x_origin: x_origin + self.tile_size]
+            log_print("信息", f"正在预测图块 {i + 1}/{len(tiles)} at ({x_origin}, {y_origin})...")
+            result_dict = self.predictor.predict_from_image(tile_image)
+            if not result_dict or not result_dict.get('shapes'):
+                continue
+            for shape in result_dict['shapes']:
+                global_points = [[p[0] + x_origin, p[1] + y_origin] for p in shape['points']]
+                all_detections.append({
+                    "class_num": shape['class_num'],
+                    "label": shape['label'],
+                    "probability": shape['probability'],
+                    "points": global_points
+                })
+        log_print("成功", f"所有图块预测完成,共收集到 {len(all_detections)} 个初步检测结果。")
+        return all_detections
+
+    # split_and_save_tiles 方法保持不变
+    def split_and_save_tiles(self, image_path: str, output_dir: str, mode: str = 'face'):
+        log_print("组开始", f"开始分割图片: {image_path}")
+        image = fry_cv2_imread(image_path)
+        if image is None:
+            log_print("错误", f"无法读取图片: {image_path}")
+            return []
+        if os.path.exists(output_dir):
+            shutil.rmtree(output_dir)
+        os.makedirs(output_dir, exist_ok=True)
+        if mode == 'face':
+            tiles = self._calculate_face_tiles(image.shape[:2])
+        elif mode == 'edge':
+            tiles = self._calculate_edge_tiles(image.shape[:2])
+        else:
+            raise ValueError(f"不支持的模式: {mode}。请选择 'face' 或 'edge'。")
+        log_print("信息", f"计算得到 {len(tiles)} 个图块位置。")
+        base_name = os.path.splitext(os.path.basename(image_path))[0]
+        saved_files = []
+        for tile_info in tiles:
+            x_origin, y_origin = tile_info['origin']
+            tile_image = image[y_origin: y_origin + self.tile_size, x_origin: x_origin + self.tile_size]
+            if mode == 'face':
+                filename = f"{base_name}_grid_r{tile_info['row']}_c{tile_info['col']}.jpg"
+            else:
+                filename = f"{base_name}_edge_{tile_info['index']}.jpg"
+            output_path = os.path.join(output_dir, filename)
+            if fry_cv2_imwrite(output_path, tile_image):
+                saved_files.append(output_path)
+        log_print("成功", f"成功保存 {len(saved_files)} 个图块到目录: {output_dir}")
+        log_print("组结束", "图片分割完成。")
+        return saved_files
+
+    def _merge_detections_by_mask(self, image_shape: Tuple[int, int], detections: List[Dict]) -> List[Dict]:
+        """
+        使用蒙版合并法来处理重叠检测,并返回合并后的不重叠多边形。
+
+        Args:
+            image_shape: 原始大图的尺寸 (height, width)。
+            detections: 未经NMS处理的所有检测结果列表。
+
+        Returns:
+            一个包含合并后、不重叠缺陷多边形的列表。
+        """
+        if not detections:
+            return []
+
+        log_print("信息", "开始使用蒙版法合并重叠缺陷...")
+
+        height, width = image_shape
+
+        # 1. 按类别对检测结果进行分组
+        detections_by_class = {}
+        for det in detections:
+            class_label = det['label']
+            if class_label not in detections_by_class:
+                detections_by_class[class_label] = []
+            detections_by_class[class_label].append(det)
+
+        final_merged_defects = []
+
+        # 2. 对每个类别独立进行蒙版合并
+        for class_label, dets in detections_by_class.items():
+            # 创建该类别的专属蒙版
+            class_mask = np.zeros((height, width), dtype=np.uint8)
+
+            # 将所有该类别的多边形画到蒙版上
+            for det in dets:
+                try:
+                    contour = np.array(det['points'], dtype=np.int32)
+                    cv2.fillPoly(class_mask, [contour], color=255)
+                except Exception as e:
+                    log_print("警告", f"为类别 '{class_label}' 绘制多边形失败: {e}")
+                    continue
+
+            # 3. 从合并后的蒙版中提取轮廓
+            # cv2.RETR_EXTERNAL 只提取最外层的轮廓
+            contours, _ = cv2.findContours(class_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+
+            # 4. 将提取出的轮廓转换回JSON格式
+            for contour in contours:
+                # 过滤掉太小的噪点轮廓 (可选,但推荐)
+                if cv2.contourArea(contour) < 5:  # 面积小于5像素的轮廓可能是噪点
+                    continue
+
+                # 找到原始检测中与当前合并轮廓重叠度最高的那个,以继承其属性
+                # 这是一个可选的优化,也可以简单地取第一个或平均值
+                max_prob = 0
+                avg_prob = []
+                class_num = -1
+
+                for det in dets:
+                    avg_prob.append(det['probability'])
+                    if det['probability'] > max_prob:
+                        max_prob = det['probability']
+                        class_num = det['class_num']
+
+                if not avg_prob: continue
+
+                final_merged_defects.append({
+                    "class_num": class_num,
+                    "label": class_label,
+                    "probability": np.mean(avg_prob),  # 使用平均置信度
+                    "points": contour.squeeze().tolist()  # 将轮廓点转换为[[x1,y1], [x2,y2]...]格式
+                })
+
+        log_print("成功", f"蒙版合并完成,最终得到 {len(final_merged_defects)} 个独立缺陷。")
+        return final_merged_defects
+
+    # =========================================================
+    # ==================== 修改后的主流程 =====================
+    # =========================================================
+    def process_image(self, image_path: str, output_json_path: str, mode: str = 'face'):
+        """
+        处理单张大图的完整流程:分块、预测、蒙版合并、保存结果。
+
+        Args:
+            image_path (str): 输入大图的路径。
+            output_json_path (str): 输出合并后JSON文件的路径。
+            mode (str): 处理模式,'face' (全图) 或 'edge' (仅边缘)。
+        """
+        log_print("组开始", f"开始处理图片: {image_path},模式: {mode}")
+
+        image = fry_cv2_imread(image_path)
+        if image is None:
+            log_print("错误", f"无法读取图片: {image_path}")
+            return
+
+        # 1. 计算图块位置
+        if mode == 'face':
+            tiles = self._calculate_face_tiles(image.shape[:2])
+        elif mode == 'edge':
+            tiles = self._calculate_edge_tiles(image.shape[:2])
+        else:
+            raise ValueError(f"不支持的模式: {mode}。请选择 'face' 或 'edge'。")
+        log_print("信息", f"步骤1: 计算得到 {len(tiles)} 个图块位置。")
+
+        # 2. 在所有图块上运行预测,得到初步的、可能重叠的检测结果
+        all_detections = self._run_prediction_on_tiles(image, tiles)
+
+        # 3. 【核心修改】使用蒙版法合并重叠的检测
+        final_defects = self._merge_detections_by_mask(image.shape[:2], all_detections)
+
+        # 4. 格式化并保存最终JSON
+        output_data = {
+            "num": len(final_defects),
+            "cls": [d['class_num'] for d in final_defects],
+            "names": [d['label'] for d in final_defects],
+            "conf": np.mean([d['probability'] for d in final_defects]) if final_defects else 0.0,
+            "shapes": final_defects
+        }
+
+        os.makedirs(os.path.dirname(output_json_path), exist_ok=True)
+        with open(output_json_path, 'w', encoding='utf-8') as f:
+            json.dump(output_data, f, ensure_ascii=False, indent=2, default=to_json_serializable)
+
+        log_print("成功", f"最终结果已保存到: {output_json_path}")
+        log_print("组结束", "处理完成。")