소스 검색

新算法以及对应的新接口

袁威 1 주 전
부모
커밋
7a4db756a8

+ 76 - 0
Model/stitch_fusion/model_meta.json

@@ -0,0 +1,76 @@
+{
+  "modals": [
+    "img",
+    "ch1",
+    "ch2",
+    "ch3",
+    "ch4",
+    "ring"
+  ],
+  "classes": [
+    "slight_scratch",
+    "scratch",
+    "serious_scratch",
+    "damaged",
+    "impact",
+    "pit",
+    "stain",
+    "wear"
+  ],
+  "palette": [
+    [
+      128,
+      0,
+      0
+    ],
+    [
+      0,
+      128,
+      0
+    ],
+    [
+      128,
+      128,
+      0
+    ],
+    [
+      0,
+      0,
+      128
+    ],
+    [
+      128,
+      0,
+      128
+    ],
+    [
+      0,
+      128,
+      128
+    ],
+    [
+      255,
+      128,
+      0
+    ],
+    [
+      64,
+      0,
+      0
+    ]
+  ],
+  "image_size": [
+    768,
+    768
+  ],
+  "thresholds": [
+    0.04,
+    0.96,
+    0.02,
+    0.505,
+    0.01,
+    0.051,
+    0.323,
+    0.343
+  ]
+}

BIN
Model/stitch_fusion/model_traced.pt


+ 67 - 1
app/api/score_inference.py

@@ -5,6 +5,7 @@ from enum import Enum
 from typing import Optional, Dict, Any
 from ..core.config import settings
 from app.services.score_service import ScoreService
+from app.services.stitch_fusion_service import StitchFusionService
 import numpy as np
 import cv2
 import json
@@ -17,6 +18,9 @@ router = APIRouter()
 score_names = settings.SCORE_TYPE
 ScoreType = Enum("InferenceType", {name: name for name in score_names})
 
