|
@@ -7,6 +7,7 @@ StitchFusion PT 模型大图切片推理。
|
|
|
"""
|
|
"""
|
|
|
import json
|
|
import json
|
|
|
import math
|
|
import math
|
|
|
|
|
+import time
|
|
|
from typing import Dict, List, Tuple
|
|
from typing import Dict, List, Tuple
|
|
|
|
|
|
|
|
import cv2
|
|
import cv2
|
|
@@ -74,27 +75,51 @@ def _mask_to_polygons(roi: np.ndarray, ox: int, oy: int) -> list:
|
|
|
return polygons
|
|
return polygons
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+def _axis_slices(length: int, count: int, overlap: float) -> List[slice]:
|
|
|
|
|
+ """单个轴上的切片划分。"""
|
|
|
|
|
+ if count == 1:
|
|
|
|
|
+ return [slice(0, length)]
|
|
|
|
|
+ patch = max(1, min(int(round(length / (count - (count - 1) * overlap))), length))
|
|
|
|
|
+ stride = max(1, int(round(patch * (1.0 - overlap))))
|
|
|
|
|
+ slices = []
|
|
|
|
|
+ for i in range(count):
|
|
|
|
|
+ start = max(0, length - patch) if i == count - 1 else min(i * stride, max(0, length - patch))
|
|
|
|
|
+ stop = length if i == count - 1 else min(length, start + patch)
|
|
|
|
|
+ slices.append(slice(start, stop))
|
|
|
|
|
+ return slices
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
def _grid_slices(h: int, w: int, rows: int, cols: int, overlap: float):
|
|
def _grid_slices(h: int, w: int, rows: int, cols: int, overlap: float):
|
|
|
"""rows x cols 个切片, 相邻块保持 overlap 比例重叠。"""
|
|
"""rows x cols 个切片, 相邻块保持 overlap 比例重叠。"""
|
|
|
- def _axis(length, count):
|
|
|
|
|
- if count == 1:
|
|
|
|
|
- return [slice(0, length)]
|
|
|
|
|
- patch = max(1, min(int(round(length / (count - (count - 1) * overlap))), length))
|
|
|
|
|
- stride = max(1, int(round(patch * (1.0 - overlap))))
|
|
|
|
|
- slices = []
|
|
|
|
|
- for i in range(count):
|
|
|
|
|
- start = max(0, length - patch) if i == count - 1 else min(i * stride, max(0, length - patch))
|
|
|
|
|
- stop = length if i == count - 1 else min(length, start + patch)
|
|
|
|
|
- slices.append(slice(start, stop))
|
|
|
|
|
- return slices
|
|
|
|
|
-
|
|
|
|
|
|
|
+ y_slices = _axis_slices(h, rows, overlap)
|
|
|
|
|
+ x_slices = _axis_slices(w, cols, overlap)
|
|
|
return [
|
|
return [
|
|
|
(f"r{r}c{c}", ys, xs)
|
|
(f"r{r}c{c}", ys, xs)
|
|
|
- for r, ys in enumerate(_axis(h, rows))
|
|
|
|
|
- for c, xs in enumerate(_axis(w, cols))
|
|
|
|
|
|
|
+ for r, ys in enumerate(y_slices)
|
|
|
|
|
+ for c, xs in enumerate(x_slices)
|
|
|
]
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+def _build_owner_map(h: int, w: int, rows: int, cols: int, overlap: float) -> np.ndarray:
|
|
|
|
|
+ """计算每个像素归属的 tile 索引 (Voronoi 划分, 基于 tile 中心)。
|
|
|
|
|
+
|
|
|
|
|
+ 由于 tile 网格是规则的, 行/列中心独立于另一轴, Voronoi 划分可分离为两个 1D
|
|
|
|
|
+ 问题, 复杂度从 O(H*W*tiles) 降到 O((H+W)*tiles)。
|
|
|
|
|
+ """
|
|
|
|
|
+ y_slices = _axis_slices(h, rows, overlap)
|
|
|
|
|
+ x_slices = _axis_slices(w, cols, overlap)
|
|
|
|
|
+ y_centers = np.array([(s.start + s.stop) / 2.0 for s in y_slices], dtype=np.float32)
|
|
|
|
|
+ x_centers = np.array([(s.start + s.stop) / 2.0 for s in x_slices], dtype=np.float32)
|
|
|
|
|
+
|
|
|
|
|
+ ys = np.arange(h, dtype=np.float32)
|
|
|
|
|
+ xs = np.arange(w, dtype=np.float32)
|
|
|
|
|
+ y_owner = np.argmin(np.abs(ys[:, None] - y_centers[None, :]), axis=1).astype(np.int32)
|
|
|
|
|
+ x_owner = np.argmin(np.abs(xs[:, None] - x_centers[None, :]), axis=1).astype(np.int32)
|
|
|
|
|
+
|
|
|
|
|
+ owner = y_owner[:, None] * cols + x_owner[None, :]
|
|
|
|
|
+ return owner.astype(np.int32)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
class TiledInfer:
|
|
class TiledInfer:
|
|
|
"""大图切片 PT 模型推理器。"""
|
|
"""大图切片 PT 模型推理器。"""
|
|
|
|
|
|
|
@@ -237,26 +262,24 @@ class TiledInfer:
|
|
|
大图切片推理。
|
|
大图切片推理。
|
|
|
overlap 区域只采用离像素最近的 tile 中心的预测,避免多 tile OR 合并导致细线膨胀成块。
|
|
overlap 区域只采用离像素最近的 tile 中心的预测,避免多 tile OR 合并导致细线膨胀成块。
|
|
|
"""
|
|
"""
|
|
|
|
|
+ t_all = time.perf_counter()
|
|
|
|
|
+
|
|
|
h = min(t.shape[1] for t in modal_imgs.values())
|
|
h = min(t.shape[1] for t in modal_imgs.values())
|
|
|
w = min(t.shape[2] for t in modal_imgs.values())
|
|
w = min(t.shape[2] for t in modal_imgs.values())
|
|
|
modal_imgs = {m: t[:, :h, :w].contiguous() for m, t in modal_imgs.items()}
|
|
modal_imgs = {m: t[:, :h, :w].contiguous() for m, t in modal_imgs.items()}
|
|
|
-
|
|
|
|
|
slices = _grid_slices(h, w, tile_rows, tile_cols, overlap)
|
|
slices = _grid_slices(h, w, tile_rows, tile_cols, overlap)
|
|
|
|
|
+ logger.info(f"[StitchFusion] 输入大图 {h}x{w}, 切片 {tile_rows}x{tile_cols}={len(slices)} 块, overlap={overlap}")
|
|
|
|
|
|
|
|
- yy, xx = np.mgrid[0:h, 0:w].astype(np.float32)
|
|
|
|
|
- min_dist2 = np.full((h, w), np.inf, dtype=np.float32)
|
|
|
|
|
- owner = np.full((h, w), -1, dtype=np.int32)
|
|
|
|
|
- for idx, (_, ys, xs) in enumerate(slices):
|
|
|
|
|
- cy = (ys.start + ys.stop) / 2.0
|
|
|
|
|
- cx = (xs.start + xs.stop) / 2.0
|
|
|
|
|
- d2 = (yy - cy) ** 2 + (xx - cx) ** 2
|
|
|
|
|
- upd = d2 < min_dist2
|
|
|
|
|
- min_dist2[upd] = d2[upd]
|
|
|
|
|
- owner[upd] = idx
|
|
|
|
|
|
|
+ t0 = time.perf_counter()
|
|
|
|
|
+ owner = _build_owner_map(h, w, tile_rows, tile_cols, overlap)
|
|
|
|
|
+ logger.info(f"[StitchFusion] Voronoi 归属图构建完成, 耗时 {time.perf_counter() - t0:.2f}s")
|
|
|
|
|
|
|
|
canvas = np.zeros((len(self.classes), h, w), dtype=bool)
|
|
canvas = np.zeros((len(self.classes), h, w), dtype=bool)
|
|
|
prob_canvas = np.zeros((len(self.classes), h, w), dtype=np.float32)
|
|
prob_canvas = np.zeros((len(self.classes), h, w), dtype=np.float32)
|
|
|
for idx, (tname, ys, xs) in enumerate(slices):
|
|
for idx, (tname, ys, xs) in enumerate(slices):
|
|
|
|
|
+ t_tile = time.perf_counter()
|
|
|
|
|
+ logger.info(f"[StitchFusion] tile {idx + 1}/{len(slices)} {tname} y={ys.start}-{ys.stop} x={xs.start}-{xs.stop} 推理开始")
|
|
|
|
|
+
|
|
|
tile = {m: t[:, ys, xs].contiguous() for m, t in modal_imgs.items()}
|
|
tile = {m: t[:, ys, xs].contiguous() for m, t in modal_imgs.items()}
|
|
|
mask, prob = self._forward_tile(tile)
|
|
mask, prob = self._forward_tile(tile)
|
|
|
mask = self._filter_edge_noise(mask)
|
|
mask = self._filter_edge_noise(mask)
|
|
@@ -264,10 +287,16 @@ class TiledInfer:
|
|
|
sel = own[np.newaxis, :, :]
|
|
sel = own[np.newaxis, :, :]
|
|
|
canvas[:, ys, xs] |= mask & sel
|
|
canvas[:, ys, xs] |= mask & sel
|
|
|
np.copyto(prob_canvas[:, ys, xs], prob, where=sel)
|
|
np.copyto(prob_canvas[:, ys, xs], prob, where=sel)
|
|
|
- logger.info(f" tile {tname}: y={ys.start}-{ys.stop}, x={xs.start}-{xs.stop}")
|
|
|
|
|
|
|
+ logger.info(f"[StitchFusion] tile {idx + 1}/{len(slices)} {tname} 完成, 耗时 {time.perf_counter() - t_tile:.2f}s")
|
|
|
|
|
|
|
|
|
|
+ t0 = time.perf_counter()
|
|
|
canvas = self._dedup_scratch(canvas)
|
|
canvas = self._dedup_scratch(canvas)
|
|
|
|
|
+ logger.info(f"[StitchFusion] 划痕跨严重度去重完成, 耗时 {time.perf_counter() - t0:.2f}s")
|
|
|
|
|
+
|
|
|
|
|
+ t0 = time.perf_counter()
|
|
|
defects = self._extract_defects(canvas, prob_canvas)
|
|
defects = self._extract_defects(canvas, prob_canvas)
|
|
|
|
|
+ logger.info(f"[StitchFusion] 连通域提取完成, 缺陷数={len(defects)}, 耗时 {time.perf_counter() - t0:.2f}s")
|
|
|
|
|
+ logger.info(f"[StitchFusion] infer 总耗时 {time.perf_counter() - t_all:.2f}s")
|
|
|
return canvas, defects
|
|
return canvas, defects
|
|
|
|
|
|
|
|
def make_json(self, stem: str, h: int, w: int, defects: list) -> dict:
|
|
def make_json(self, stem: str, h: int, w: int, defects: list) -> dict:
|