tiled_infer.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. """
  2. StitchFusion PT 模型大图切片推理。
  3. 流程: 大图 -> 切 TILE_ROWS x TILE_COLS 块(带 overlap) -> 每块推理 -> Voronoi 归属合并 ->
  4. 连通域提取 -> labelme 风格 JSON。
  5. 仅依赖 model_traced.pt + model_meta.json,不依赖项目其他工具模块。
  6. """
  7. import json
  8. import math
  9. from typing import Dict, List, Tuple
  10. import cv2
  11. import numpy as np
  12. import torch
  13. from torch.nn import functional as F
  14. from torchvision import transforms as T
  15. from app.core.logger import get_logger
  16. logger = get_logger(__name__)
  17. # ---- 切片参数 ----
  18. TILE_ROWS = 6
  19. TILE_COLS = 4
  20. OVERLAP = 0.2
  21. # ---- 过滤参数 ----
  22. MIN_DEFECT_AREA = 64 # 整图坐标下连通域最小面积
  23. EDGE_IGNORE_AREA = 128 # tile 内贴边且面积小于此值的碎片忽略
  24. EDGE_MARGIN = 4 # 贴边判断的边缘宽度(像素)
  25. MIN_HOLE_AREA = 16 # 保留面积 >= 此值的孔洞,避免细长划痕交叉围成的区域被填成块
  26. # 三种划痕严重度按优先级从高到低;同一条划痕被多个严重度同时检出时只保留最高优先级
  27. SCRATCH_PRIORITY = ['serious_scratch', 'scratch', 'slight_scratch']
  28. SCRATCH_DEDUP_OVERLAP = 0.8
  29. def _stitch_holes(outer: np.ndarray, holes: list) -> np.ndarray:
  30. """把孔洞用零宽双向桥并入外轮廓,得到 even-odd 填充时仍保留孔洞的单环。"""
  31. ring = outer
  32. for hole in holes:
  33. hc = hole.mean(axis=0)
  34. pi = int(np.argmin(((ring - hc) ** 2).sum(1)))
  35. qi = int(np.argmin(((hole - ring[pi]) ** 2).sum(1)))
  36. hole_loop = np.concatenate([hole[qi:], hole[:qi], hole[qi:qi + 1]], axis=0)
  37. ring = np.concatenate(
  38. [ring[:pi + 1], hole_loop, ring[pi:pi + 1], ring[pi + 1:]], axis=0
  39. )
  40. return ring
  41. def _mask_to_polygons(roi: np.ndarray, ox: int, oy: int) -> list:
  42. """单连通域二值 ROI -> labelme 多边形列表(保留孔洞), 坐标偏移到整图。"""
  43. contours, hierarchy = cv2.findContours(roi, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)
  44. if hierarchy is None:
  45. return []
  46. hierarchy = hierarchy[0]
  47. polygons = []
  48. for ci, hi in enumerate(hierarchy):
  49. if hi[3] != -1:
  50. continue
  51. outer = contours[ci].reshape(-1, 2)
  52. if outer.shape[0] < 3:
  53. continue
  54. holes = []
  55. child = hi[2]
  56. while child != -1:
  57. hole = contours[child].reshape(-1, 2)
  58. if hole.shape[0] >= 3 and cv2.contourArea(contours[child]) >= MIN_HOLE_AREA:
  59. holes.append(hole)
  60. child = hierarchy[child][0]
  61. ring = _stitch_holes(outer, holes) if holes else outer
  62. polygons.append([[int(px + ox), int(py + oy)] for px, py in ring])
  63. return polygons
  64. def _grid_slices(h: int, w: int, rows: int, cols: int, overlap: float):
  65. """rows x cols 个切片, 相邻块保持 overlap 比例重叠。"""
  66. def _axis(length, count):
  67. if count == 1:
  68. return [slice(0, length)]
  69. patch = max(1, min(int(round(length / (count - (count - 1) * overlap))), length))
  70. stride = max(1, int(round(patch * (1.0 - overlap))))
  71. slices = []
  72. for i in range(count):
  73. start = max(0, length - patch) if i == count - 1 else min(i * stride, max(0, length - patch))
  74. stop = length if i == count - 1 else min(length, start + patch)
  75. slices.append(slice(start, stop))
  76. return slices
  77. return [
  78. (f"r{r}c{c}", ys, xs)
  79. for r, ys in enumerate(_axis(h, rows))
  80. for c, xs in enumerate(_axis(w, cols))
  81. ]
  82. class TiledInfer:
  83. """大图切片 PT 模型推理器。"""
  84. def __init__(self, model_pt: str, model_meta: str,
  85. device: str = 'cuda:0', threshold: float = None):
  86. with open(model_meta) as f:
  87. meta = json.load(f)
  88. self.modals = meta['modals']
  89. self.classes = meta['classes']
  90. self.palette = meta['palette']
  91. self.image_size = meta['image_size']
  92. thr = meta['thresholds']
  93. if threshold is not None:
  94. thr = [float(threshold)] * len(self.classes)
  95. if device.startswith('cuda') and not torch.cuda.is_available():
  96. logger.warning(f"{device} 不可用, 回退到 cpu")
  97. device = 'cpu'
  98. self.device = torch.device(device)
  99. self.model = torch.jit.load(model_pt, map_location=self.device)
  100. self.model.eval()
  101. logger.info(f"[StitchFusion] 模型已加载: {model_pt}, device={self.device}")
  102. self.thresholds = torch.tensor(thr, device=self.device).view(-1, 1, 1)
  103. self.tf_main = T.Compose([
  104. T.Lambda(lambda x: x / 255),
  105. T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
  106. T.Lambda(lambda x: x.unsqueeze(0)),
  107. ])
  108. self.tf_aux = T.Compose([
  109. T.Lambda(lambda x: x / 255),
  110. T.Lambda(lambda x: x.unsqueeze(0)),
  111. ])
  112. def _preprocess(self, img: torch.Tensor, is_main: bool) -> torch.Tensor:
  113. H, W = img.shape[1:]
  114. scale = self.image_size[0] / min(H, W)
  115. nH = int(math.ceil(round(H * scale) / 32)) * 32
  116. nW = int(math.ceil(round(W * scale) / 32)) * 32
  117. img = T.Resize((nH, nW))(img)
  118. tf = self.tf_main if is_main else self.tf_aux
  119. return tf(img.float()).to(self.device)
  120. @torch.inference_mode()
  121. def _forward_tile(self, tile_modals: dict) -> Tuple[np.ndarray, np.ndarray]:
  122. """{modal: [3,h,w] uint8} -> (mask [C,h,w] bool, prob [C,h,w] float32)"""
  123. h, w = next(iter(tile_modals.values())).shape[1:]
  124. inputs = [self._preprocess(tile_modals[m], m == 'img') for m in self.modals]
  125. logits = self.model(inputs)
  126. probs = logits.sigmoid().squeeze(0)
  127. probs = F.interpolate(probs.unsqueeze(0), size=(h, w),
  128. mode='bilinear', align_corners=False).squeeze(0)
  129. mask = (probs > self.thresholds).cpu().numpy()
  130. return mask, probs.cpu().numpy().astype(np.float32)
  131. def _filter_edge_noise(self, mask: np.ndarray) -> np.ndarray:
  132. """去除 tile 内贴边的小碎片。"""
  133. out = np.zeros_like(mask, dtype=bool)
  134. for c in range(mask.shape[0]):
  135. m = mask[c].astype(np.uint8)
  136. if not m.any():
  137. continue
  138. num, labels, stats, _ = cv2.connectedComponentsWithStats(m, connectivity=8)
  139. for k in range(1, num):
  140. area = int(stats[k, cv2.CC_STAT_AREA])
  141. comp = labels == k
  142. on_edge = (comp[:EDGE_MARGIN, :].any() or comp[-EDGE_MARGIN:, :].any() or
  143. comp[:, :EDGE_MARGIN].any() or comp[:, -EDGE_MARGIN:].any())
  144. if area < EDGE_IGNORE_AREA and on_edge:
  145. continue
  146. out[c][comp] = True
  147. return out
  148. def _dedup_scratch(self, canvas: np.ndarray) -> np.ndarray:
  149. """跨严重度去重: 同一条划痕被多个 scratch 类检出时只保留最高优先级。"""
  150. name2idx = {n: i for i, n in enumerate(self.classes)}
  151. order = [name2idx[n] for n in SCRATCH_PRIORITY if n in name2idx]
  152. if len(order) < 2:
  153. return canvas
  154. claimed = np.zeros(canvas.shape[1:], dtype=bool)
  155. for ci in order:
  156. m = canvas[ci]
  157. if not m.any():
  158. continue
  159. num, labels = cv2.connectedComponents(m.astype(np.uint8), connectivity=8)
  160. keep = np.zeros_like(m)
  161. for k in range(1, num):
  162. comp = labels == k
  163. area = int(comp.sum())
  164. inter = int((comp & claimed).sum())
  165. if area > 0 and inter / area >= SCRATCH_DEDUP_OVERLAP:
  166. continue
  167. keep |= comp
  168. canvas[ci] = keep
  169. claimed |= keep
  170. return canvas
  171. def _extract_defects(self, canvas: np.ndarray, prob_canvas: np.ndarray = None) -> list:
  172. """从 [C,H,W] bool canvas 提取每个连通域的缺陷信息。"""
  173. defects, did = [], 0
  174. for c in range(canvas.shape[0]):
  175. m = canvas[c].astype(np.uint8)
  176. if not m.any():
  177. continue
  178. num, labels, stats, centroids = cv2.connectedComponentsWithStats(m, connectivity=8)
  179. for k in range(1, num):
  180. area = int(stats[k, cv2.CC_STAT_AREA])
  181. if area < MIN_DEFECT_AREA:
  182. continue
  183. x = int(stats[k, cv2.CC_STAT_LEFT])
  184. y = int(stats[k, cv2.CC_STAT_TOP])
  185. bw = int(stats[k, cv2.CC_STAT_WIDTH])
  186. bh = int(stats[k, cv2.CC_STAT_HEIGHT])
  187. cx, cy = centroids[k]
  188. comp = labels == k
  189. prob = round(float(prob_canvas[c][comp].mean()), 4) if prob_canvas is not None else None
  190. roi = comp[y:y + bh, x:x + bw].astype(np.uint8)
  191. polygons = _mask_to_polygons(roi, x, y)
  192. defects.append({
  193. "id": did,
  194. "class_index": int(c),
  195. "class_name": self.classes[c],
  196. "area": area,
  197. "prob": prob,
  198. "bbox": [x, y, x + bw, y + bh],
  199. "centroid": [round(float(cx), 2), round(float(cy), 2)],
  200. "polygons": polygons,
  201. })
  202. did += 1
  203. return defects
  204. def infer(self, modal_imgs: Dict[str, torch.Tensor],
  205. tile_rows: int = TILE_ROWS, tile_cols: int = TILE_COLS,
  206. overlap: float = OVERLAP) -> Tuple[np.ndarray, list]:
  207. """
  208. 大图切片推理。
  209. overlap 区域只采用离像素最近的 tile 中心的预测,避免多 tile OR 合并导致细线膨胀成块。
  210. """
  211. h = min(t.shape[1] for t in modal_imgs.values())
  212. w = min(t.shape[2] for t in modal_imgs.values())
  213. modal_imgs = {m: t[:, :h, :w].contiguous() for m, t in modal_imgs.items()}
  214. slices = _grid_slices(h, w, tile_rows, tile_cols, overlap)
  215. yy, xx = np.mgrid[0:h, 0:w].astype(np.float32)
  216. min_dist2 = np.full((h, w), np.inf, dtype=np.float32)
  217. owner = np.full((h, w), -1, dtype=np.int32)
  218. for idx, (_, ys, xs) in enumerate(slices):
  219. cy = (ys.start + ys.stop) / 2.0
  220. cx = (xs.start + xs.stop) / 2.0
  221. d2 = (yy - cy) ** 2 + (xx - cx) ** 2
  222. upd = d2 < min_dist2
  223. min_dist2[upd] = d2[upd]
  224. owner[upd] = idx
  225. canvas = np.zeros((len(self.classes), h, w), dtype=bool)
  226. prob_canvas = np.zeros((len(self.classes), h, w), dtype=np.float32)
  227. for idx, (tname, ys, xs) in enumerate(slices):
  228. tile = {m: t[:, ys, xs].contiguous() for m, t in modal_imgs.items()}
  229. mask, prob = self._forward_tile(tile)
  230. mask = self._filter_edge_noise(mask)
  231. own = (owner[ys, xs] == idx)
  232. sel = own[np.newaxis, :, :]
  233. canvas[:, ys, xs] |= mask & sel
  234. np.copyto(prob_canvas[:, ys, xs], prob, where=sel)
  235. logger.info(f" tile {tname}: y={ys.start}-{ys.stop}, x={xs.start}-{xs.stop}")
  236. canvas = self._dedup_scratch(canvas)
  237. defects = self._extract_defects(canvas, prob_canvas)
  238. return canvas, defects
  239. def make_json(self, stem: str, h: int, w: int, defects: list) -> dict:
  240. """生成符合 labelme 格式的 JSON 结构。"""
  241. return {
  242. "image": stem,
  243. "image_size": [h, w],
  244. "masks": [
  245. {"label": d["class_name"], "prob": d.get("prob"), "points": poly}
  246. for d in defects
  247. for poly in d.get("polygons", [])
  248. if len(poly) >= 3
  249. ],
  250. }