+stitch_score_names = settings.STITCH_SCORE_TYPE
+StitchScoreType = Enum("StitchScoreType", {name: name for name in stitch_score_names})
+
 
 @router.post("/score_inference", summary="输入卡片类型(正反面, 缺陷类型), 是否为反射卡")
 async def card_model_inference(
@@ -73,4 +77,66 @@ async def score_recalculate(score_type: ScoreType, json_data: Dict[str, Any]):
     except ValueError as e:
         raise HTTPException(status_code=400, detail=str(e))
     except Exception as e:
-        raise HTTPException(status_code=500, detail=f"服务器内部错误: {e}")
+        raise HTTPException(status_code=500, detail=f"服务器内部错误: {e}")
+
+
+@router.post("/stitch_score_inference",
+             summary="StitchFusion 多模态拼接缺陷推理 (一次提交6张同侧图)",
+             description="""
+一次提交一组 6 张同侧图片, 经 StitchFusion 单 PT 模型大图切片推理后, 返回 labelme 风格 mask 列表。
+
+字段对应 (按正反面分别上传, 一次只提交一侧):
+- ring     -> 环光图   (front_ring / back_ring)
+- gray     -> 灰度图   (front_gray / back_gray)
+- stripe1  -> 调光1    (front_stripe1 / back_stripe1)
+- stripe2  -> 调光2    (front_stripe2 / back_stripe2)
+- stripe3  -> 调光3    (front_stripe3 / back_stripe3)
+- stripe4  -> 调光4    (front_stripe4 / back_stripe4)
+""")
+async def stitch_score_inference(
+        score_type: StitchScoreType,
+        card_name: str = "",
+        ring: UploadFile = File(..., description="环光图"),
+        gray: UploadFile = File(..., description="灰度图"),
+        stripe1: UploadFile = File(..., description="调光1"),
+        stripe2: UploadFile = File(..., description="调光2"),
+        stripe3: UploadFile = File(..., description="调光3"),
+        stripe4: UploadFile = File(..., description="调光4"),
+):
+    """接收同一类型(正面或反面) 6 张图, 输出该组的缺陷 JSON。"""
+    variant_files = {
+        "ring": ring,
+        "gray": gray,
+        "stripe1": stripe1,
+        "stripe2": stripe2,
+        "stripe3": stripe3,
+        "stripe4": stripe4,
+    }
+    variant_bytes: Dict[str, bytes] = {}
+    for variant, upload in variant_files.items():
+        try:
+            variant_bytes[variant] = await upload.read()
+        except Exception as e:
+            raise HTTPException(status_code=400, detail=f"读取上传文件 {variant} 失败: {e}")
+        if not variant_bytes[variant]:
+            raise HTTPException(status_code=400, detail=f"上传文件 {variant} 内容为空")
+
+    service = StitchFusionService()
+    try:
+        json_result = await run_in_threadpool(
+            service.stitch_score_inference,
+            score_type=score_type.value,
+            card_name=card_name,
+            variant_bytes=variant_bytes,
+        )
+        return json_result
+    except FileNotFoundError as e:
+        raise HTTPException(status_code=500, detail=str(e))
+    except ValueError as e:
+        raise HTTPException(status_code=400, detail=str(e))
+    except Exception as e:
+        logger.exception("stitch_score_inference 失败")
+        raise HTTPException(status_code=500, detail=f"服务器内部错误: {e}")
+
+
+

+ 10 - 0
app/core/config.py

@@ -152,6 +152,16 @@ class Settings:
     SCORE_TYPE: List[str] = ["front_coaxial", "front_ring",
                              "back_coaxial", "back_ring"]
 
+    # ===================== StitchFusion 多模态拼接缺陷推理 =====================
+    # 单一 PT 模型 (TorchScript) + meta 配置, 一次推理 6 张同侧图。
+    STITCH_FUSION_MODEL_PT: Path = BASE_PATH / "Model/stitch_fusion/model_traced.pt"
+    STITCH_FUSION_MODEL_META: Path = BASE_PATH / "Model/stitch_fusion/model_meta.json"
+    STITCH_FUSION_DEVICE: str = "cuda:0"
+    STITCH_FUSION_THRESHOLD: float = 0.5
+
+    # 仅区分正反面, 用于命名/落盘和接口下拉选项
+    STITCH_SCORE_TYPE: List[str] = ["front", "back"]
+
 
 settings = Settings()
 print(f"项目根目录: {settings.BASE_PATH}")

+ 231 - 0
app/services/stitch_fusion_service.py

@@ -0,0 +1,231 @@
+"""
+StitchFusion 多模态拼接缺陷推理服务。
+- 懒加载 + 单例: 模型只加载一次, 加载失败给出明确错误。
+- 输入: 6 张同侧图片 (ring/gray/stripe1-4), 解码成 {modal: [3,H,W] uint8 tensor}。
+- 输出: 与 score_inference 接口一致的 result 结构 (center_result/defect_result/card_score 等)。
+  StitchFusion 模型只检测面缺陷, 因此走同轴光 (coaxial) 评分流程, center_result 留空。
+"""
+from pathlib import Path
+from threading import Lock
+from typing import Dict, Optional, Tuple
+
+import cv2
+import numpy as np
+import torch
+
+from app.core.config import settings
+from app.core.logger import get_logger
+from app.utils.defect_inference.arean_anylize_draw import DefectProcessor
+from app.utils.json_data_formate import formate_add_edit_type, formate_face_data
+from app.utils.score_inference.CardScorer import CardScorer
+from app.utils.stitch_fusion.tiled_infer import TiledInfer
+
+logger = get_logger(__name__)
+
+# ring/gray/stripe1-4 -> 模型 modal 名 (与 model_meta.json 里 modals 顺序对应)
+VARIANT_TO_MODAL: Dict[str, str] = {
+    "gray":    "img",
+    "stripe1": "ch1",
+    "stripe2": "ch2",
+    "stripe3": "ch3",
+    "stripe4": "ch4",
+    "ring":    "ring",
+}
+
+REQUIRED_VARIANTS = list(VARIANT_TO_MODAL.keys())
+
+# StitchFusion 类别 -> (CardScorer 面缺陷可识别的 label, severity_level)
+# 没列出的类别按原 label 透传, severity_level 默认 "一般"
+CLASS_TO_LABEL_SEVERITY: Dict[str, Tuple[str, str]] = {
+    "serious_scratch": ("scratch", "严重"),
+    "scratch":         ("scratch", "一般"),
+    "slight_scratch":  ("scratch", "轻微"),
+}
+
+# CardScorer 已知的面缺陷 label, 不在此集合的统一 fallback 到 "wear" 以避免评分异常
+KNOWN_FACE_LABELS = {
+    "wear", "wear_and_impact", "wear_and_stain", "damaged",
+    "scratch", "scuff",
+    "pit", "impact", "protrudent",
+    "stain",
+}
+
+
+def _decode_image_to_tensor(image_bytes: bytes, variant: str) -> torch.Tensor:
+    """字节流 -> [3,H,W] uint8 tensor (RGB)."""
+    np_arr = np.frombuffer(image_bytes, np.uint8)
+    img = cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
+    if img is None:
+        raise ValueError(f"无法解码图像: {variant}, 请确认是有效图片格式 (JPG/PNG 等)")
+
+    if img.ndim == 2:
+        img_rgb = np.stack([img, img, img], axis=-1)
+    elif img.shape[2] == 4:
+        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
+    elif img.shape[2] == 3:
+        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+    else:
+        raise ValueError(f"图像通道数不支持: {variant}, shape={img.shape}")
+
+    tensor = torch.from_numpy(np.ascontiguousarray(img_rgb)).permute(2, 0, 1).contiguous()
+    return tensor
+
+
+def _map_class(stitch_class: str) -> Tuple[str, str]:
+    """StitchFusion 类别名 -> (CardScorer label, severity_level)."""
+    if stitch_class in CLASS_TO_LABEL_SEVERITY:
+        return CLASS_TO_LABEL_SEVERITY[stitch_class]
+    if stitch_class in KNOWN_FACE_LABELS:
+        return stitch_class, "一般"
+    logger.warning(f"[StitchFusion] 未知类别 '{stitch_class}', 评分时回退为 'wear'")
+    return "wear", "一般"
+
+
+def _build_shapes_data(raw_defects: list) -> Tuple[Dict, list]:
+    """把 TiledInfer.defects 转成 DefectProcessor 可吃的 shapes 数据。
+
+    返回:
+        shapes_data: {"shapes": [{label, points, confidence}, ...]}
+        meta_per_shape: 与 shapes 顺序一一对应, 包含原始 class_name 与 severity_level
+    """
+    shapes = []
+    meta = []
+    for d in raw_defects:
+        for poly in d.get("polygons", []):
+            if len(poly) < 3:
+                continue
+            label, severity = _map_class(d["class_name"])
+            shapes.append({
+                "label": label,
+                "points": poly,
+                "confidence": d.get("prob"),
+            })
+            meta.append({
+                "stitch_class": d["class_name"],
+                "severity_level": severity,
+            })
+    return {"shapes": shapes}, meta
+
+
+class StitchFusionService:
+    """单例 + 懒加载, 避免每次请求都加载模型。"""
+
+    _instance: Optional["StitchFusionService"] = None
+    _lock = Lock()
+
+    def __new__(cls):
+        with cls._lock:
+            if cls._instance is None:
+                cls._instance = super().__new__(cls)
+                cls._instance._inferrer = None
+                cls._instance._init_lock = Lock()
+                cls._instance._scorer = None
+        return cls._instance
+
+    def _ensure_inferrer(self) -> TiledInfer:
+        if self._inferrer is not None:
+            return self._inferrer
+        with self._init_lock:
+            if self._inferrer is not None:
+                return self._inferrer
+            model_pt = str(settings.STITCH_FUSION_MODEL_PT)
+            model_meta = str(settings.STITCH_FUSION_MODEL_META)
+
+            for path_str, name in [(model_pt, "model_pt"), (model_meta, "model_meta")]:
+                if not Path(path_str).exists():
+                    raise FileNotFoundError(
+                        f"StitchFusion {name} 文件不存在: {path_str},请检查 settings.STITCH_FUSION_*"
+                    )
+
+            self._inferrer = TiledInfer(
+                model_pt=model_pt,
+                model_meta=model_meta,
+                device=settings.STITCH_FUSION_DEVICE,
+                threshold=settings.STITCH_FUSION_THRESHOLD,
+            )
+        return self._inferrer
+
+    def _ensure_scorer(self) -> CardScorer:
+        if self._scorer is None:
+            self._scorer = CardScorer(config_path=settings.SCORE_CONFIG_PATH)
+        return self._scorer
+
+    def stitch_score_inference(self, score_type: str, card_name: str,
+                               variant_bytes: Dict[str, bytes]) -> dict:
+        """
+        score_type: front 或 back, 决定 card_aspect, 同时仅用于命名。
+        variant_bytes: {variant_name: image_bytes}, 必须含 REQUIRED_VARIANTS 全部 6 张。
+
+        返回结构对齐 score_inference (同轴光分支), defect_result.defects 字段含
+        label/confidence/pixel_area/actual_area/width/height/points/min_rect/
+        defect_type/edit_type/severity_level/scratch_length/score/new_score。
+        """
+        missing = [v for v in REQUIRED_VARIANTS if v not in variant_bytes]
+        if missing:
+            raise ValueError(f"缺少图片字段: {missing}, 必须提供 6 张: {REQUIRED_VARIANTS}")
+
+        if score_type not in ("front", "back"):
+            raise ValueError(f"score_type 仅支持 front/back, 收到: {score_type}")
+
+        inferrer = self._ensure_inferrer()
+        scorer = self._ensure_scorer()
+
+        modal_imgs: Dict[str, torch.Tensor] = {}
+        for variant in REQUIRED_VARIANTS:
+            modal = VARIANT_TO_MODAL[variant]
+            modal_imgs[modal] = _decode_image_to_tensor(variant_bytes[variant], variant)
+
+        logger.info(f"[StitchFusion] {score_type}/{card_name} 开始推理")
+        canvas, raw_defects = inferrer.infer(modal_imgs)
+        h, w = int(canvas.shape[1]), int(canvas.shape[2])
+        stem = card_name or f"{score_type}_card"
+        logger.info(f"[StitchFusion] {stem} 推理结束, 原始连通域={len(raw_defects)}")
+
+        # 1) 把 StitchFusion 多边形转成 score_inference 风格 defect_result
+        shapes_data, meta_per_shape = _build_shapes_data(raw_defects)
+        processor = DefectProcessor(pixel_resolution=settings.PIXEL_RESOLUTION)
+        analysis_json = processor.analyze_from_json(shapes_data)
+
+        # 把每个 defect 补齐 score_inference 所需字段 (defect_type/edit_type/severity_level)
+        defects_full = []
+        for defect, meta in zip(analysis_json["defects"], meta_per_shape):
+            defect["severity_level"] = meta["severity_level"]
+            defect["stitch_class"] = meta["stitch_class"]
+            defects_full.append(defect)
+
+        defect_data = {
+            "defects": defects_full,
+            "statistics": analysis_json["statistics"],
+        }
+        defect_data = formate_face_data(defect_data)
+        defect_data = formate_add_edit_type(defect_data)
+
+        # 2) 同轴光评分流程: 仅算 face, 跳过 center/corner/edge
+        card_aspect = score_type
+        card_light_type = "coaxial"
+        try:
+            defect_data = scorer.calculate_defect_score(
+                "face", card_aspect, card_light_type, defect_data, True
+            )
+        except Exception as e:
+            logger.warning(f"[StitchFusion] face 评分失败, 跳过评分: {e}")
+            defect_data[f"{card_aspect}_face_deduct_score"] = 0.0
+            for d in defect_data["defects"]:
+                d.setdefault("score", None)
+                d.setdefault("new_score", None)
+
+        # 3) 套用 score_inference 同款外层结构 (center_result 用 {} 占位)
+        result_json = scorer.formate_one_card_result(
+            center_result={},
+            defect_result=defect_data,
+            card_light_type=card_light_type,
+            card_aspect=card_aspect,
+            imageHeight=h,
+            imageWidth=w,
+        )
+
+        result_json["result"]["image"] = stem
+        result_json["result"]["image_size"] = [h, w]
+        result_json["result"]["card_aspect"] = card_aspect
+        result_json["result"]["card_light_type"] = card_light_type
+        return result_json

+ 0 - 0
app/utils/stitch_fusion/__init__.py


+ 284 - 0
app/utils/stitch_fusion/tiled_infer.py

@@ -0,0 +1,284 @@
+"""
+StitchFusion PT 模型大图切片推理。
+流程: 大图 -> 切 TILE_ROWS x TILE_COLS 块(带 overlap) -> 每块推理 -> Voronoi 归属合并 ->
+       连通域提取 -> labelme 风格 JSON。
+
+仅依赖 model_traced.pt + model_meta.json,不依赖项目其他工具模块。
+"""
+import json
+import math
+from typing import Dict, List, Tuple
+
+import cv2
+import numpy as np
+import torch
+from torch.nn import functional as F
+from torchvision import transforms as T
+
+from app.core.logger import get_logger
+
+logger = get_logger(__name__)
+
+# ---- 切片参数 ----
+TILE_ROWS = 6
+TILE_COLS = 4
+OVERLAP = 0.2
+
+# ---- 过滤参数 ----
+MIN_DEFECT_AREA = 64       # 整图坐标下连通域最小面积
+EDGE_IGNORE_AREA = 128     # tile 内贴边且面积小于此值的碎片忽略
+EDGE_MARGIN = 4            # 贴边判断的边缘宽度(像素)
+MIN_HOLE_AREA = 16         # 保留面积 >= 此值的孔洞,避免细长划痕交叉围成的区域被填成块
+
+# 三种划痕严重度按优先级从高到低;同一条划痕被多个严重度同时检出时只保留最高优先级
+SCRATCH_PRIORITY = ['serious_scratch', 'scratch', 'slight_scratch']
+SCRATCH_DEDUP_OVERLAP = 0.8
+
+
+def _stitch_holes(outer: np.ndarray, holes: list) -> np.ndarray:
+    """把孔洞用零宽双向桥并入外轮廓,得到 even-odd 填充时仍保留孔洞的单环。"""
+    ring = outer
+    for hole in holes:
+        hc = hole.mean(axis=0)
+        pi = int(np.argmin(((ring - hc) ** 2).sum(1)))
+        qi = int(np.argmin(((hole - ring[pi]) ** 2).sum(1)))
+        hole_loop = np.concatenate([hole[qi:], hole[:qi], hole[qi:qi + 1]], axis=0)
+        ring = np.concatenate(
+            [ring[:pi + 1], hole_loop, ring[pi:pi + 1], ring[pi + 1:]], axis=0
+        )
+    return ring
+
+
+def _mask_to_polygons(roi: np.ndarray, ox: int, oy: int) -> list:
+    """单连通域二值 ROI -> labelme 多边形列表(保留孔洞), 坐标偏移到整图。"""
+    contours, hierarchy = cv2.findContours(roi, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)
+    if hierarchy is None:
+        return []
+    hierarchy = hierarchy[0]
+    polygons = []
+    for ci, hi in enumerate(hierarchy):
+        if hi[3] != -1:
+            continue
+        outer = contours[ci].reshape(-1, 2)
+        if outer.shape[0] < 3:
+            continue
+        holes = []
+        child = hi[2]
+        while child != -1:
+            hole = contours[child].reshape(-1, 2)
+            if hole.shape[0] >= 3 and cv2.contourArea(contours[child]) >= MIN_HOLE_AREA:
+                holes.append(hole)
+            child = hierarchy[child][0]
+        ring = _stitch_holes(outer, holes) if holes else outer
+        polygons.append([[int(px + ox), int(py + oy)] for px, py in ring])
+    return polygons
+
+
+def _grid_slices(h: int, w: int, rows: int, cols: int, overlap: float):
+    """rows x cols 个切片, 相邻块保持 overlap 比例重叠。"""
+    def _axis(length, count):
+        if count == 1:
+            return [slice(0, length)]
+        patch = max(1, min(int(round(length / (count - (count - 1) * overlap))), length))
+        stride = max(1, int(round(patch * (1.0 - overlap))))
+        slices = []
+        for i in range(count):
+            start = max(0, length - patch) if i == count - 1 else min(i * stride, max(0, length - patch))
+            stop = length if i == count - 1 else min(length, start + patch)
+            slices.append(slice(start, stop))
+        return slices
+
+    return [
+        (f"r{r}c{c}", ys, xs)
+        for r, ys in enumerate(_axis(h, rows))
+        for c, xs in enumerate(_axis(w, cols))
+    ]
+
+
+class TiledInfer:
+    """大图切片 PT 模型推理器。"""
+
+    def __init__(self, model_pt: str, model_meta: str,
+                 device: str = 'cuda:0', threshold: float = None):
+        with open(model_meta) as f:
+            meta = json.load(f)
+        self.modals = meta['modals']
+        self.classes = meta['classes']
+        self.palette = meta['palette']
+        self.image_size = meta['image_size']
+
+        thr = meta['thresholds']
+        if threshold is not None:
+            thr = [float(threshold)] * len(self.classes)
+
+        if device.startswith('cuda') and not torch.cuda.is_available():
+            logger.warning(f"{device} 不可用, 回退到 cpu")
+            device = 'cpu'
+        self.device = torch.device(device)
+
+        self.model = torch.jit.load(model_pt, map_location=self.device)
+        self.model.eval()
+        logger.info(f"[StitchFusion] 模型已加载: {model_pt}, device={self.device}")
+
+        self.thresholds = torch.tensor(thr, device=self.device).view(-1, 1, 1)
+
+        self.tf_main = T.Compose([
+            T.Lambda(lambda x: x / 255),
+            T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
+            T.Lambda(lambda x: x.unsqueeze(0)),
+        ])
+        self.tf_aux = T.Compose([
+            T.Lambda(lambda x: x / 255),
+            T.Lambda(lambda x: x.unsqueeze(0)),
+        ])
+
+    def _preprocess(self, img: torch.Tensor, is_main: bool) -> torch.Tensor:
+        H, W = img.shape[1:]
+        scale = self.image_size[0] / min(H, W)
+        nH = int(math.ceil(round(H * scale) / 32)) * 32
+        nW = int(math.ceil(round(W * scale) / 32)) * 32
+        img = T.Resize((nH, nW))(img)
+        tf = self.tf_main if is_main else self.tf_aux
+        return tf(img.float()).to(self.device)
+
+    @torch.inference_mode()
+    def _forward_tile(self, tile_modals: dict) -> Tuple[np.ndarray, np.ndarray]:
+        """{modal: [3,h,w] uint8} -> (mask [C,h,w] bool, prob [C,h,w] float32)"""
+        h, w = next(iter(tile_modals.values())).shape[1:]
+        inputs = [self._preprocess(tile_modals[m], m == 'img') for m in self.modals]
+        logits = self.model(inputs)
+        probs = logits.sigmoid().squeeze(0)
+        probs = F.interpolate(probs.unsqueeze(0), size=(h, w),
+                              mode='bilinear', align_corners=False).squeeze(0)
+        mask = (probs > self.thresholds).cpu().numpy()
+        return mask, probs.cpu().numpy().astype(np.float32)
+
+    def _filter_edge_noise(self, mask: np.ndarray) -> np.ndarray:
+        """去除 tile 内贴边的小碎片。"""
+        out = np.zeros_like(mask, dtype=bool)
+        for c in range(mask.shape[0]):
+            m = mask[c].astype(np.uint8)
+            if not m.any():
+                continue
+            num, labels, stats, _ = cv2.connectedComponentsWithStats(m, connectivity=8)
+            for k in range(1, num):
+                area = int(stats[k, cv2.CC_STAT_AREA])
+                comp = labels == k
+                on_edge = (comp[:EDGE_MARGIN, :].any() or comp[-EDGE_MARGIN:, :].any() or
+                           comp[:, :EDGE_MARGIN].any() or comp[:, -EDGE_MARGIN:].any())
+                if area < EDGE_IGNORE_AREA and on_edge:
+                    continue
+                out[c][comp] = True
+        return out
+
+    def _dedup_scratch(self, canvas: np.ndarray) -> np.ndarray:
+        """跨严重度去重: 同一条划痕被多个 scratch 类检出时只保留最高优先级。"""
+        name2idx = {n: i for i, n in enumerate(self.classes)}
+        order = [name2idx[n] for n in SCRATCH_PRIORITY if n in name2idx]
+        if len(order) < 2:
+            return canvas
+
+        claimed = np.zeros(canvas.shape[1:], dtype=bool)
+        for ci in order:
+            m = canvas[ci]
+            if not m.any():
+                continue
+            num, labels = cv2.connectedComponents(m.astype(np.uint8), connectivity=8)
+            keep = np.zeros_like(m)
+            for k in range(1, num):
+                comp = labels == k
+                area = int(comp.sum())
+                inter = int((comp & claimed).sum())
+                if area > 0 and inter / area >= SCRATCH_DEDUP_OVERLAP:
+                    continue
+                keep |= comp
+            canvas[ci] = keep
+            claimed |= keep
+        return canvas
+
+    def _extract_defects(self, canvas: np.ndarray, prob_canvas: np.ndarray = None) -> list:
+        """从 [C,H,W] bool canvas 提取每个连通域的缺陷信息。"""
+        defects, did = [], 0
+        for c in range(canvas.shape[0]):
+            m = canvas[c].astype(np.uint8)
+            if not m.any():
+                continue
+            num, labels, stats, centroids = cv2.connectedComponentsWithStats(m, connectivity=8)
+            for k in range(1, num):
+                area = int(stats[k, cv2.CC_STAT_AREA])
+                if area < MIN_DEFECT_AREA:
+                    continue
+                x = int(stats[k, cv2.CC_STAT_LEFT])
+                y = int(stats[k, cv2.CC_STAT_TOP])
+                bw = int(stats[k, cv2.CC_STAT_WIDTH])
+                bh = int(stats[k, cv2.CC_STAT_HEIGHT])
+                cx, cy = centroids[k]
+                comp = labels == k
+                prob = round(float(prob_canvas[c][comp].mean()), 4) if prob_canvas is not None else None
+                roi = comp[y:y + bh, x:x + bw].astype(np.uint8)
+                polygons = _mask_to_polygons(roi, x, y)
+                defects.append({
+                    "id": did,
+                    "class_index": int(c),
+                    "class_name": self.classes[c],
+                    "area": area,
+                    "prob": prob,
+                    "bbox": [x, y, x + bw, y + bh],
+                    "centroid": [round(float(cx), 2), round(float(cy), 2)],
+                    "polygons": polygons,
+                })
+                did += 1
+        return defects
+
+    def infer(self, modal_imgs: Dict[str, torch.Tensor],
+              tile_rows: int = TILE_ROWS, tile_cols: int = TILE_COLS,
+              overlap: float = OVERLAP) -> Tuple[np.ndarray, list]:
+        """
+        大图切片推理。
+        overlap 区域只采用离像素最近的 tile 中心的预测,避免多 tile OR 合并导致细线膨胀成块。
+        """
+        h = min(t.shape[1] for t in modal_imgs.values())
+        w = min(t.shape[2] for t in modal_imgs.values())
+        modal_imgs = {m: t[:, :h, :w].contiguous() for m, t in modal_imgs.items()}
+
+        slices = _grid_slices(h, w, tile_rows, tile_cols, overlap)
+
+        yy, xx = np.mgrid[0:h, 0:w].astype(np.float32)
+        min_dist2 = np.full((h, w), np.inf, dtype=np.float32)
+        owner = np.full((h, w), -1, dtype=np.int32)
+        for idx, (_, ys, xs) in enumerate(slices):
+            cy = (ys.start + ys.stop) / 2.0
+            cx = (xs.start + xs.stop) / 2.0
+            d2 = (yy - cy) ** 2 + (xx - cx) ** 2
+            upd = d2 < min_dist2
+            min_dist2[upd] = d2[upd]
+            owner[upd] = idx
+
+        canvas = np.zeros((len(self.classes), h, w), dtype=bool)
+        prob_canvas = np.zeros((len(self.classes), h, w), dtype=np.float32)
+        for idx, (tname, ys, xs) in enumerate(slices):
+            tile = {m: t[:, ys, xs].contiguous() for m, t in modal_imgs.items()}
+            mask, prob = self._forward_tile(tile)
+            mask = self._filter_edge_noise(mask)
+            own = (owner[ys, xs] == idx)
+            sel = own[np.newaxis, :, :]
+            canvas[:, ys, xs] |= mask & sel
+            np.copyto(prob_canvas[:, ys, xs], prob, where=sel)
+            logger.info(f"  tile {tname}: y={ys.start}-{ys.stop}, x={xs.start}-{xs.stop}")
+
+        canvas = self._dedup_scratch(canvas)
+        defects = self._extract_defects(canvas, prob_canvas)
+        return canvas, defects
+
+    def make_json(self, stem: str, h: int, w: int, defects: list) -> dict:
+        """生成符合 labelme 格式的 JSON 结构。"""
+        return {
+            "image": stem,
+            "image_size": [h, w],
+            "masks": [
+                {"label": d["class_name"], "prob": d.get("prob"), "points": poly}
+                for d in defects
+                for poly in d.get("polygons", [])
+                if len(poly) >= 3
+            ],
+        }

+ 1 - 1
run_defect_score_server.py

@@ -1,6 +1,6 @@
 import uvicorn
 
 if __name__ == "__main__":
-    port = 7754
+    port = 7744
     print(f"http://127.0.0.1:{port}/docs")
     uvicorn.run("app.main:app", host="0.0.0.0", port=port, reload=True)