""" StitchFusion 多模态拼接缺陷推理服务。 - 懒加载 + 单例: 模型只加载一次, 加载失败给出明确错误。 - 输入: 6 张同侧图片 (ring/gray/stripe1-4), 解码成 {modal: [3,H,W] uint8 tensor}。 - 输出: 与 score_inference 接口一致的 result 结构 (center_result/defect_result/card_score 等)。 StitchFusion 模型只检测面缺陷, 因此走同轴光 (coaxial) 评分流程, center_result 留空。 """ from pathlib import Path from threading import Lock from typing import Dict, Optional, Tuple import cv2 import numpy as np import torch from app.core.config import settings from app.core.logger import get_logger from app.utils.defect_inference.arean_anylize_draw import DefectProcessor from app.utils.json_data_formate import formate_add_edit_type, formate_face_data from app.utils.score_inference.CardScorer import CardScorer from app.utils.stitch_fusion.tiled_infer import TiledInfer logger = get_logger(__name__) # ring/gray/stripe1-4 -> 模型 modal 名 (与 model_meta.json 里 modals 顺序对应) VARIANT_TO_MODAL: Dict[str, str] = { "gray": "img", "stripe1": "ch1", "stripe2": "ch2", "stripe3": "ch3", "stripe4": "ch4", "ring": "ring", } REQUIRED_VARIANTS = list(VARIANT_TO_MODAL.keys()) # StitchFusion 类别 -> (CardScorer 面缺陷可识别的 label, severity_level) # 没列出的类别按原 label 透传, severity_level 默认 "一般" CLASS_TO_LABEL_SEVERITY: Dict[str, Tuple[str, str]] = { "serious_scratch": ("scratch", "严重"), "scratch": ("scratch", "一般"), "slight_scratch": ("scratch", "轻微"), } # CardScorer 已知的面缺陷 label, 不在此集合的统一 fallback 到 "wear" 以避免评分异常 KNOWN_FACE_LABELS = { "wear", "wear_and_impact", "wear_and_stain", "damaged", "scratch", "scuff", "pit", "impact", "protrudent", "stain", } def _decode_image_to_tensor(image_bytes: bytes, variant: str) -> torch.Tensor: """字节流 -> [3,H,W] uint8 tensor (RGB).""" np_arr = np.frombuffer(image_bytes, np.uint8) img = cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED) if img is None: raise ValueError(f"无法解码图像: {variant}, 请确认是有效图片格式 (JPG/PNG 等)") if img.ndim == 2: img_rgb = np.stack([img, img, img], axis=-1) elif img.shape[2] == 4: img_rgb = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB) elif img.shape[2] == 3: img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) else: raise ValueError(f"图像通道数不支持: {variant}, shape={img.shape}") tensor = torch.from_numpy(np.ascontiguousarray(img_rgb)).permute(2, 0, 1).contiguous() return tensor def _map_class(stitch_class: str) -> Tuple[str, str]: """StitchFusion 类别名 -> (CardScorer label, severity_level).""" if stitch_class in CLASS_TO_LABEL_SEVERITY: return CLASS_TO_LABEL_SEVERITY[stitch_class] if stitch_class in KNOWN_FACE_LABELS: return stitch_class, "一般" logger.warning(f"[StitchFusion] 未知类别 '{stitch_class}', 评分时回退为 'wear'") return "wear", "一般" def _build_shapes_data(raw_defects: list) -> Tuple[Dict, list]: """把 TiledInfer.defects 转成 DefectProcessor 可吃的 shapes 数据。 返回: shapes_data: {"shapes": [{label, points, confidence}, ...]} meta_per_shape: 与 shapes 顺序一一对应, 包含原始 class_name 与 severity_level """ shapes = [] meta = [] for d in raw_defects: for poly in d.get("polygons", []): if len(poly) < 3: continue label, severity = _map_class(d["class_name"]) shapes.append({ "label": label, "points": poly, "confidence": d.get("prob"), }) meta.append({ "stitch_class": d["class_name"], "severity_level": severity, }) return {"shapes": shapes}, meta class StitchFusionService: """单例 + 懒加载, 避免每次请求都加载模型。""" _instance: Optional["StitchFusionService"] = None _lock = Lock() def __new__(cls): with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._inferrer = None cls._instance._init_lock = Lock() cls._instance._scorer = None return cls._instance def _ensure_inferrer(self) -> TiledInfer: if self._inferrer is not None: return self._inferrer with self._init_lock: if self._inferrer is not None: return self._inferrer model_pt = str(settings.STITCH_FUSION_MODEL_PT) model_meta = str(settings.STITCH_FUSION_MODEL_META) for path_str, name in [(model_pt, "model_pt"), (model_meta, "model_meta")]: if not Path(path_str).exists(): raise FileNotFoundError( f"StitchFusion {name} 文件不存在: {path_str},请检查 settings.STITCH_FUSION_*" ) self._inferrer = TiledInfer( model_pt=model_pt, model_meta=model_meta, device=settings.STITCH_FUSION_DEVICE, threshold=settings.STITCH_FUSION_THRESHOLD, ) return self._inferrer def _ensure_scorer(self) -> CardScorer: if self._scorer is None: self._scorer = CardScorer(config_path=settings.SCORE_CONFIG_PATH) return self._scorer def stitch_score_inference(self, score_type: str, card_name: str, variant_bytes: Dict[str, bytes]) -> dict: """ score_type: front 或 back, 决定 card_aspect, 同时仅用于命名。 variant_bytes: {variant_name: image_bytes}, 必须含 REQUIRED_VARIANTS 全部 6 张。 返回结构对齐 score_inference (同轴光分支), defect_result.defects 字段含 label/confidence/pixel_area/actual_area/width/height/points/min_rect/ defect_type/edit_type/severity_level/scratch_length/score/new_score。 """ missing = [v for v in REQUIRED_VARIANTS if v not in variant_bytes] if missing: raise ValueError(f"缺少图片字段: {missing}, 必须提供 6 张: {REQUIRED_VARIANTS}") if score_type not in ("front", "back"): raise ValueError(f"score_type 仅支持 front/back, 收到: {score_type}") inferrer = self._ensure_inferrer() scorer = self._ensure_scorer() modal_imgs: Dict[str, torch.Tensor] = {} for variant in REQUIRED_VARIANTS: modal = VARIANT_TO_MODAL[variant] modal_imgs[modal] = _decode_image_to_tensor(variant_bytes[variant], variant) logger.info(f"[StitchFusion] {score_type}/{card_name} 开始推理") canvas, raw_defects = inferrer.infer(modal_imgs) h, w = int(canvas.shape[1]), int(canvas.shape[2]) stem = card_name or f"{score_type}_card" logger.info(f"[StitchFusion] {stem} 推理结束, 原始连通域={len(raw_defects)}") # 1) 把 StitchFusion 多边形转成 score_inference 风格 defect_result shapes_data, meta_per_shape = _build_shapes_data(raw_defects) processor = DefectProcessor(pixel_resolution=settings.PIXEL_RESOLUTION) analysis_json = processor.analyze_from_json(shapes_data) # 把每个 defect 补齐 score_inference 所需字段 (defect_type/edit_type/severity_level) defects_full = [] for defect, meta in zip(analysis_json["defects"], meta_per_shape): defect["severity_level"] = meta["severity_level"] defect["stitch_class"] = meta["stitch_class"] defects_full.append(defect) defect_data = { "defects": defects_full, "statistics": analysis_json["statistics"], } defect_data = formate_face_data(defect_data) defect_data = formate_add_edit_type(defect_data) # 2) 同轴光评分流程: 仅算 face, 跳过 center/corner/edge card_aspect = score_type card_light_type = "coaxial" try: defect_data = scorer.calculate_defect_score( "face", card_aspect, card_light_type, defect_data, True ) except Exception as e: logger.warning(f"[StitchFusion] face 评分失败, 跳过评分: {e}") defect_data[f"{card_aspect}_face_deduct_score"] = 0.0 for d in defect_data["defects"]: d.setdefault("score", None) d.setdefault("new_score", None) # 3) 套用 score_inference 同款外层结构 (center_result 用 {} 占位) result_json = scorer.formate_one_card_result( center_result={}, defect_result=defect_data, card_light_type=card_light_type, card_aspect=card_aspect, imageHeight=h, imageWidth=w, ) result_json["result"]["image"] = stem result_json["result"]["image_size"] = [h, w] result_json["result"]["card_aspect"] = card_aspect result_json["result"]["card_light_type"] = card_light_type return result_json