| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244 |
- """
- StitchFusion 多模态拼接缺陷推理服务。
- - 懒加载 + 单例: 模型只加载一次, 加载失败给出明确错误。
- - 输入: 6 张同侧图片 (ring/gray/stripe1-4), 解码成 {modal: [3,H,W] uint8 tensor}。
- - 输出: 与 score_inference 接口一致的 result 结构 (center_result/defect_result/card_score 等)。
- StitchFusion 模型只检测面缺陷, 因此走同轴光 (coaxial) 评分流程, center_result 留空。
- """
- import time
- from pathlib import Path
- from threading import Lock
- from typing import Dict, Optional, Tuple
- import cv2
- import numpy as np
- import torch
- from app.core.config import settings
- from app.core.logger import get_logger
- from app.utils.defect_inference.arean_anylize_draw import DefectProcessor
- from app.utils.json_data_formate import formate_add_edit_type, formate_face_data
- from app.utils.score_inference.CardScorer import CardScorer
- from app.utils.stitch_fusion.tiled_infer import TiledInfer
- logger = get_logger(__name__)
- # ring/gray/stripe1-4 -> 模型 modal 名 (与 model_meta.json 里 modals 顺序对应)
- VARIANT_TO_MODAL: Dict[str, str] = {
- "gray": "img",
- "stripe1": "ch1",
- "stripe2": "ch2",
- "stripe3": "ch3",
- "stripe4": "ch4",
- "ring": "ring",
- }
- REQUIRED_VARIANTS = list(VARIANT_TO_MODAL.keys())
- # StitchFusion 类别 -> (CardScorer 面缺陷可识别的 label, severity_level)
- # 没列出的类别按原 label 透传, severity_level 默认 "一般"
- CLASS_TO_LABEL_SEVERITY: Dict[str, Tuple[str, str]] = {
- "serious_scratch": ("scratch", "严重"),
- "scratch": ("scratch", "一般"),
- "slight_scratch": ("scratch", "轻微"),
- }
- # CardScorer 已知的面缺陷 label, 不在此集合的统一 fallback 到 "wear" 以避免评分异常
- KNOWN_FACE_LABELS = {
- "wear", "wear_and_impact", "wear_and_stain", "damaged",
- "scratch", "scuff",
- "pit", "impact", "protrudent",
- "stain",
- }
- def _decode_image_to_tensor(image_bytes: bytes, variant: str) -> torch.Tensor:
- """字节流 -> [3,H,W] uint8 tensor (RGB)."""
- np_arr = np.frombuffer(image_bytes, np.uint8)
- img = cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
- if img is None:
- raise ValueError(f"无法解码图像: {variant}, 请确认是有效图片格式 (JPG/PNG 等)")
- if img.ndim == 2:
- img_rgb = np.stack([img, img, img], axis=-1)
- elif img.shape[2] == 4:
- img_rgb = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
- elif img.shape[2] == 3:
- img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
- else:
- raise ValueError(f"图像通道数不支持: {variant}, shape={img.shape}")
- tensor = torch.from_numpy(np.ascontiguousarray(img_rgb)).permute(2, 0, 1).contiguous()
- return tensor
- def _map_class(stitch_class: str) -> Tuple[str, str]:
- """StitchFusion 类别名 -> (CardScorer label, severity_level)."""
- if stitch_class in CLASS_TO_LABEL_SEVERITY:
- return CLASS_TO_LABEL_SEVERITY[stitch_class]
- if stitch_class in KNOWN_FACE_LABELS:
- return stitch_class, "一般"
- logger.warning(f"[StitchFusion] 未知类别 '{stitch_class}', 评分时回退为 'wear'")
- return "wear", "一般"
- def _build_shapes_data(raw_defects: list) -> Tuple[Dict, list]:
- """把 TiledInfer.defects 转成 DefectProcessor 可吃的 shapes 数据。
- 返回:
- shapes_data: {"shapes": [{label, points, confidence}, ...]}
- meta_per_shape: 与 shapes 顺序一一对应, 包含原始 class_name 与 severity_level
- """
- shapes = []
- meta = []
- for d in raw_defects:
- for poly in d.get("polygons", []):
- if len(poly) < 3:
- continue
- label, severity = _map_class(d["class_name"])
- shapes.append({
- "label": label,
- "points": poly,
- "confidence": d.get("prob"),
- })
- meta.append({
- "stitch_class": d["class_name"],
- "severity_level": severity,
- })
- return {"shapes": shapes}, meta
- class StitchFusionService:
- """单例 + 懒加载, 避免每次请求都加载模型。"""
- _instance: Optional["StitchFusionService"] = None
- _lock = Lock()
- def __new__(cls):
- with cls._lock:
- if cls._instance is None:
- cls._instance = super().__new__(cls)
- cls._instance._inferrer = None
- cls._instance._init_lock = Lock()
- cls._instance._scorer = None
- return cls._instance
- def _ensure_inferrer(self) -> TiledInfer:
- if self._inferrer is not None:
- return self._inferrer
- with self._init_lock:
- if self._inferrer is not None:
- return self._inferrer
- model_pt = str(settings.STITCH_FUSION_MODEL_PT)
- model_meta = str(settings.STITCH_FUSION_MODEL_META)
- for path_str, name in [(model_pt, "model_pt"), (model_meta, "model_meta")]:
- if not Path(path_str).exists():
- raise FileNotFoundError(
- f"StitchFusion {name} 文件不存在: {path_str},请检查 settings.STITCH_FUSION_*"
- )
- self._inferrer = TiledInfer(
- model_pt=model_pt,
- model_meta=model_meta,
- device=settings.STITCH_FUSION_DEVICE,
- threshold=settings.STITCH_FUSION_THRESHOLD,
- )
- return self._inferrer
- def _ensure_scorer(self) -> CardScorer:
- if self._scorer is None:
- self._scorer = CardScorer(config_path=settings.SCORE_CONFIG_PATH)
- return self._scorer
- def stitch_score_inference(self, score_type: str, card_name: str,
- variant_bytes: Dict[str, bytes]) -> dict:
- """
- score_type: front 或 back, 决定 card_aspect, 同时仅用于命名。
- variant_bytes: {variant_name: image_bytes}, 必须含 REQUIRED_VARIANTS 全部 6 张。
- 返回结构对齐 score_inference (同轴光分支), defect_result.defects 字段含
- label/confidence/pixel_area/actual_area/width/height/points/min_rect/
- defect_type/edit_type/severity_level/scratch_length/score/new_score。
- """
- missing = [v for v in REQUIRED_VARIANTS if v not in variant_bytes]
- if missing:
- raise ValueError(f"缺少图片字段: {missing}, 必须提供 6 张: {REQUIRED_VARIANTS}")
- if score_type not in ("front", "back"):
- raise ValueError(f"score_type 仅支持 front/back, 收到: {score_type}")
- inferrer = self._ensure_inferrer()
- scorer = self._ensure_scorer()
- stem = card_name or f"{score_type}_card"
- t0 = time.perf_counter()
- modal_imgs: Dict[str, torch.Tensor] = {}
- for variant in REQUIRED_VARIANTS:
- modal = VARIANT_TO_MODAL[variant]
- modal_imgs[modal] = _decode_image_to_tensor(variant_bytes[variant], variant)
- logger.info(f"[StitchFusion] {score_type}/{stem} 6 张图解码完成, "
- f"耗时 {time.perf_counter() - t0:.2f}s, "
- f"shape={[(m, list(t.shape)) for m, t in modal_imgs.items()]}")
- logger.info(f"[StitchFusion] {score_type}/{stem} 开始推理")
- t0 = time.perf_counter()
- canvas, raw_defects = inferrer.infer(modal_imgs)
- h, w = int(canvas.shape[1]), int(canvas.shape[2])
- logger.info(f"[StitchFusion] {stem} 模型推理结束, 原始连通域={len(raw_defects)}, "
- f"耗时 {time.perf_counter() - t0:.2f}s")
- # 1) 把 StitchFusion 多边形转成 score_inference 风格 defect_result
- t0 = time.perf_counter()
- shapes_data, meta_per_shape = _build_shapes_data(raw_defects)
- processor = DefectProcessor(pixel_resolution=settings.PIXEL_RESOLUTION)
- analysis_json = processor.analyze_from_json(shapes_data)
- # 把每个 defect 补齐 score_inference 所需字段 (defect_type/edit_type/severity_level)
- defects_full = []
- for defect, meta in zip(analysis_json["defects"], meta_per_shape):
- defect["severity_level"] = meta["severity_level"]
- defect["stitch_class"] = meta["stitch_class"]
- defects_full.append(defect)
- defect_data = {
- "defects": defects_full,
- "statistics": analysis_json["statistics"],
- }
- defect_data = formate_face_data(defect_data)
- defect_data = formate_add_edit_type(defect_data)
- logger.info(f"[StitchFusion] {stem} 几何指标/统计计算完成, 缺陷={len(defect_data['defects'])}, "
- f"耗时 {time.perf_counter() - t0:.2f}s")
- # 2) 同轴光评分流程: 仅算 face, 跳过 center/corner/edge
- t0 = time.perf_counter()
- card_aspect = score_type
- card_light_type = "coaxial"
- try:
- defect_data = scorer.calculate_defect_score(
- "face", card_aspect, card_light_type, defect_data, True
- )
- except Exception as e:
- logger.warning(f"[StitchFusion] face 评分失败, 跳过评分: {e}")
- defect_data[f"{card_aspect}_face_deduct_score"] = 0.0
- for d in defect_data["defects"]:
- d.setdefault("score", None)
- d.setdefault("new_score", None)
- # 3) 套用 score_inference 同款外层结构 (center_result 用 {} 占位)
- result_json = scorer.formate_one_card_result(
- center_result={},
- defect_result=defect_data,
- card_light_type=card_light_type,
- card_aspect=card_aspect,
- imageHeight=h,
- imageWidth=w,
- )
- result_json["result"]["image"] = stem
- result_json["result"]["image_size"] = [h, w]
- result_json["result"]["card_aspect"] = card_aspect
- result_json["result"]["card_light_type"] = card_light_type
- logger.info(f"[StitchFusion] {stem} 评分+封装完成, 耗时 {time.perf_counter() - t0:.2f}s, "
- f"card_score={result_json['result'].get('card_score')}")
- return result_json
|