""" StitchFusion PT 模型大图切片推理。 流程: 大图 -> 切 TILE_ROWS x TILE_COLS 块(带 overlap) -> 每块推理 -> Voronoi 归属合并 -> 连通域提取 -> labelme 风格 JSON。 仅依赖 model_traced.pt + model_meta.json,不依赖项目其他工具模块。 """ import json import math import time from typing import Dict, List, Tuple import cv2 import numpy as np import torch from torch.nn import functional as F from torchvision import transforms as T from app.core.logger import get_logger logger = get_logger(__name__) # ---- 切片参数 ---- TILE_ROWS = 2 TILE_COLS = 3 OVERLAP = 0.2 # ---- 过滤参数 ---- MIN_DEFECT_AREA = 64 # 整图坐标下连通域最小面积 EDGE_IGNORE_AREA = 128 # tile 内贴边且面积小于此值的碎片忽略 EDGE_MARGIN = 4 # 贴边判断的边缘宽度(像素) MIN_HOLE_AREA = 16 # 保留面积 >= 此值的孔洞,避免细长划痕交叉围成的区域被填成块 # 三种划痕严重度按优先级从高到低;同一条划痕被多个严重度同时检出时只保留最高优先级 SCRATCH_PRIORITY = ['serious_scratch', 'scratch', 'slight_scratch'] SCRATCH_DEDUP_OVERLAP = 0.8 def _stitch_holes(outer: np.ndarray, holes: list) -> np.ndarray: """把孔洞用零宽双向桥并入外轮廓,得到 even-odd 填充时仍保留孔洞的单环。""" ring = outer for hole in holes: hc = hole.mean(axis=0) pi = int(np.argmin(((ring - hc) ** 2).sum(1))) qi = int(np.argmin(((hole - ring[pi]) ** 2).sum(1))) hole_loop = np.concatenate([hole[qi:], hole[:qi], hole[qi:qi + 1]], axis=0) ring = np.concatenate( [ring[:pi + 1], hole_loop, ring[pi:pi + 1], ring[pi + 1:]], axis=0 ) return ring def _mask_to_polygons(roi: np.ndarray, ox: int, oy: int) -> list: """单连通域二值 ROI -> labelme 多边形列表(保留孔洞), 坐标偏移到整图。""" contours, hierarchy = cv2.findContours(roi, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE) if hierarchy is None: return [] hierarchy = hierarchy[0] polygons = [] for ci, hi in enumerate(hierarchy): if hi[3] != -1: continue outer = contours[ci].reshape(-1, 2) if outer.shape[0] < 3: continue holes = [] child = hi[2] while child != -1: hole = contours[child].reshape(-1, 2) if hole.shape[0] >= 3 and cv2.contourArea(contours[child]) >= MIN_HOLE_AREA: holes.append(hole) child = hierarchy[child][0] ring = _stitch_holes(outer, holes) if holes else outer polygons.append([[int(px + ox), int(py + oy)] for px, py in ring]) return polygons def _axis_slices(length: int, count: int, overlap: float) -> List[slice]: """单个轴上的切片划分。""" if count == 1: return [slice(0, length)] patch = max(1, min(int(round(length / (count - (count - 1) * overlap))), length)) stride = max(1, int(round(patch * (1.0 - overlap)))) slices = [] for i in range(count): start = max(0, length - patch) if i == count - 1 else min(i * stride, max(0, length - patch)) stop = length if i == count - 1 else min(length, start + patch) slices.append(slice(start, stop)) return slices def _grid_slices(h: int, w: int, rows: int, cols: int, overlap: float): """rows x cols 个切片, 相邻块保持 overlap 比例重叠。""" y_slices = _axis_slices(h, rows, overlap) x_slices = _axis_slices(w, cols, overlap) return [ (f"r{r}c{c}", ys, xs) for r, ys in enumerate(y_slices) for c, xs in enumerate(x_slices) ] def _build_owner_map(h: int, w: int, rows: int, cols: int, overlap: float) -> np.ndarray: """计算每个像素归属的 tile 索引 (Voronoi 划分, 基于 tile 中心)。 由于 tile 网格是规则的, 行/列中心独立于另一轴, Voronoi 划分可分离为两个 1D 问题, 复杂度从 O(H*W*tiles) 降到 O((H+W)*tiles)。 """ y_slices = _axis_slices(h, rows, overlap) x_slices = _axis_slices(w, cols, overlap) y_centers = np.array([(s.start + s.stop) / 2.0 for s in y_slices], dtype=np.float32) x_centers = np.array([(s.start + s.stop) / 2.0 for s in x_slices], dtype=np.float32) ys = np.arange(h, dtype=np.float32) xs = np.arange(w, dtype=np.float32) y_owner = np.argmin(np.abs(ys[:, None] - y_centers[None, :]), axis=1).astype(np.int32) x_owner = np.argmin(np.abs(xs[:, None] - x_centers[None, :]), axis=1).astype(np.int32) owner = y_owner[:, None] * cols + x_owner[None, :] return owner.astype(np.int32) class TiledInfer: """大图切片 PT 模型推理器。""" def __init__(self, model_pt: str, model_meta: str, device: str = 'cuda:0', threshold: float = None): with open(model_meta) as f: meta = json.load(f) self.modals = meta['modals'] self.classes = meta['classes'] self.palette = meta['palette'] self.image_size = meta['image_size'] thr = meta['thresholds'] if threshold is not None: thr = [float(threshold)] * len(self.classes) if device.startswith('cuda') and not torch.cuda.is_available(): logger.warning(f"{device} 不可用, 回退到 cpu") device = 'cpu' self.device = torch.device(device) self.model = torch.jit.load(model_pt, map_location=self.device) self.model.eval() logger.info(f"[StitchFusion] 模型已加载: {model_pt}, device={self.device}") self.thresholds = torch.tensor(thr, device=self.device).view(-1, 1, 1) self.tf_main = T.Compose([ T.Lambda(lambda x: x / 255), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), T.Lambda(lambda x: x.unsqueeze(0)), ]) self.tf_aux = T.Compose([ T.Lambda(lambda x: x / 255), T.Lambda(lambda x: x.unsqueeze(0)), ]) def _preprocess(self, img: torch.Tensor, is_main: bool) -> torch.Tensor: H, W = img.shape[1:] scale = self.image_size[0] / min(H, W) nH = int(math.ceil(round(H * scale) / 32)) * 32 nW = int(math.ceil(round(W * scale) / 32)) * 32 img = T.Resize((nH, nW))(img) tf = self.tf_main if is_main else self.tf_aux return tf(img.float()).to(self.device) @torch.inference_mode() def _forward_tile(self, tile_modals: dict) -> Tuple[np.ndarray, np.ndarray]: """{modal: [3,h,w] uint8} -> (mask [C,h,w] bool, prob [C,h,w] float32)""" h, w = next(iter(tile_modals.values())).shape[1:] inputs = [self._preprocess(tile_modals[m], m == 'img') for m in self.modals] logits = self.model(inputs) probs = logits.sigmoid().squeeze(0) probs = F.interpolate(probs.unsqueeze(0), size=(h, w), mode='bilinear', align_corners=False).squeeze(0) mask = (probs > self.thresholds).cpu().numpy() return mask, probs.cpu().numpy().astype(np.float32) def _filter_edge_noise(self, mask: np.ndarray) -> np.ndarray: """去除 tile 内贴边的小碎片。""" out = np.zeros_like(mask, dtype=bool) for c in range(mask.shape[0]): m = mask[c].astype(np.uint8) if not m.any(): continue num, labels, stats, _ = cv2.connectedComponentsWithStats(m, connectivity=8) for k in range(1, num): area = int(stats[k, cv2.CC_STAT_AREA]) comp = labels == k on_edge = (comp[:EDGE_MARGIN, :].any() or comp[-EDGE_MARGIN:, :].any() or comp[:, :EDGE_MARGIN].any() or comp[:, -EDGE_MARGIN:].any()) if area < EDGE_IGNORE_AREA and on_edge: continue out[c][comp] = True return out def _dedup_scratch(self, canvas: np.ndarray) -> np.ndarray: """跨严重度去重: 同一条划痕被多个 scratch 类检出时只保留最高优先级。""" name2idx = {n: i for i, n in enumerate(self.classes)} order = [name2idx[n] for n in SCRATCH_PRIORITY if n in name2idx] if len(order) < 2: return canvas claimed = np.zeros(canvas.shape[1:], dtype=bool) for ci in order: m = canvas[ci] if not m.any(): continue num, labels = cv2.connectedComponents(m.astype(np.uint8), connectivity=8) keep = np.zeros_like(m) for k in range(1, num): comp = labels == k area = int(comp.sum()) inter = int((comp & claimed).sum()) if area > 0 and inter / area >= SCRATCH_DEDUP_OVERLAP: continue keep |= comp canvas[ci] = keep claimed |= keep return canvas def _extract_defects(self, canvas: np.ndarray, prob_canvas: np.ndarray = None) -> list: """从 [C,H,W] bool canvas 提取每个连通域的缺陷信息。""" defects, did = [], 0 for c in range(canvas.shape[0]): m = canvas[c].astype(np.uint8) if not m.any(): continue num, labels, stats, centroids = cv2.connectedComponentsWithStats(m, connectivity=8) for k in range(1, num): area = int(stats[k, cv2.CC_STAT_AREA]) if area < MIN_DEFECT_AREA: continue x = int(stats[k, cv2.CC_STAT_LEFT]) y = int(stats[k, cv2.CC_STAT_TOP]) bw = int(stats[k, cv2.CC_STAT_WIDTH]) bh = int(stats[k, cv2.CC_STAT_HEIGHT]) cx, cy = centroids[k] comp = labels == k prob = round(float(prob_canvas[c][comp].mean()), 4) if prob_canvas is not None else None roi = comp[y:y + bh, x:x + bw].astype(np.uint8) polygons = _mask_to_polygons(roi, x, y) defects.append({ "id": did, "class_index": int(c), "class_name": self.classes[c], "area": area, "prob": prob, "bbox": [x, y, x + bw, y + bh], "centroid": [round(float(cx), 2), round(float(cy), 2)], "polygons": polygons, }) did += 1 return defects def infer(self, modal_imgs: Dict[str, torch.Tensor], tile_rows: int = TILE_ROWS, tile_cols: int = TILE_COLS, overlap: float = OVERLAP) -> Tuple[np.ndarray, list]: """ 大图切片推理。 overlap 区域只采用离像素最近的 tile 中心的预测,避免多 tile OR 合并导致细线膨胀成块。 """ t_all = time.perf_counter() h = min(t.shape[1] for t in modal_imgs.values()) w = min(t.shape[2] for t in modal_imgs.values()) modal_imgs = {m: t[:, :h, :w].contiguous() for m, t in modal_imgs.items()} slices = _grid_slices(h, w, tile_rows, tile_cols, overlap) logger.info(f"[StitchFusion] 输入大图 {h}x{w}, 切片 {tile_rows}x{tile_cols}={len(slices)} 块, overlap={overlap}") t0 = time.perf_counter() owner = _build_owner_map(h, w, tile_rows, tile_cols, overlap) logger.info(f"[StitchFusion] Voronoi 归属图构建完成, 耗时 {time.perf_counter() - t0:.2f}s") canvas = np.zeros((len(self.classes), h, w), dtype=bool) prob_canvas = np.zeros((len(self.classes), h, w), dtype=np.float32) for idx, (tname, ys, xs) in enumerate(slices): t_tile = time.perf_counter() logger.info(f"[StitchFusion] tile {idx + 1}/{len(slices)} {tname} y={ys.start}-{ys.stop} x={xs.start}-{xs.stop} 推理开始") tile = {m: t[:, ys, xs].contiguous() for m, t in modal_imgs.items()} mask, prob = self._forward_tile(tile) mask = self._filter_edge_noise(mask) own = (owner[ys, xs] == idx) sel = own[np.newaxis, :, :] canvas[:, ys, xs] |= mask & sel np.copyto(prob_canvas[:, ys, xs], prob, where=sel) logger.info(f"[StitchFusion] tile {idx + 1}/{len(slices)} {tname} 完成, 耗时 {time.perf_counter() - t_tile:.2f}s") t0 = time.perf_counter() canvas = self._dedup_scratch(canvas) logger.info(f"[StitchFusion] 划痕跨严重度去重完成, 耗时 {time.perf_counter() - t0:.2f}s") t0 = time.perf_counter() defects = self._extract_defects(canvas, prob_canvas) logger.info(f"[StitchFusion] 连通域提取完成, 缺陷数={len(defects)}, 耗时 {time.perf_counter() - t0:.2f}s") logger.info(f"[StitchFusion] infer 总耗时 {time.perf_counter() - t_all:.2f}s") return canvas, defects def make_json(self, stem: str, h: int, w: int, defects: list) -> dict: """生成符合 labelme 格式的 JSON 结构。""" return { "image": stem, "image_size": [h, w], "masks": [ {"label": d["class_name"], "prob": d.get("prob"), "points": poly} for d in defects for poly in d.get("polygons", []) if len(poly) >= 3 ], }