| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313 |
- """
- StitchFusion PT 模型大图切片推理。
- 流程: 大图 -> 切 TILE_ROWS x TILE_COLS 块(带 overlap) -> 每块推理 -> Voronoi 归属合并 ->
- 连通域提取 -> labelme 风格 JSON。
- 仅依赖 model_traced.pt + model_meta.json,不依赖项目其他工具模块。
- """
- import json
- import math
- import time
- from typing import Dict, List, Tuple
- import cv2
- import numpy as np
- import torch
- from torch.nn import functional as F
- from torchvision import transforms as T
- from app.core.logger import get_logger
- logger = get_logger(__name__)
- # ---- 切片参数 ----
- TILE_ROWS = 3
- TILE_COLS = 2
- OVERLAP = 0.2
- # ---- 过滤参数 ----
- MIN_DEFECT_AREA = 64 # 整图坐标下连通域最小面积
- EDGE_IGNORE_AREA = 128 # tile 内贴边且面积小于此值的碎片忽略
- EDGE_MARGIN = 4 # 贴边判断的边缘宽度(像素)
- MIN_HOLE_AREA = 16 # 保留面积 >= 此值的孔洞,避免细长划痕交叉围成的区域被填成块
- # 三种划痕严重度按优先级从高到低;同一条划痕被多个严重度同时检出时只保留最高优先级
- SCRATCH_PRIORITY = ['serious_scratch', 'scratch', 'slight_scratch']
- SCRATCH_DEDUP_OVERLAP = 0.8
- def _stitch_holes(outer: np.ndarray, holes: list) -> np.ndarray:
- """把孔洞用零宽双向桥并入外轮廓,得到 even-odd 填充时仍保留孔洞的单环。"""
- ring = outer
- for hole in holes:
- hc = hole.mean(axis=0)
- pi = int(np.argmin(((ring - hc) ** 2).sum(1)))
- qi = int(np.argmin(((hole - ring[pi]) ** 2).sum(1)))
- hole_loop = np.concatenate([hole[qi:], hole[:qi], hole[qi:qi + 1]], axis=0)
- ring = np.concatenate(
- [ring[:pi + 1], hole_loop, ring[pi:pi + 1], ring[pi + 1:]], axis=0
- )
- return ring
- def _mask_to_polygons(roi: np.ndarray, ox: int, oy: int) -> list:
- """单连通域二值 ROI -> labelme 多边形列表(保留孔洞), 坐标偏移到整图。"""
- contours, hierarchy = cv2.findContours(roi, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)
- if hierarchy is None:
- return []
- hierarchy = hierarchy[0]
- polygons = []
- for ci, hi in enumerate(hierarchy):
- if hi[3] != -1:
- continue
- outer = contours[ci].reshape(-1, 2)
- if outer.shape[0] < 3:
- continue
- holes = []
- child = hi[2]
- while child != -1:
- hole = contours[child].reshape(-1, 2)
- if hole.shape[0] >= 3 and cv2.contourArea(contours[child]) >= MIN_HOLE_AREA:
- holes.append(hole)
- child = hierarchy[child][0]
- ring = _stitch_holes(outer, holes) if holes else outer
- polygons.append([[int(px + ox), int(py + oy)] for px, py in ring])
- return polygons
- def _axis_slices(length: int, count: int, overlap: float) -> List[slice]:
- """单个轴上的切片划分。"""
- if count == 1:
- return [slice(0, length)]
- patch = max(1, min(int(round(length / (count - (count - 1) * overlap))), length))
- stride = max(1, int(round(patch * (1.0 - overlap))))
- slices = []
- for i in range(count):
- start = max(0, length - patch) if i == count - 1 else min(i * stride, max(0, length - patch))
- stop = length if i == count - 1 else min(length, start + patch)
- slices.append(slice(start, stop))
- return slices
- def _grid_slices(h: int, w: int, rows: int, cols: int, overlap: float):
- """rows x cols 个切片, 相邻块保持 overlap 比例重叠。"""
- y_slices = _axis_slices(h, rows, overlap)
- x_slices = _axis_slices(w, cols, overlap)
- return [
- (f"r{r}c{c}", ys, xs)
- for r, ys in enumerate(y_slices)
- for c, xs in enumerate(x_slices)
- ]
- def _build_owner_map(h: int, w: int, rows: int, cols: int, overlap: float) -> np.ndarray:
- """计算每个像素归属的 tile 索引 (Voronoi 划分, 基于 tile 中心)。
- 由于 tile 网格是规则的, 行/列中心独立于另一轴, Voronoi 划分可分离为两个 1D
- 问题, 复杂度从 O(H*W*tiles) 降到 O((H+W)*tiles)。
- """
- y_slices = _axis_slices(h, rows, overlap)
- x_slices = _axis_slices(w, cols, overlap)
- y_centers = np.array([(s.start + s.stop) / 2.0 for s in y_slices], dtype=np.float32)
- x_centers = np.array([(s.start + s.stop) / 2.0 for s in x_slices], dtype=np.float32)
- ys = np.arange(h, dtype=np.float32)
- xs = np.arange(w, dtype=np.float32)
- y_owner = np.argmin(np.abs(ys[:, None] - y_centers[None, :]), axis=1).astype(np.int32)
- x_owner = np.argmin(np.abs(xs[:, None] - x_centers[None, :]), axis=1).astype(np.int32)
- owner = y_owner[:, None] * cols + x_owner[None, :]
- return owner.astype(np.int32)
- class TiledInfer:
- """大图切片 PT 模型推理器。"""
- def __init__(self, model_pt: str, model_meta: str,
- device: str = 'cuda:0', threshold: float = None):
- with open(model_meta) as f:
- meta = json.load(f)
- self.modals = meta['modals']
- self.classes = meta['classes']
- self.palette = meta['palette']
- self.image_size = meta['image_size']
- thr = meta['thresholds']
- if threshold is not None:
- thr = [float(threshold)] * len(self.classes)
- if device.startswith('cuda') and not torch.cuda.is_available():
- logger.warning(f"{device} 不可用, 回退到 cpu")
- device = 'cpu'
- self.device = torch.device(device)
- self.model = torch.jit.load(model_pt, map_location=self.device)
- self.model.eval()
- logger.info(f"[StitchFusion] 模型已加载: {model_pt}, device={self.device}")
- self.thresholds = torch.tensor(thr, device=self.device).view(-1, 1, 1)
- self.tf_main = T.Compose([
- T.Lambda(lambda x: x / 255),
- T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
- T.Lambda(lambda x: x.unsqueeze(0)),
- ])
- self.tf_aux = T.Compose([
- T.Lambda(lambda x: x / 255),
- T.Lambda(lambda x: x.unsqueeze(0)),
- ])
- def _preprocess(self, img: torch.Tensor, is_main: bool) -> torch.Tensor:
- H, W = img.shape[1:]
- scale = self.image_size[0] / min(H, W)
- nH = int(math.ceil(round(H * scale) / 32)) * 32
- nW = int(math.ceil(round(W * scale) / 32)) * 32
- img = T.Resize((nH, nW))(img)
- tf = self.tf_main if is_main else self.tf_aux
- return tf(img.float()).to(self.device)
- @torch.inference_mode()
- def _forward_tile(self, tile_modals: dict) -> Tuple[np.ndarray, np.ndarray]:
- """{modal: [3,h,w] uint8} -> (mask [C,h,w] bool, prob [C,h,w] float32)"""
- h, w = next(iter(tile_modals.values())).shape[1:]
- inputs = [self._preprocess(tile_modals[m], m == 'img') for m in self.modals]
- logits = self.model(inputs)
- probs = logits.sigmoid().squeeze(0)
- probs = F.interpolate(probs.unsqueeze(0), size=(h, w),
- mode='bilinear', align_corners=False).squeeze(0)
- mask = (probs > self.thresholds).cpu().numpy()
- return mask, probs.cpu().numpy().astype(np.float32)
- def _filter_edge_noise(self, mask: np.ndarray) -> np.ndarray:
- """去除 tile 内贴边的小碎片。"""
- out = np.zeros_like(mask, dtype=bool)
- for c in range(mask.shape[0]):
- m = mask[c].astype(np.uint8)
- if not m.any():
- continue
- num, labels, stats, _ = cv2.connectedComponentsWithStats(m, connectivity=8)
- for k in range(1, num):
- area = int(stats[k, cv2.CC_STAT_AREA])
- comp = labels == k
- on_edge = (comp[:EDGE_MARGIN, :].any() or comp[-EDGE_MARGIN:, :].any() or
- comp[:, :EDGE_MARGIN].any() or comp[:, -EDGE_MARGIN:].any())
- if area < EDGE_IGNORE_AREA and on_edge:
- continue
- out[c][comp] = True
- return out
- def _dedup_scratch(self, canvas: np.ndarray) -> np.ndarray:
- """跨严重度去重: 同一条划痕被多个 scratch 类检出时只保留最高优先级。"""
- name2idx = {n: i for i, n in enumerate(self.classes)}
- order = [name2idx[n] for n in SCRATCH_PRIORITY if n in name2idx]
- if len(order) < 2:
- return canvas
- claimed = np.zeros(canvas.shape[1:], dtype=bool)
- for ci in order:
- m = canvas[ci]
- if not m.any():
- continue
- num, labels = cv2.connectedComponents(m.astype(np.uint8), connectivity=8)
- keep = np.zeros_like(m)
- for k in range(1, num):
- comp = labels == k
- area = int(comp.sum())
- inter = int((comp & claimed).sum())
- if area > 0 and inter / area >= SCRATCH_DEDUP_OVERLAP:
- continue
- keep |= comp
- canvas[ci] = keep
- claimed |= keep
- return canvas
- def _extract_defects(self, canvas: np.ndarray, prob_canvas: np.ndarray = None) -> list:
- """从 [C,H,W] bool canvas 提取每个连通域的缺陷信息。"""
- defects, did = [], 0
- for c in range(canvas.shape[0]):
- m = canvas[c].astype(np.uint8)
- if not m.any():
- continue
- num, labels, stats, centroids = cv2.connectedComponentsWithStats(m, connectivity=8)
- for k in range(1, num):
- area = int(stats[k, cv2.CC_STAT_AREA])
- if area < MIN_DEFECT_AREA:
- continue
- x = int(stats[k, cv2.CC_STAT_LEFT])
- y = int(stats[k, cv2.CC_STAT_TOP])
- bw = int(stats[k, cv2.CC_STAT_WIDTH])
- bh = int(stats[k, cv2.CC_STAT_HEIGHT])
- cx, cy = centroids[k]
- comp = labels == k
- prob = round(float(prob_canvas[c][comp].mean()), 4) if prob_canvas is not None else None
- roi = comp[y:y + bh, x:x + bw].astype(np.uint8)
- polygons = _mask_to_polygons(roi, x, y)
- defects.append({
- "id": did,
- "class_index": int(c),
- "class_name": self.classes[c],
- "area": area,
- "prob": prob,
- "bbox": [x, y, x + bw, y + bh],
- "centroid": [round(float(cx), 2), round(float(cy), 2)],
- "polygons": polygons,
- })
- did += 1
- return defects
- def infer(self, modal_imgs: Dict[str, torch.Tensor],
- tile_rows: int = TILE_ROWS, tile_cols: int = TILE_COLS,
- overlap: float = OVERLAP) -> Tuple[np.ndarray, list]:
- """
- 大图切片推理。
- overlap 区域只采用离像素最近的 tile 中心的预测,避免多 tile OR 合并导致细线膨胀成块。
- """
- t_all = time.perf_counter()
- h = min(t.shape[1] for t in modal_imgs.values())
- w = min(t.shape[2] for t in modal_imgs.values())
- modal_imgs = {m: t[:, :h, :w].contiguous() for m, t in modal_imgs.items()}
- slices = _grid_slices(h, w, tile_rows, tile_cols, overlap)
- logger.info(f"[StitchFusion] 输入大图 {h}x{w}, 切片 {tile_rows}x{tile_cols}={len(slices)} 块, overlap={overlap}")
- t0 = time.perf_counter()
- owner = _build_owner_map(h, w, tile_rows, tile_cols, overlap)
- logger.info(f"[StitchFusion] Voronoi 归属图构建完成, 耗时 {time.perf_counter() - t0:.2f}s")
- canvas = np.zeros((len(self.classes), h, w), dtype=bool)
- prob_canvas = np.zeros((len(self.classes), h, w), dtype=np.float32)
- for idx, (tname, ys, xs) in enumerate(slices):
- t_tile = time.perf_counter()
- logger.info(f"[StitchFusion] tile {idx + 1}/{len(slices)} {tname} y={ys.start}-{ys.stop} x={xs.start}-{xs.stop} 推理开始")
- tile = {m: t[:, ys, xs].contiguous() for m, t in modal_imgs.items()}
- mask, prob = self._forward_tile(tile)
- mask = self._filter_edge_noise(mask)
- own = (owner[ys, xs] == idx)
- sel = own[np.newaxis, :, :]
- canvas[:, ys, xs] |= mask & sel
- np.copyto(prob_canvas[:, ys, xs], prob, where=sel)
- logger.info(f"[StitchFusion] tile {idx + 1}/{len(slices)} {tname} 完成, 耗时 {time.perf_counter() - t_tile:.2f}s")
- t0 = time.perf_counter()
- canvas = self._dedup_scratch(canvas)
- logger.info(f"[StitchFusion] 划痕跨严重度去重完成, 耗时 {time.perf_counter() - t0:.2f}s")
- t0 = time.perf_counter()
- defects = self._extract_defects(canvas, prob_canvas)
- logger.info(f"[StitchFusion] 连通域提取完成, 缺陷数={len(defects)}, 耗时 {time.perf_counter() - t0:.2f}s")
- logger.info(f"[StitchFusion] infer 总耗时 {time.perf_counter() - t_all:.2f}s")
- return canvas, defects
- def make_json(self, stem: str, h: int, w: int, defects: list) -> dict:
- """生成符合 labelme 格式的 JSON 结构。"""
- return {
- "image": stem,
- "image_size": [h, w],
- "masks": [
- {"label": d["class_name"], "prob": d.get("prob"), "points": poly}
- for d in defects
- for poly in d.get("polygons", [])
- if len(poly) >= 3
- ],
- }
|