stitch_fusion_service.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. """
  2. StitchFusion 多模态拼接缺陷推理服务。
  3. - 懒加载 + 单例: 模型只加载一次, 加载失败给出明确错误。
  4. - 输入: 6 张同侧图片 (ring/gray/stripe1-4), 解码成 {modal: [3,H,W] uint8 tensor}。
  5. - 输出: 与 score_inference 接口一致的 result 结构 (center_result/defect_result/card_score 等)。
  6. StitchFusion 模型只检测面缺陷, 因此走同轴光 (coaxial) 评分流程, center_result 留空。
  7. """
  8. from pathlib import Path
  9. from threading import Lock
  10. from typing import Dict, Optional, Tuple
  11. import cv2
  12. import numpy as np
  13. import torch
  14. from app.core.config import settings
  15. from app.core.logger import get_logger
  16. from app.utils.defect_inference.arean_anylize_draw import DefectProcessor
  17. from app.utils.json_data_formate import formate_add_edit_type, formate_face_data
  18. from app.utils.score_inference.CardScorer import CardScorer
  19. from app.utils.stitch_fusion.tiled_infer import TiledInfer
  20. logger = get_logger(__name__)
  21. # ring/gray/stripe1-4 -> 模型 modal 名 (与 model_meta.json 里 modals 顺序对应)
  22. VARIANT_TO_MODAL: Dict[str, str] = {
  23. "gray": "img",
  24. "stripe1": "ch1",
  25. "stripe2": "ch2",
  26. "stripe3": "ch3",
  27. "stripe4": "ch4",
  28. "ring": "ring",
  29. }
  30. REQUIRED_VARIANTS = list(VARIANT_TO_MODAL.keys())
  31. # StitchFusion 类别 -> (CardScorer 面缺陷可识别的 label, severity_level)
  32. # 没列出的类别按原 label 透传, severity_level 默认 "一般"
  33. CLASS_TO_LABEL_SEVERITY: Dict[str, Tuple[str, str]] = {
  34. "serious_scratch": ("scratch", "严重"),
  35. "scratch": ("scratch", "一般"),
  36. "slight_scratch": ("scratch", "轻微"),
  37. }
  38. # CardScorer 已知的面缺陷 label, 不在此集合的统一 fallback 到 "wear" 以避免评分异常
  39. KNOWN_FACE_LABELS = {
  40. "wear", "wear_and_impact", "wear_and_stain", "damaged",
  41. "scratch", "scuff",
  42. "pit", "impact", "protrudent",
  43. "stain",
  44. }
  45. def _decode_image_to_tensor(image_bytes: bytes, variant: str) -> torch.Tensor:
  46. """字节流 -> [3,H,W] uint8 tensor (RGB)."""
  47. np_arr = np.frombuffer(image_bytes, np.uint8)
  48. img = cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
  49. if img is None:
  50. raise ValueError(f"无法解码图像: {variant}, 请确认是有效图片格式 (JPG/PNG 等)")
  51. if img.ndim == 2:
  52. img_rgb = np.stack([img, img, img], axis=-1)
  53. elif img.shape[2] == 4:
  54. img_rgb = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
  55. elif img.shape[2] == 3:
  56. img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  57. else:
  58. raise ValueError(f"图像通道数不支持: {variant}, shape={img.shape}")
  59. tensor = torch.from_numpy(np.ascontiguousarray(img_rgb)).permute(2, 0, 1).contiguous()
  60. return tensor
  61. def _map_class(stitch_class: str) -> Tuple[str, str]:
  62. """StitchFusion 类别名 -> (CardScorer label, severity_level)."""
  63. if stitch_class in CLASS_TO_LABEL_SEVERITY:
  64. return CLASS_TO_LABEL_SEVERITY[stitch_class]
  65. if stitch_class in KNOWN_FACE_LABELS:
  66. return stitch_class, "一般"
  67. logger.warning(f"[StitchFusion] 未知类别 '{stitch_class}', 评分时回退为 'wear'")
  68. return "wear", "一般"
  69. def _build_shapes_data(raw_defects: list) -> Tuple[Dict, list]:
  70. """把 TiledInfer.defects 转成 DefectProcessor 可吃的 shapes 数据。
  71. 返回:
  72. shapes_data: {"shapes": [{label, points, confidence}, ...]}
  73. meta_per_shape: 与 shapes 顺序一一对应, 包含原始 class_name 与 severity_level
  74. """
  75. shapes = []
  76. meta = []
  77. for d in raw_defects:
  78. for poly in d.get("polygons", []):
  79. if len(poly) < 3:
  80. continue
  81. label, severity = _map_class(d["class_name"])
  82. shapes.append({
  83. "label": label,
  84. "points": poly,
  85. "confidence": d.get("prob"),
  86. })
  87. meta.append({
  88. "stitch_class": d["class_name"],
  89. "severity_level": severity,
  90. })
  91. return {"shapes": shapes}, meta
  92. class StitchFusionService:
  93. """单例 + 懒加载, 避免每次请求都加载模型。"""
  94. _instance: Optional["StitchFusionService"] = None
  95. _lock = Lock()
  96. def __new__(cls):
  97. with cls._lock:
  98. if cls._instance is None:
  99. cls._instance = super().__new__(cls)
  100. cls._instance._inferrer = None
  101. cls._instance._init_lock = Lock()
  102. cls._instance._scorer = None
  103. return cls._instance
  104. def _ensure_inferrer(self) -> TiledInfer:
  105. if self._inferrer is not None:
  106. return self._inferrer
  107. with self._init_lock:
  108. if self._inferrer is not None:
  109. return self._inferrer
  110. model_pt = str(settings.STITCH_FUSION_MODEL_PT)
  111. model_meta = str(settings.STITCH_FUSION_MODEL_META)
  112. for path_str, name in [(model_pt, "model_pt"), (model_meta, "model_meta")]:
  113. if not Path(path_str).exists():
  114. raise FileNotFoundError(
  115. f"StitchFusion {name} 文件不存在: {path_str},请检查 settings.STITCH_FUSION_*"
  116. )
  117. self._inferrer = TiledInfer(
  118. model_pt=model_pt,
  119. model_meta=model_meta,
  120. device=settings.STITCH_FUSION_DEVICE,
  121. threshold=settings.STITCH_FUSION_THRESHOLD,
  122. )
  123. return self._inferrer
  124. def _ensure_scorer(self) -> CardScorer:
  125. if self._scorer is None:
  126. self._scorer = CardScorer(config_path=settings.SCORE_CONFIG_PATH)
  127. return self._scorer
  128. def stitch_score_inference(self, score_type: str, card_name: str,
  129. variant_bytes: Dict[str, bytes]) -> dict:
  130. """
  131. score_type: front 或 back, 决定 card_aspect, 同时仅用于命名。
  132. variant_bytes: {variant_name: image_bytes}, 必须含 REQUIRED_VARIANTS 全部 6 张。
  133. 返回结构对齐 score_inference (同轴光分支), defect_result.defects 字段含
  134. label/confidence/pixel_area/actual_area/width/height/points/min_rect/
  135. defect_type/edit_type/severity_level/scratch_length/score/new_score。
  136. """
  137. missing = [v for v in REQUIRED_VARIANTS if v not in variant_bytes]
  138. if missing:
  139. raise ValueError(f"缺少图片字段: {missing}, 必须提供 6 张: {REQUIRED_VARIANTS}")
  140. if score_type not in ("front", "back"):
  141. raise ValueError(f"score_type 仅支持 front/back, 收到: {score_type}")
  142. inferrer = self._ensure_inferrer()
  143. scorer = self._ensure_scorer()
  144. modal_imgs: Dict[str, torch.Tensor] = {}
  145. for variant in REQUIRED_VARIANTS:
  146. modal = VARIANT_TO_MODAL[variant]
  147. modal_imgs[modal] = _decode_image_to_tensor(variant_bytes[variant], variant)
  148. logger.info(f"[StitchFusion] {score_type}/{card_name} 开始推理")
  149. canvas, raw_defects = inferrer.infer(modal_imgs)
  150. h, w = int(canvas.shape[1]), int(canvas.shape[2])
  151. stem = card_name or f"{score_type}_card"
  152. logger.info(f"[StitchFusion] {stem} 推理结束, 原始连通域={len(raw_defects)}")
  153. # 1) 把 StitchFusion 多边形转成 score_inference 风格 defect_result
  154. shapes_data, meta_per_shape = _build_shapes_data(raw_defects)
  155. processor = DefectProcessor(pixel_resolution=settings.PIXEL_RESOLUTION)
  156. analysis_json = processor.analyze_from_json(shapes_data)
  157. # 把每个 defect 补齐 score_inference 所需字段 (defect_type/edit_type/severity_level)
  158. defects_full = []
  159. for defect, meta in zip(analysis_json["defects"], meta_per_shape):
  160. defect["severity_level"] = meta["severity_level"]
  161. defect["stitch_class"] = meta["stitch_class"]
  162. defects_full.append(defect)
  163. defect_data = {
  164. "defects": defects_full,
  165. "statistics": analysis_json["statistics"],
  166. }
  167. defect_data = formate_face_data(defect_data)
  168. defect_data = formate_add_edit_type(defect_data)
  169. # 2) 同轴光评分流程: 仅算 face, 跳过 center/corner/edge
  170. card_aspect = score_type
  171. card_light_type = "coaxial"
  172. try:
  173. defect_data = scorer.calculate_defect_score(
  174. "face", card_aspect, card_light_type, defect_data, True
  175. )
  176. except Exception as e:
  177. logger.warning(f"[StitchFusion] face 评分失败, 跳过评分: {e}")
  178. defect_data[f"{card_aspect}_face_deduct_score"] = 0.0
  179. for d in defect_data["defects"]:
  180. d.setdefault("score", None)
  181. d.setdefault("new_score", None)
  182. # 3) 套用 score_inference 同款外层结构 (center_result 用 {} 占位)
  183. result_json = scorer.formate_one_card_result(
  184. center_result={},
  185. defect_result=defect_data,
  186. card_light_type=card_light_type,
  187. card_aspect=card_aspect,
  188. imageHeight=h,
  189. imageWidth=w,
  190. )
  191. result_json["result"]["image"] = stem
  192. result_json["result"]["image_size"] = [h, w]
  193. result_json["result"]["card_aspect"] = card_aspect
  194. result_json["result"]["card_light_type"] = card_light_type
  195. return result_json