Kaynağa Gözat

推理优化以及日志埋点

袁威 1 hafta önce
ebeveyn
işleme
b5433682e7

+ 16 - 3
app/services/stitch_fusion_service.py

@@ -5,6 +5,7 @@ StitchFusion 多模态拼接缺陷推理服务。
 - 输出: 与 score_inference 接口一致的 result 结构 (center_result/defect_result/card_score 等)。
   StitchFusion 模型只检测面缺陷, 因此走同轴光 (coaxial) 评分流程, center_result 留空。
 """
+import time
 from pathlib import Path
 from threading import Lock
 from typing import Dict, Optional, Tuple
@@ -169,19 +170,26 @@ class StitchFusionService:
 
         inferrer = self._ensure_inferrer()
         scorer = self._ensure_scorer()
+        stem = card_name or f"{score_type}_card"
 
+        t0 = time.perf_counter()
         modal_imgs: Dict[str, torch.Tensor] = {}
         for variant in REQUIRED_VARIANTS:
             modal = VARIANT_TO_MODAL[variant]
             modal_imgs[modal] = _decode_image_to_tensor(variant_bytes[variant], variant)
+        logger.info(f"[StitchFusion] {score_type}/{stem} 6 张图解码完成, "
+                    f"耗时 {time.perf_counter() - t0:.2f}s, "
+                    f"shape={[(m, list(t.shape)) for m, t in modal_imgs.items()]}")
 
-        logger.info(f"[StitchFusion] {score_type}/{card_name} 开始推理")
+        logger.info(f"[StitchFusion] {score_type}/{stem} 开始推理")
+        t0 = time.perf_counter()
         canvas, raw_defects = inferrer.infer(modal_imgs)
         h, w = int(canvas.shape[1]), int(canvas.shape[2])
-        stem = card_name or f"{score_type}_card"
-        logger.info(f"[StitchFusion] {stem} 推理结束, 原始连通域={len(raw_defects)}")
+        logger.info(f"[StitchFusion] {stem} 模型推理结束, 原始连通域={len(raw_defects)}, "
+                    f"耗时 {time.perf_counter() - t0:.2f}s")
 
         # 1) 把 StitchFusion 多边形转成 score_inference 风格 defect_result
+        t0 = time.perf_counter()
         shapes_data, meta_per_shape = _build_shapes_data(raw_defects)
         processor = DefectProcessor(pixel_resolution=settings.PIXEL_RESOLUTION)
         analysis_json = processor.analyze_from_json(shapes_data)
@@ -199,8 +207,11 @@ class StitchFusionService:
         }
         defect_data = formate_face_data(defect_data)
         defect_data = formate_add_edit_type(defect_data)
+        logger.info(f"[StitchFusion] {stem} 几何指标/统计计算完成, 缺陷={len(defect_data['defects'])}, "
+                    f"耗时 {time.perf_counter() - t0:.2f}s")
 
         # 2) 同轴光评分流程: 仅算 face, 跳过 center/corner/edge
+        t0 = time.perf_counter()
         card_aspect = score_type
         card_light_type = "coaxial"
         try:
@@ -228,4 +239,6 @@ class StitchFusionService:
         result_json["result"]["image_size"] = [h, w]
         result_json["result"]["card_aspect"] = card_aspect
         result_json["result"]["card_light_type"] = card_light_type
+        logger.info(f"[StitchFusion] {stem} 评分+封装完成, 耗时 {time.perf_counter() - t0:.2f}s, "
+                    f"card_score={result_json['result'].get('card_score')}")
         return result_json

+ 55 - 26
app/utils/stitch_fusion/tiled_infer.py

@@ -7,6 +7,7 @@ StitchFusion PT 模型大图切片推理。
 """
 import json
 import math
+import time
 from typing import Dict, List, Tuple
 
 import cv2
@@ -74,27 +75,51 @@ def _mask_to_polygons(roi: np.ndarray, ox: int, oy: int) -> list:
     return polygons
 
 
+def _axis_slices(length: int, count: int, overlap: float) -> List[slice]:
+    """单个轴上的切片划分。"""
+    if count == 1:
+        return [slice(0, length)]
+    patch = max(1, min(int(round(length / (count - (count - 1) * overlap))), length))
+    stride = max(1, int(round(patch * (1.0 - overlap))))
+    slices = []
+    for i in range(count):
+        start = max(0, length - patch) if i == count - 1 else min(i * stride, max(0, length - patch))
+        stop = length if i == count - 1 else min(length, start + patch)
+        slices.append(slice(start, stop))
+    return slices
+
+
 def _grid_slices(h: int, w: int, rows: int, cols: int, overlap: float):
     """rows x cols 个切片, 相邻块保持 overlap 比例重叠。"""
-    def _axis(length, count):
-        if count == 1:
-            return [slice(0, length)]
-        patch = max(1, min(int(round(length / (count - (count - 1) * overlap))), length))
-        stride = max(1, int(round(patch * (1.0 - overlap))))
-        slices = []
-        for i in range(count):
-            start = max(0, length - patch) if i == count - 1 else min(i * stride, max(0, length - patch))
-            stop = length if i == count - 1 else min(length, start + patch)
-            slices.append(slice(start, stop))
-        return slices
-
+    y_slices = _axis_slices(h, rows, overlap)
+    x_slices = _axis_slices(w, cols, overlap)
     return [
         (f"r{r}c{c}", ys, xs)
-        for r, ys in enumerate(_axis(h, rows))
-        for c, xs in enumerate(_axis(w, cols))
+        for r, ys in enumerate(y_slices)
+        for c, xs in enumerate(x_slices)
     ]
 
 
