|
|
@@ -0,0 +1,284 @@
|
|
|
+"""
|
|
|
+StitchFusion PT 模型大图切片推理。
|
|
|
+流程: 大图 -> 切 TILE_ROWS x TILE_COLS 块(带 overlap) -> 每块推理 -> Voronoi 归属合并 ->
|
|
|
+ 连通域提取 -> labelme 风格 JSON。
|
|
|
+
|
|
|
+仅依赖 model_traced.pt + model_meta.json,不依赖项目其他工具模块。
|
|
|
+"""
|
|
|
+import json
|
|
|
+import math
|
|
|
+from typing import Dict, List, Tuple
|
|
|
+
|
|
|
+import cv2
|
|
|
+import numpy as np
|
|
|
+import torch
|
|
|
+from torch.nn import functional as F
|
|
|
+from torchvision import transforms as T
|
|
|
+
|
|
|
+from app.core.logger import get_logger
|
|
|
+
|
|
|
+logger = get_logger(__name__)
|
|
|
+
|
|
|
+# ---- 切片参数 ----
|
|
|
+TILE_ROWS = 6
|
|
|
+TILE_COLS = 4
|
|
|
+OVERLAP = 0.2
|
|
|
+
|
|
|
+# ---- 过滤参数 ----
|
|
|
+MIN_DEFECT_AREA = 64 # 整图坐标下连通域最小面积
|
|
|
+EDGE_IGNORE_AREA = 128 # tile 内贴边且面积小于此值的碎片忽略
|
|
|
+EDGE_MARGIN = 4 # 贴边判断的边缘宽度(像素)
|
|
|
+MIN_HOLE_AREA = 16 # 保留面积 >= 此值的孔洞,避免细长划痕交叉围成的区域被填成块
|
|
|
+
|
|
|
+# 三种划痕严重度按优先级从高到低;同一条划痕被多个严重度同时检出时只保留最高优先级
|
|
|
+SCRATCH_PRIORITY = ['serious_scratch', 'scratch', 'slight_scratch']
|
|
|
+SCRATCH_DEDUP_OVERLAP = 0.8
|
|
|
+
|
|
|
+
|
|
|
+def _stitch_holes(outer: np.ndarray, holes: list) -> np.ndarray:
|
|
|
+ """把孔洞用零宽双向桥并入外轮廓,得到 even-odd 填充时仍保留孔洞的单环。"""
|
|
|
+ ring = outer
|
|
|
+ for hole in holes:
|
|
|
+ hc = hole.mean(axis=0)
|
|
|
+ pi = int(np.argmin(((ring - hc) ** 2).sum(1)))
|
|
|
+ qi = int(np.argmin(((hole - ring[pi]) ** 2).sum(1)))
|
|
|
+ hole_loop = np.concatenate([hole[qi:], hole[:qi], hole[qi:qi + 1]], axis=0)
|
|
|
+ ring = np.concatenate(
|
|
|
+ [ring[:pi + 1], hole_loop, ring[pi:pi + 1], ring[pi + 1:]], axis=0
|
|
|
+ )
|
|
|
+ return ring
|
|
|
+
|
|
|
+
|
|
|
+def _mask_to_polygons(roi: np.ndarray, ox: int, oy: int) -> list:
|
|
|
+ """单连通域二值 ROI -> labelme 多边形列表(保留孔洞), 坐标偏移到整图。"""
|
|
|
+ contours, hierarchy = cv2.findContours(roi, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)
|
|
|
+ if hierarchy is None:
|
|
|
+ return []
|
|
|
+ hierarchy = hierarchy[0]
|
|
|
+ polygons = []
|
|
|
+ for ci, hi in enumerate(hierarchy):
|
|
|
+ if hi[3] != -1:
|
|
|
+ continue
|
|
|
+ outer = contours[ci].reshape(-1, 2)
|
|
|
+ if outer.shape[0] < 3:
|
|
|
+ continue
|
|
|
+ holes = []
|
|
|
+ child = hi[2]
|
|
|
+ while child != -1:
|
|
|
+ hole = contours[child].reshape(-1, 2)
|
|
|
+ if hole.shape[0] >= 3 and cv2.contourArea(contours[child]) >= MIN_HOLE_AREA:
|
|
|
+ holes.append(hole)
|
|
|
+ child = hierarchy[child][0]
|
|
|
+ ring = _stitch_holes(outer, holes) if holes else outer
|
|
|
+ polygons.append([[int(px + ox), int(py + oy)] for px, py in ring])
|
|
|
+ return polygons
|
|
|
+
|
|
|
+
|
|
|
+def _grid_slices(h: int, w: int, rows: int, cols: int, overlap: float):
|
|
|
+ """rows x cols 个切片, 相邻块保持 overlap 比例重叠。"""
|
|
|
+ def _axis(length, count):
|
|
|
+ if count == 1:
|
|
|
+ return [slice(0, length)]
|
|
|
+ patch = max(1, min(int(round(length / (count - (count - 1) * overlap))), length))
|
|
|
+ stride = max(1, int(round(patch * (1.0 - overlap))))
|
|
|
+ slices = []
|
|
|
+ for i in range(count):
|
|
|
+ start = max(0, length - patch) if i == count - 1 else min(i * stride, max(0, length - patch))
|
|
|
+ stop = length if i == count - 1 else min(length, start + patch)
|
|
|
+ slices.append(slice(start, stop))
|
|
|
+ return slices
|
|
|
+
|
|
|
+ return [
|
|
|
+ (f"r{r}c{c}", ys, xs)
|
|
|
+ for r, ys in enumerate(_axis(h, rows))
|
|
|
+ for c, xs in enumerate(_axis(w, cols))
|
|
|
+ ]
|
|
|
+
|
|
|
+
|
|
|
+class TiledInfer:
|
|
|
+ """大图切片 PT 模型推理器。"""
|
|
|
+
|
|
|
+ def __init__(self, model_pt: str, model_meta: str,
|
|
|
+ device: str = 'cuda:0', threshold: float = None):
|
|
|
+ with open(model_meta) as f:
|
|
|
+ meta = json.load(f)
|
|
|
+ self.modals = meta['modals']
|
|
|
+ self.classes = meta['classes']
|
|
|
+ self.palette = meta['palette']
|
|
|
+ self.image_size = meta['image_size']
|
|
|
+
|
|
|
+ thr = meta['thresholds']
|
|
|
+ if threshold is not None:
|
|
|
+ thr = [float(threshold)] * len(self.classes)
|
|
|
+
|
|
|
+ if device.startswith('cuda') and not torch.cuda.is_available():
|
|
|
+ logger.warning(f"{device} 不可用, 回退到 cpu")
|
|
|
+ device = 'cpu'
|
|
|
+ self.device = torch.device(device)
|
|
|
+
|
|
|
+ self.model = torch.jit.load(model_pt, map_location=self.device)
|
|
|
+ self.model.eval()
|
|
|
+ logger.info(f"[StitchFusion] 模型已加载: {model_pt}, device={self.device}")
|
|
|
+
|
|
|
+ self.thresholds = torch.tensor(thr, device=self.device).view(-1, 1, 1)
|
|
|
+
|
|
|
+ self.tf_main = T.Compose([
|
|
|
+ T.Lambda(lambda x: x / 255),
|
|
|
+ T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
|
|
+ T.Lambda(lambda x: x.unsqueeze(0)),
|
|
|
+ ])
|
|
|
+ self.tf_aux = T.Compose([
|
|
|
+ T.Lambda(lambda x: x / 255),
|
|
|
+ T.Lambda(lambda x: x.unsqueeze(0)),
|
|
|
+ ])
|
|
|
+
|
|
|
+ def _preprocess(self, img: torch.Tensor, is_main: bool) -> torch.Tensor:
|
|
|
+ H, W = img.shape[1:]
|
|
|
+ scale = self.image_size[0] / min(H, W)
|
|
|
+ nH = int(math.ceil(round(H * scale) / 32)) * 32
|
|
|
+ nW = int(math.ceil(round(W * scale) / 32)) * 32
|
|
|
+ img = T.Resize((nH, nW))(img)
|
|
|
+ tf = self.tf_main if is_main else self.tf_aux
|
|
|
+ return tf(img.float()).to(self.device)
|
|
|
+
|
|
|
+ @torch.inference_mode()
|
|
|
+ def _forward_tile(self, tile_modals: dict) -> Tuple[np.ndarray, np.ndarray]:
|
|
|
+ """{modal: [3,h,w] uint8} -> (mask [C,h,w] bool, prob [C,h,w] float32)"""
|
|
|
+ h, w = next(iter(tile_modals.values())).shape[1:]
|
|
|
+ inputs = [self._preprocess(tile_modals[m], m == 'img') for m in self.modals]
|
|
|
+ logits = self.model(inputs)
|
|
|
+ probs = logits.sigmoid().squeeze(0)
|
|
|
+ probs = F.interpolate(probs.unsqueeze(0), size=(h, w),
|
|
|
+ mode='bilinear', align_corners=False).squeeze(0)
|
|
|
+ mask = (probs > self.thresholds).cpu().numpy()
|
|
|
+ return mask, probs.cpu().numpy().astype(np.float32)
|
|
|
+
|
|
|
+ def _filter_edge_noise(self, mask: np.ndarray) -> np.ndarray:
|
|
|
+ """去除 tile 内贴边的小碎片。"""
|
|
|
+ out = np.zeros_like(mask, dtype=bool)
|
|
|
+ for c in range(mask.shape[0]):
|
|
|
+ m = mask[c].astype(np.uint8)
|
|
|
+ if not m.any():
|
|
|
+ continue
|
|
|
+ num, labels, stats, _ = cv2.connectedComponentsWithStats(m, connectivity=8)
|
|
|
+ for k in range(1, num):
|
|
|
+ area = int(stats[k, cv2.CC_STAT_AREA])
|
|
|
+ comp = labels == k
|
|
|
+ on_edge = (comp[:EDGE_MARGIN, :].any() or comp[-EDGE_MARGIN:, :].any() or
|
|
|
+ comp[:, :EDGE_MARGIN].any() or comp[:, -EDGE_MARGIN:].any())
|
|
|
+ if area < EDGE_IGNORE_AREA and on_edge:
|
|
|
+ continue
|
|
|
+ out[c][comp] = True
|
|
|
+ return out
|
|
|
+
|
|
|
+ def _dedup_scratch(self, canvas: np.ndarray) -> np.ndarray:
|
|
|
+ """跨严重度去重: 同一条划痕被多个 scratch 类检出时只保留最高优先级。"""
|
|
|
+ name2idx = {n: i for i, n in enumerate(self.classes)}
|
|
|
+ order = [name2idx[n] for n in SCRATCH_PRIORITY if n in name2idx]
|
|
|
+ if len(order) < 2:
|
|
|
+ return canvas
|
|
|
+
|
|
|
+ claimed = np.zeros(canvas.shape[1:], dtype=bool)
|
|
|
+ for ci in order:
|
|
|
+ m = canvas[ci]
|
|
|
+ if not m.any():
|
|
|
+ continue
|
|
|
+ num, labels = cv2.connectedComponents(m.astype(np.uint8), connectivity=8)
|
|
|
+ keep = np.zeros_like(m)
|
|
|
+ for k in range(1, num):
|
|
|
+ comp = labels == k
|
|
|
+ area = int(comp.sum())
|
|
|
+ inter = int((comp & claimed).sum())
|
|
|
+ if area > 0 and inter / area >= SCRATCH_DEDUP_OVERLAP:
|
|
|
+ continue
|
|
|
+ keep |= comp
|
|
|
+ canvas[ci] = keep
|
|
|
+ claimed |= keep
|
|
|
+ return canvas
|
|
|
+
|
|
|
+ def _extract_defects(self, canvas: np.ndarray, prob_canvas: np.ndarray = None) -> list:
|
|
|
+ """从 [C,H,W] bool canvas 提取每个连通域的缺陷信息。"""
|
|
|
+ defects, did = [], 0
|
|
|
+ for c in range(canvas.shape[0]):
|
|
|
+ m = canvas[c].astype(np.uint8)
|
|
|
+ if not m.any():
|
|
|
+ continue
|
|
|
+ num, labels, stats, centroids = cv2.connectedComponentsWithStats(m, connectivity=8)
|
|
|
+ for k in range(1, num):
|
|
|
+ area = int(stats[k, cv2.CC_STAT_AREA])
|
|
|
+ if area < MIN_DEFECT_AREA:
|
|
|
+ continue
|
|
|
+ x = int(stats[k, cv2.CC_STAT_LEFT])
|
|
|
+ y = int(stats[k, cv2.CC_STAT_TOP])
|
|
|
+ bw = int(stats[k, cv2.CC_STAT_WIDTH])
|
|
|
+ bh = int(stats[k, cv2.CC_STAT_HEIGHT])
|
|
|
+ cx, cy = centroids[k]
|
|
|
+ comp = labels == k
|
|
|
+ prob = round(float(prob_canvas[c][comp].mean()), 4) if prob_canvas is not None else None
|
|
|
+ roi = comp[y:y + bh, x:x + bw].astype(np.uint8)
|
|
|
+ polygons = _mask_to_polygons(roi, x, y)
|
|
|
+ defects.append({
|
|
|
+ "id": did,
|
|
|
+ "class_index": int(c),
|
|
|
+ "class_name": self.classes[c],
|
|
|
+ "area": area,
|
|
|
+ "prob": prob,
|
|
|
+ "bbox": [x, y, x + bw, y + bh],
|
|
|
+ "centroid": [round(float(cx), 2), round(float(cy), 2)],
|
|
|
+ "polygons": polygons,
|
|
|
+ })
|
|
|
+ did += 1
|
|
|
+ return defects
|
|
|
+
|
|
|
+ def infer(self, modal_imgs: Dict[str, torch.Tensor],
|
|
|
+ tile_rows: int = TILE_ROWS, tile_cols: int = TILE_COLS,
|
|
|
+ overlap: float = OVERLAP) -> Tuple[np.ndarray, list]:
|
|
|
+ """
|
|
|
+ 大图切片推理。
|
|
|
+ overlap 区域只采用离像素最近的 tile 中心的预测,避免多 tile OR 合并导致细线膨胀成块。
|
|
|
+ """
|
|
|
+ h = min(t.shape[1] for t in modal_imgs.values())
|
|
|
+ w = min(t.shape[2] for t in modal_imgs.values())
|
|
|
+ modal_imgs = {m: t[:, :h, :w].contiguous() for m, t in modal_imgs.items()}
|
|
|
+
|
|
|
+ slices = _grid_slices(h, w, tile_rows, tile_cols, overlap)
|
|
|
+
|
|
|
+ yy, xx = np.mgrid[0:h, 0:w].astype(np.float32)
|
|
|
+ min_dist2 = np.full((h, w), np.inf, dtype=np.float32)
|
|
|
+ owner = np.full((h, w), -1, dtype=np.int32)
|
|
|
+ for idx, (_, ys, xs) in enumerate(slices):
|
|
|
+ cy = (ys.start + ys.stop) / 2.0
|
|
|
+ cx = (xs.start + xs.stop) / 2.0
|
|
|
+ d2 = (yy - cy) ** 2 + (xx - cx) ** 2
|
|
|
+ upd = d2 < min_dist2
|
|
|
+ min_dist2[upd] = d2[upd]
|
|
|
+ owner[upd] = idx
|
|
|
+
|
|
|
+ canvas = np.zeros((len(self.classes), h, w), dtype=bool)
|
|
|
+ prob_canvas = np.zeros((len(self.classes), h, w), dtype=np.float32)
|
|
|
+ for idx, (tname, ys, xs) in enumerate(slices):
|
|
|
+ tile = {m: t[:, ys, xs].contiguous() for m, t in modal_imgs.items()}
|
|
|
+ mask, prob = self._forward_tile(tile)
|
|
|
+ mask = self._filter_edge_noise(mask)
|
|
|
+ own = (owner[ys, xs] == idx)
|
|
|
+ sel = own[np.newaxis, :, :]
|
|
|
+ canvas[:, ys, xs] |= mask & sel
|
|
|
+ np.copyto(prob_canvas[:, ys, xs], prob, where=sel)
|
|
|
+ logger.info(f" tile {tname}: y={ys.start}-{ys.stop}, x={xs.start}-{xs.stop}")
|
|
|
+
|
|
|
+ canvas = self._dedup_scratch(canvas)
|
|
|
+ defects = self._extract_defects(canvas, prob_canvas)
|
|
|
+ return canvas, defects
|
|
|
+
|
|
|
+ def make_json(self, stem: str, h: int, w: int, defects: list) -> dict:
|
|
|
+ """生成符合 labelme 格式的 JSON 结构。"""
|
|
|
+ return {
|
|
|
+ "image": stem,
|
|
|
+ "image_size": [h, w],
|
|
|
+ "masks": [
|
|
|
+ {"label": d["class_name"], "prob": d.get("prob"), "points": poly}
|
|
|
+ for d in defects
|
|
|
+ for poly in d.get("polygons", [])
|
|
|
+ if len(poly) >= 3
|
|
|
+ ],
|
|
|
+ }
|