stitch_fusion_service.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. """
  2. StitchFusion 多模态拼接缺陷推理服务。
  3. - 懒加载 + 单例: 模型只加载一次, 加载失败给出明确错误。
  4. - 输入: 6 张同侧图片 (ring/gray/stripe1-4), 解码成 {modal: [3,H,W] uint8 tensor}。
  5. - 输出: 与 score_inference 接口一致的 result 结构 (center_result/defect_result/card_score 等)。
  6. StitchFusion 模型只检测面缺陷, 因此走同轴光 (coaxial) 评分流程, center_result 留空。
  7. """
  8. import time
  9. from pathlib import Path
  10. from threading import Lock
  11. from typing import Dict, Optional, Tuple
  12. import cv2
  13. import numpy as np
  14. import torch
  15. from app.core.config import settings
  16. from app.core.logger import get_logger
  17. from app.utils.defect_inference.arean_anylize_draw import DefectProcessor
  18. from app.utils.json_data_formate import formate_add_edit_type, formate_face_data
  19. from app.utils.score_inference.CardScorer import CardScorer
  20. from app.utils.stitch_fusion.tiled_infer import TiledInfer
  21. logger = get_logger(__name__)
  22. # ring/gray/stripe1-4 -> 模型 modal 名 (与 model_meta.json 里 modals 顺序对应)
  23. VARIANT_TO_MODAL: Dict[str, str] = {
  24. "gray": "img",
  25. "stripe1": "ch1",
  26. "stripe2": "ch2",
  27. "stripe3": "ch3",
  28. "stripe4": "ch4",
  29. "ring": "ring",
  30. }
  31. REQUIRED_VARIANTS = list(VARIANT_TO_MODAL.keys())
  32. # StitchFusion 类别 -> (CardScorer 面缺陷可识别的 label, severity_level)
  33. # 没列出的类别按原 label 透传, severity_level 默认 "一般"
  34. CLASS_TO_LABEL_SEVERITY: Dict[str, Tuple[str, str]] = {
  35. "serious_scratch": ("scratch", "严重"),
  36. "scratch": ("scratch", "一般"),
  37. "slight_scratch": ("scratch", "轻微"),
  38. }
  39. # CardScorer 已知的面缺陷 label, 不在此集合的统一 fallback 到 "wear" 以避免评分异常
  40. KNOWN_FACE_LABELS = {
  41. "wear", "wear_and_impact", "wear_and_stain", "damaged",
  42. "scratch", "scuff",
  43. "pit", "impact", "protrudent",
  44. "stain",
  45. }
  46. def _decode_image_to_tensor(image_bytes: bytes, variant: str) -> torch.Tensor:
  47. """字节流 -> [3,H,W] uint8 tensor (RGB)."""
  48. np_arr = np.frombuffer(image_bytes, np.uint8)
  49. img = cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
  50. if img is None:
  51. raise ValueError(f"无法解码图像: {variant}, 请确认是有效图片格式 (JPG/PNG 等)")
  52. if img.ndim == 2:
  53. img_rgb = np.stack([img, img, img], axis=-1)
  54. elif img.shape[2] == 4:
  55. img_rgb = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
  56. elif img.shape[2] == 3:
  57. img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  58. else:
  59. raise ValueError(f"图像通道数不支持: {variant}, shape={img.shape}")
  60. tensor = torch.from_numpy(np.ascontiguousarray(img_rgb)).permute(2, 0, 1).contiguous()
  61. return tensor
  62. def _map_class(stitch_class: str) -> Tuple[str, str]:
  63. """StitchFusion 类别名 -> (CardScorer label, severity_level)."""
  64. if stitch_class in CLASS_TO_LABEL_SEVERITY:
  65. return CLASS_TO_LABEL_SEVERITY[stitch_class]
  66. if stitch_class in KNOWN_FACE_LABELS:
  67. return stitch_class, "一般"
  68. logger.warning(f"[StitchFusion] 未知类别 '{stitch_class}', 评分时回退为 'wear'")
  69. return "wear", "一般"
  70. def _build_shapes_data(raw_defects: list) -> Tuple[Dict, list]:
  71. """把 TiledInfer.defects 转成 DefectProcessor 可吃的 shapes 数据。
  72. 返回:
  73. shapes_data: {"shapes": [{label, points, confidence}, ...]}
  74. meta_per_shape: 与 shapes 顺序一一对应, 包含原始 class_name 与 severity_level
  75. """
  76. shapes = []
  77. meta = []
  78. for d in raw_defects:
  79. for poly in d.get("polygons", []):
  80. if len(poly) < 3:
  81. continue
  82. label, severity = _map_class(d["class_name"])
  83. shapes.append({
  84. "label": label,
  85. "points": poly,
  86. "confidence": d.get("prob"),
  87. })
  88. meta.append({
  89. "stitch_class": d["class_name"],
  90. "severity_level": severity,
  91. })
  92. return {"shapes": shapes}, meta
  93. class StitchFusionService:
  94. """单例 + 懒加载, 避免每次请求都加载模型。"""
  95. _instance: Optional["StitchFusionService"] = None
  96. _lock = Lock()
  97. def __new__(cls):
  98. with cls._lock:
  99. if cls._instance is None:
  100. cls._instance = super().__new__(cls)
  101. cls._instance._inferrer = None
  102. cls._instance._init_lock = Lock()
  103. cls._instance._scorer = None
  104. return cls._instance
  105. def _ensure_inferrer(self) -> TiledInfer:
  106. if self._inferrer is not None:
  107. return self._inferrer
  108. with self._init_lock:
  109. if self._inferrer is not None:
  110. return self._inferrer
  111. model_pt = str(settings.STITCH_FUSION_MODEL_PT)
  112. model_meta = str(settings.STITCH_FUSION_MODEL_META)
  113. for path_str, name in [(model_pt, "model_pt"), (model_meta, "model_meta")]:
  114. if not Path(path_str).exists():
  115. raise FileNotFoundError(
  116. f"StitchFusion {name} 文件不存在: {path_str},请检查 settings.STITCH_FUSION_*"
  117. )
  118. self._inferrer = TiledInfer(
  119. model_pt=model_pt,
  120. model_meta=model_meta,
  121. device=settings.STITCH_FUSION_DEVICE,
  122. threshold=settings.STITCH_FUSION_THRESHOLD,
  123. )
  124. return self._inferrer
  125. def _ensure_scorer(self) -> CardScorer:
  126. if self._scorer is None:
  127. self._scorer = CardScorer(config_path=settings.SCORE_CONFIG_PATH)
  128. return self._scorer
  129. def stitch_score_inference(self, score_type: str, card_name: str,
  130. variant_bytes: Dict[str, bytes]) -> dict:
  131. """
  132. score_type: front 或 back, 决定 card_aspect, 同时仅用于命名。
  133. variant_bytes: {variant_name: image_bytes}, 必须含 REQUIRED_VARIANTS 全部 6 张。
  134. 返回结构对齐 score_inference (同轴光分支), defect_result.defects 字段含
  135. label/confidence/pixel_area/actual_area/width/height/points/min_rect/
  136. defect_type/edit_type/severity_level/scratch_length/score/new_score。
  137. """
  138. missing = [v for v in REQUIRED_VARIANTS if v not in variant_bytes]
  139. if missing:
  140. raise ValueError(f"缺少图片字段: {missing}, 必须提供 6 张: {REQUIRED_VARIANTS}")
  141. if score_type not in ("front", "back"):
  142. raise ValueError(f"score_type 仅支持 front/back, 收到: {score_type}")
  143. inferrer = self._ensure_inferrer()
  144. scorer = self._ensure_scorer()
  145. stem = card_name or f"{score_type}_card"
  146. t0 = time.perf_counter()
  147. modal_imgs: Dict[str, torch.Tensor] = {}
  148. for variant in REQUIRED_VARIANTS:
  149. modal = VARIANT_TO_MODAL[variant]
  150. modal_imgs[modal] = _decode_image_to_tensor(variant_bytes[variant], variant)
  151. logger.info(f"[StitchFusion] {score_type}/{stem} 6 张图解码完成, "
  152. f"耗时 {time.perf_counter() - t0:.2f}s, "
  153. f"shape={[(m, list(t.shape)) for m, t in modal_imgs.items()]}")
  154. logger.info(f"[StitchFusion] {score_type}/{stem} 开始推理")
  155. t0 = time.perf_counter()
  156. canvas, raw_defects = inferrer.infer(modal_imgs)
  157. h, w = int(canvas.shape[1]), int(canvas.shape[2])
  158. logger.info(f"[StitchFusion] {stem} 模型推理结束, 原始连通域={len(raw_defects)}, "
  159. f"耗时 {time.perf_counter() - t0:.2f}s")
  160. # 1) 把 StitchFusion 多边形转成 score_inference 风格 defect_result
  161. t0 = time.perf_counter()
  162. shapes_data, meta_per_shape = _build_shapes_data(raw_defects)
  163. processor = DefectProcessor(pixel_resolution=settings.PIXEL_RESOLUTION)
  164. analysis_json = processor.analyze_from_json(shapes_data)
  165. # 把每个 defect 补齐 score_inference 所需字段 (defect_type/edit_type/severity_level)
  166. defects_full = []
  167. for defect, meta in zip(analysis_json["defects"], meta_per_shape):
  168. defect["severity_level"] = meta["severity_level"]
  169. defect["stitch_class"] = meta["stitch_class"]
  170. defects_full.append(defect)
  171. defect_data = {
  172. "defects": defects_full,
  173. "statistics": analysis_json["statistics"],
  174. }
  175. defect_data = formate_face_data(defect_data)
  176. defect_data = formate_add_edit_type(defect_data)
  177. logger.info(f"[StitchFusion] {stem} 几何指标/统计计算完成, 缺陷={len(defect_data['defects'])}, "
  178. f"耗时 {time.perf_counter() - t0:.2f}s")
  179. # 2) 同轴光评分流程: 仅算 face, 跳过 center/corner/edge
  180. t0 = time.perf_counter()
  181. card_aspect = score_type
  182. card_light_type = "coaxial"
  183. try:
  184. defect_data = scorer.calculate_defect_score(
  185. "face", card_aspect, card_light_type, defect_data, True
  186. )
  187. except Exception as e:
  188. logger.warning(f"[StitchFusion] face 评分失败, 跳过评分: {e}")
  189. defect_data[f"{card_aspect}_face_deduct_score"] = 0.0
  190. for d in defect_data["defects"]:
  191. d.setdefault("score", None)
  192. d.setdefault("new_score", None)
  193. # 3) 套用 score_inference 同款外层结构 (center_result 用 {} 占位)
  194. result_json = scorer.formate_one_card_result(
  195. center_result={},
  196. defect_result=defect_data,
  197. card_light_type=card_light_type,
  198. card_aspect=card_aspect,
  199. imageHeight=h,
  200. imageWidth=w,
  201. )
  202. result_json["result"]["image"] = stem
  203. result_json["result"]["image_size"] = [h, w]
  204. result_json["result"]["card_aspect"] = card_aspect
  205. result_json["result"]["card_light_type"] = card_light_type
  206. logger.info(f"[StitchFusion] {stem} 评分+封装完成, 耗时 {time.perf_counter() - t0:.2f}s, "
  207. f"card_score={result_json['result'].get('card_score')}")
  208. return result_json