+def _build_owner_map(h: int, w: int, rows: int, cols: int, overlap: float) -> np.ndarray:
+    """计算每个像素归属的 tile 索引 (Voronoi 划分, 基于 tile 中心)。
+
+    由于 tile 网格是规则的, 行/列中心独立于另一轴, Voronoi 划分可分离为两个 1D
+    问题, 复杂度从 O(H*W*tiles) 降到 O((H+W)*tiles)。
+    """
+    y_slices = _axis_slices(h, rows, overlap)
+    x_slices = _axis_slices(w, cols, overlap)
+    y_centers = np.array([(s.start + s.stop) / 2.0 for s in y_slices], dtype=np.float32)
+    x_centers = np.array([(s.start + s.stop) / 2.0 for s in x_slices], dtype=np.float32)
+
+    ys = np.arange(h, dtype=np.float32)
+    xs = np.arange(w, dtype=np.float32)
+    y_owner = np.argmin(np.abs(ys[:, None] - y_centers[None, :]), axis=1).astype(np.int32)
+    x_owner = np.argmin(np.abs(xs[:, None] - x_centers[None, :]), axis=1).astype(np.int32)
+
+    owner = y_owner[:, None] * cols + x_owner[None, :]
+    return owner.astype(np.int32)
+
+
 class TiledInfer:
     """大图切片 PT 模型推理器。"""
 
@@ -237,26 +262,24 @@ class TiledInfer:
         大图切片推理。
         overlap 区域只采用离像素最近的 tile 中心的预测,避免多 tile OR 合并导致细线膨胀成块。
         """
+        t_all = time.perf_counter()
+
         h = min(t.shape[1] for t in modal_imgs.values())
         w = min(t.shape[2] for t in modal_imgs.values())
         modal_imgs = {m: t[:, :h, :w].contiguous() for m, t in modal_imgs.items()}
-
         slices = _grid_slices(h, w, tile_rows, tile_cols, overlap)
+        logger.info(f"[StitchFusion] 输入大图 {h}x{w}, 切片 {tile_rows}x{tile_cols}={len(slices)} 块, overlap={overlap}")
 
-        yy, xx = np.mgrid[0:h, 0:w].astype(np.float32)
-        min_dist2 = np.full((h, w), np.inf, dtype=np.float32)
-        owner = np.full((h, w), -1, dtype=np.int32)
-        for idx, (_, ys, xs) in enumerate(slices):
-            cy = (ys.start + ys.stop) / 2.0
-            cx = (xs.start + xs.stop) / 2.0
-            d2 = (yy - cy) ** 2 + (xx - cx) ** 2
-            upd = d2 < min_dist2
-            min_dist2[upd] = d2[upd]
-            owner[upd] = idx
+        t0 = time.perf_counter()
+        owner = _build_owner_map(h, w, tile_rows, tile_cols, overlap)
+        logger.info(f"[StitchFusion] Voronoi 归属图构建完成, 耗时 {time.perf_counter() - t0:.2f}s")
 
         canvas = np.zeros((len(self.classes), h, w), dtype=bool)
         prob_canvas = np.zeros((len(self.classes), h, w), dtype=np.float32)
         for idx, (tname, ys, xs) in enumerate(slices):
+            t_tile = time.perf_counter()
+            logger.info(f"[StitchFusion] tile {idx + 1}/{len(slices)} {tname} y={ys.start}-{ys.stop} x={xs.start}-{xs.stop} 推理开始")
+
             tile = {m: t[:, ys, xs].contiguous() for m, t in modal_imgs.items()}
             mask, prob = self._forward_tile(tile)
             mask = self._filter_edge_noise(mask)
@@ -264,10 +287,16 @@ class TiledInfer:
             sel = own[np.newaxis, :, :]
             canvas[:, ys, xs] |= mask & sel
             np.copyto(prob_canvas[:, ys, xs], prob, where=sel)
-            logger.info(f"  tile {tname}: y={ys.start}-{ys.stop}, x={xs.start}-{xs.stop}")
+            logger.info(f"[StitchFusion] tile {idx + 1}/{len(slices)} {tname} 完成, 耗时 {time.perf_counter() - t_tile:.2f}s")
 
+        t0 = time.perf_counter()
         canvas = self._dedup_scratch(canvas)
+        logger.info(f"[StitchFusion] 划痕跨严重度去重完成, 耗时 {time.perf_counter() - t0:.2f}s")
+
+        t0 = time.perf_counter()
         defects = self._extract_defects(canvas, prob_canvas)
+        logger.info(f"[StitchFusion] 连通域提取完成, 缺陷数={len(defects)}, 耗时 {time.perf_counter() - t0:.2f}s")
+        logger.info(f"[StitchFusion] infer 总耗时 {time.perf_counter() - t_all:.2f}s")
         return canvas, defects
 
     def make_json(self, stem: str, h: int, w: int, defects: list) -> dict: