tiled_infer.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. """
  2. StitchFusion PT 模型大图切片推理。
  3. 流程: 大图 -> 切 TILE_ROWS x TILE_COLS 块(带 overlap) -> 每块推理 -> Voronoi 归属合并 ->
  4. 连通域提取 -> labelme 风格 JSON。
  5. 仅依赖 model_traced.pt + model_meta.json,不依赖项目其他工具模块。
  6. """
  7. import json
  8. import math
  9. import time
  10. from typing import Dict, List, Tuple
  11. import cv2
  12. import numpy as np
  13. import torch
  14. from torch.nn import functional as F
  15. from torchvision import transforms as T
  16. from app.core.logger import get_logger
  17. logger = get_logger(__name__)
  18. # ---- 切片参数 ----
  19. TILE_ROWS = 6
  20. TILE_COLS = 4
  21. OVERLAP = 0.2
  22. # ---- 过滤参数 ----
  23. MIN_DEFECT_AREA = 64 # 整图坐标下连通域最小面积
  24. EDGE_IGNORE_AREA = 128 # tile 内贴边且面积小于此值的碎片忽略
  25. EDGE_MARGIN = 4 # 贴边判断的边缘宽度(像素)
  26. MIN_HOLE_AREA = 16 # 保留面积 >= 此值的孔洞,避免细长划痕交叉围成的区域被填成块
  27. # 三种划痕严重度按优先级从高到低;同一条划痕被多个严重度同时检出时只保留最高优先级
  28. SCRATCH_PRIORITY = ['serious_scratch', 'scratch', 'slight_scratch']
  29. SCRATCH_DEDUP_OVERLAP = 0.8
  30. def _stitch_holes(outer: np.ndarray, holes: list) -> np.ndarray:
  31. """把孔洞用零宽双向桥并入外轮廓,得到 even-odd 填充时仍保留孔洞的单环。"""
  32. ring = outer
  33. for hole in holes:
  34. hc = hole.mean(axis=0)
  35. pi = int(np.argmin(((ring - hc) ** 2).sum(1)))
  36. qi = int(np.argmin(((hole - ring[pi]) ** 2).sum(1)))
  37. hole_loop = np.concatenate([hole[qi:], hole[:qi], hole[qi:qi + 1]], axis=0)
  38. ring = np.concatenate(
  39. [ring[:pi + 1], hole_loop, ring[pi:pi + 1], ring[pi + 1:]], axis=0
  40. )
  41. return ring
  42. def _mask_to_polygons(roi: np.ndarray, ox: int, oy: int) -> list:
  43. """单连通域二值 ROI -> labelme 多边形列表(保留孔洞), 坐标偏移到整图。"""
  44. contours, hierarchy = cv2.findContours(roi, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)
  45. if hierarchy is None:
  46. return []
  47. hierarchy = hierarchy[0]
  48. polygons = []
  49. for ci, hi in enumerate(hierarchy):
  50. if hi[3] != -1:
  51. continue
  52. outer = contours[ci].reshape(-1, 2)
  53. if outer.shape[0] < 3:
  54. continue
  55. holes = []
  56. child = hi[2]
  57. while child != -1:
  58. hole = contours[child].reshape(-1, 2)
  59. if hole.shape[0] >= 3 and cv2.contourArea(contours[child]) >= MIN_HOLE_AREA:
  60. holes.append(hole)
  61. child = hierarchy[child][0]
  62. ring = _stitch_holes(outer, holes) if holes else outer
  63. polygons.append([[int(px + ox), int(py + oy)] for px, py in ring])
  64. return polygons
  65. def _axis_slices(length: int, count: int, overlap: float) -> List[slice]:
  66. """单个轴上的切片划分。"""
  67. if count == 1:
  68. return [slice(0, length)]
  69. patch = max(1, min(int(round(length / (count - (count - 1) * overlap))), length))
  70. stride = max(1, int(round(patch * (1.0 - overlap))))
  71. slices = []
  72. for i in range(count):
  73. start = max(0, length - patch) if i == count - 1 else min(i * stride, max(0, length - patch))
  74. stop = length if i == count - 1 else min(length, start + patch)
  75. slices.append(slice(start, stop))
  76. return slices
  77. def _grid_slices(h: int, w: int, rows: int, cols: int, overlap: float):
  78. """rows x cols 个切片, 相邻块保持 overlap 比例重叠。"""
  79. y_slices = _axis_slices(h, rows, overlap)
  80. x_slices = _axis_slices(w, cols, overlap)
  81. return [
  82. (f"r{r}c{c}", ys, xs)
  83. for r, ys in enumerate(y_slices)
  84. for c, xs in enumerate(x_slices)
  85. ]
  86. def _build_owner_map(h: int, w: int, rows: int, cols: int, overlap: float) -> np.ndarray:
  87. """计算每个像素归属的 tile 索引 (Voronoi 划分, 基于 tile 中心)。
  88. 由于 tile 网格是规则的, 行/列中心独立于另一轴, Voronoi 划分可分离为两个 1D
  89. 问题, 复杂度从 O(H*W*tiles) 降到 O((H+W)*tiles)。
  90. """
  91. y_slices = _axis_slices(h, rows, overlap)
  92. x_slices = _axis_slices(w, cols, overlap)
  93. y_centers = np.array([(s.start + s.stop) / 2.0 for s in y_slices], dtype=np.float32)
  94. x_centers = np.array([(s.start + s.stop) / 2.0 for s in x_slices], dtype=np.float32)
  95. ys = np.arange(h, dtype=np.float32)
  96. xs = np.arange(w, dtype=np.float32)
  97. y_owner = np.argmin(np.abs(ys[:, None] - y_centers[None, :]), axis=1).astype(np.int32)
  98. x_owner = np.argmin(np.abs(xs[:, None] - x_centers[None, :]), axis=1).astype(np.int32)
  99. owner = y_owner[:, None] * cols + x_owner[None, :]
  100. return owner.astype(np.int32)
  101. class TiledInfer:
  102. """大图切片 PT 模型推理器。"""
  103. def __init__(self, model_pt: str, model_meta: str,
  104. device: str = 'cuda:0', threshold: float = None):
  105. with open(model_meta) as f:
  106. meta = json.load(f)
  107. self.modals = meta['modals']
  108. self.classes = meta['classes']
  109. self.palette = meta['palette']
  110. self.image_size = meta['image_size']
  111. thr = meta['thresholds']
  112. if threshold is not None:
  113. thr = [float(threshold)] * len(self.classes)
  114. if device.startswith('cuda') and not torch.cuda.is_available():
  115. logger.warning(f"{device} 不可用, 回退到 cpu")
  116. device = 'cpu'
  117. self.device = torch.device(device)
  118. self.model = torch.jit.load(model_pt, map_location=self.device)
  119. self.model.eval()
  120. logger.info(f"[StitchFusion] 模型已加载: {model_pt}, device={self.device}")
  121. self.thresholds = torch.tensor(thr, device=self.device).view(-1, 1, 1)
  122. self.tf_main = T.Compose([
  123. T.Lambda(lambda x: x / 255),
  124. T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
  125. T.Lambda(lambda x: x.unsqueeze(0)),
  126. ])
  127. self.tf_aux = T.Compose([
  128. T.Lambda(lambda x: x / 255),
  129. T.Lambda(lambda x: x.unsqueeze(0)),
  130. ])
  131. def _preprocess(self, img: torch.Tensor, is_main: bool) -> torch.Tensor:
  132. H, W = img.shape[1:]
  133. scale = self.image_size[0] / min(H, W)
  134. nH = int(math.ceil(round(H * scale) / 32)) * 32
  135. nW = int(math.ceil(round(W * scale) / 32)) * 32
  136. img = T.Resize((nH, nW))(img)
  137. tf = self.tf_main if is_main else self.tf_aux
  138. return tf(img.float()).to(self.device)
  139. @torch.inference_mode()
  140. def _forward_tile(self, tile_modals: dict) -> Tuple[np.ndarray, np.ndarray]:
  141. """{modal: [3,h,w] uint8} -> (mask [C,h,w] bool, prob [C,h,w] float32)"""
  142. h, w = next(iter(tile_modals.values())).shape[1:]
  143. inputs = [self._preprocess(tile_modals[m], m == 'img') for m in self.modals]
  144. logits = self.model(inputs)
  145. probs = logits.sigmoid().squeeze(0)
  146. probs = F.interpolate(probs.unsqueeze(0), size=(h, w),
  147. mode='bilinear', align_corners=False).squeeze(0)
  148. mask = (probs > self.thresholds).cpu().numpy()
  149. return mask, probs.cpu().numpy().astype(np.float32)
  150. def _filter_edge_noise(self, mask: np.ndarray) -> np.ndarray:
  151. """去除 tile 内贴边的小碎片。"""
  152. out = np.zeros_like(mask, dtype=bool)
  153. for c in range(mask.shape[0]):
  154. m = mask[c].astype(np.uint8)
  155. if not m.any():
  156. continue
  157. num, labels, stats, _ = cv2.connectedComponentsWithStats(m, connectivity=8)
  158. for k in range(1, num):
  159. area = int(stats[k, cv2.CC_STAT_AREA])
  160. comp = labels == k
  161. on_edge = (comp[:EDGE_MARGIN, :].any() or comp[-EDGE_MARGIN:, :].any() or
  162. comp[:, :EDGE_MARGIN].any() or comp[:, -EDGE_MARGIN:].any())
  163. if area < EDGE_IGNORE_AREA and on_edge:
  164. continue
  165. out[c][comp] = True
  166. return out
  167. def _dedup_scratch(self, canvas: np.ndarray) -> np.ndarray:
  168. """跨严重度去重: 同一条划痕被多个 scratch 类检出时只保留最高优先级。"""
  169. name2idx = {n: i for i, n in enumerate(self.classes)}
  170. order = [name2idx[n] for n in SCRATCH_PRIORITY if n in name2idx]
  171. if len(order) < 2:
  172. return canvas
  173. claimed = np.zeros(canvas.shape[1:], dtype=bool)
  174. for ci in order:
  175. m = canvas[ci]
  176. if not m.any():
  177. continue
  178. num, labels = cv2.connectedComponents(m.astype(np.uint8), connectivity=8)
  179. keep = np.zeros_like(m)
  180. for k in range(1, num):
  181. comp = labels == k
  182. area = int(comp.sum())
  183. inter = int((comp & claimed).sum())
  184. if area > 0 and inter / area >= SCRATCH_DEDUP_OVERLAP:
  185. continue
  186. keep |= comp
  187. canvas[ci] = keep
  188. claimed |= keep
  189. return canvas
  190. def _extract_defects(self, canvas: np.ndarray, prob_canvas: np.ndarray = None) -> list:
  191. """从 [C,H,W] bool canvas 提取每个连通域的缺陷信息。"""
  192. defects, did = [], 0
  193. for c in range(canvas.shape[0]):
  194. m = canvas[c].astype(np.uint8)
  195. if not m.any():
  196. continue
  197. num, labels, stats, centroids = cv2.connectedComponentsWithStats(m, connectivity=8)
  198. for k in range(1, num):
  199. area = int(stats[k, cv2.CC_STAT_AREA])
  200. if area < MIN_DEFECT_AREA:
  201. continue
  202. x = int(stats[k, cv2.CC_STAT_LEFT])
  203. y = int(stats[k, cv2.CC_STAT_TOP])
  204. bw = int(stats[k, cv2.CC_STAT_WIDTH])
  205. bh = int(stats[k, cv2.CC_STAT_HEIGHT])
  206. cx, cy = centroids[k]
  207. comp = labels == k
  208. prob = round(float(prob_canvas[c][comp].mean()), 4) if prob_canvas is not None else None
  209. roi = comp[y:y + bh, x:x + bw].astype(np.uint8)
  210. polygons = _mask_to_polygons(roi, x, y)
  211. defects.append({
  212. "id": did,
  213. "class_index": int(c),
  214. "class_name": self.classes[c],
  215. "area": area,
  216. "prob": prob,
  217. "bbox": [x, y, x + bw, y + bh],
  218. "centroid": [round(float(cx), 2), round(float(cy), 2)],
  219. "polygons": polygons,
  220. })
  221. did += 1
  222. return defects
  223. def infer(self, modal_imgs: Dict[str, torch.Tensor],
  224. tile_rows: int = TILE_ROWS, tile_cols: int = TILE_COLS,
  225. overlap: float = OVERLAP) -> Tuple[np.ndarray, list]:
  226. """
  227. 大图切片推理。
  228. overlap 区域只采用离像素最近的 tile 中心的预测,避免多 tile OR 合并导致细线膨胀成块。
  229. """
  230. t_all = time.perf_counter()
  231. h = min(t.shape[1] for t in modal_imgs.values())
  232. w = min(t.shape[2] for t in modal_imgs.values())
  233. modal_imgs = {m: t[:, :h, :w].contiguous() for m, t in modal_imgs.items()}
  234. slices = _grid_slices(h, w, tile_rows, tile_cols, overlap)
  235. logger.info(f"[StitchFusion] 输入大图 {h}x{w}, 切片 {tile_rows}x{tile_cols}={len(slices)} 块, overlap={overlap}")
  236. t0 = time.perf_counter()
  237. owner = _build_owner_map(h, w, tile_rows, tile_cols, overlap)
  238. logger.info(f"[StitchFusion] Voronoi 归属图构建完成, 耗时 {time.perf_counter() - t0:.2f}s")
  239. canvas = np.zeros((len(self.classes), h, w), dtype=bool)
  240. prob_canvas = np.zeros((len(self.classes), h, w), dtype=np.float32)
  241. for idx, (tname, ys, xs) in enumerate(slices):
  242. t_tile = time.perf_counter()
  243. logger.info(f"[StitchFusion] tile {idx + 1}/{len(slices)} {tname} y={ys.start}-{ys.stop} x={xs.start}-{xs.stop} 推理开始")
  244. tile = {m: t[:, ys, xs].contiguous() for m, t in modal_imgs.items()}
  245. mask, prob = self._forward_tile(tile)
  246. mask = self._filter_edge_noise(mask)
  247. own = (owner[ys, xs] == idx)
  248. sel = own[np.newaxis, :, :]
  249. canvas[:, ys, xs] |= mask & sel
  250. np.copyto(prob_canvas[:, ys, xs], prob, where=sel)
  251. logger.info(f"[StitchFusion] tile {idx + 1}/{len(slices)} {tname} 完成, 耗时 {time.perf_counter() - t_tile:.2f}s")
  252. t0 = time.perf_counter()
  253. canvas = self._dedup_scratch(canvas)
  254. logger.info(f"[StitchFusion] 划痕跨严重度去重完成, 耗时 {time.perf_counter() - t0:.2f}s")
  255. t0 = time.perf_counter()
  256. defects = self._extract_defects(canvas, prob_canvas)
  257. logger.info(f"[StitchFusion] 连通域提取完成, 缺陷数={len(defects)}, 耗时 {time.perf_counter() - t0:.2f}s")
  258. logger.info(f"[StitchFusion] infer 总耗时 {time.perf_counter() - t_all:.2f}s")
  259. return canvas, defects
  260. def make_json(self, stem: str, h: int, w: int, defects: list) -> dict:
  261. """生成符合 labelme 格式的 JSON 结构。"""
  262. return {
  263. "image": stem,
  264. "image_size": [h, w],
  265. "masks": [
  266. {"label": d["class_name"], "prob": d.get("prob"), "points": poly}
  267. for d in defects
  268. for poly in d.get("polygons", [])
  269. if len(poly) >= 3
  270. ],
  271. }