stitcher_template.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import cv2
  2. from pathlib import Path
  3. import concurrent.futures
  4. import logging
  5. from tqdm import tqdm
  6. from fry_project_classes.stitch_img_template_match import ImageStitcherTemplateMatch
  7. from utils.utils import natural_sort_key
  8. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
  9. def stitch_single_row(row_index, row_image_paths, stitch_params):
  10. NUM_COLS = len(row_image_paths)
  11. OUTPUT_DIR = stitch_params['OUTPUT_DIR']
  12. ESTIMATE_OVERLAP_HORIZONTAL_PIXELS = stitch_params['ESTIMATE_OVERLAP_HORIZONTAL_PIXELS']
  13. BLEND_TYPE = stitch_params['BLEND_TYPE']
  14. LIGHT_COMPENSATION = stitch_params['LIGHT_COMPENSATION']
  15. DEBUG_MODE = stitch_params['DEBUG_MODE']
  16. current_row_image = cv2.imread(str(row_image_paths[0]))
  17. if current_row_image is None:
  18. logging.error(f"错误: 无法读取图片 {row_image_paths[0]}")
  19. return row_index, None
  20. for j in range(1, NUM_COLS):
  21. stitcher_h = ImageStitcherTemplateMatch(
  22. estimate_overlap_pixels=ESTIMATE_OVERLAP_HORIZONTAL_PIXELS,
  23. stitch_type="horizontal",
  24. blend_type=BLEND_TYPE,
  25. light_uniformity_compensation_enabled=LIGHT_COMPENSATION,
  26. debug=DEBUG_MODE,
  27. debug_dir=str(OUTPUT_DIR / f'debug_h_row{row_index + 1}_col{j}vs{j + 1}')
  28. )
  29. next_image = cv2.imread(str(row_image_paths[j]))
  30. if next_image is None:
  31. logging.error(f"错误: 无法读取图片 {row_image_paths[j]}")
  32. return row_index, current_row_image
  33. current_row_image = stitcher_h.stitch_main(current_row_image, next_image)
  34. return row_index, current_row_image
  35. def stitch_img(IMAGE_DIR: Path, OUTPUT_DIR: Path, NUM_COLS: int, NUM_ROWS: int,
  36. ESTIMATE_OVERLAP_HORIZONTAL_PIXELS: int, ESTIMATE_OVERLAP_VERTICAL_PIXELS: int,
  37. BLEND_TYPE: str, LIGHT_COMPENSATION: bool,
  38. DEBUG_MODE: bool) -> Path | None:
  39. """
  40. 基于模板匹配的图像拼接函数。
  41. 成功时返回最终图像的路径,失败时返回 None。
  42. """
  43. OUTPUT_DIR.mkdir(exist_ok=True, parents=True)
  44. logging.info("--- [模板匹配] 图像拼接开始 ---")
  45. logging.info(f"配置: {NUM_ROWS}行 x {NUM_COLS}列, 图片目录: {IMAGE_DIR}")
  46. image_paths = sorted(list(IMAGE_DIR.glob("*.jpg")), key=natural_sort_key)
  47. if len(image_paths) != NUM_COLS * NUM_ROWS:
  48. logging.error(f"错误: 找到 {len(image_paths)} 张图片, 但预期需要 {NUM_COLS * NUM_ROWS} 张。")
  49. return None
  50. # --- 阶段一:并行水平拼接每一行 ---
  51. logging.info("--- 阶段一: 并行水平拼接每一行 ---")
  52. stitch_params = {
  53. 'OUTPUT_DIR': OUTPUT_DIR,
  54. 'ESTIMATE_OVERLAP_HORIZONTAL_PIXELS': ESTIMATE_OVERLAP_HORIZONTAL_PIXELS,
  55. 'BLEND_TYPE': BLEND_TYPE,
  56. 'LIGHT_COMPENSATION': LIGHT_COMPENSATION,
  57. 'DEBUG_MODE': DEBUG_MODE
  58. }
  59. stitched_rows = [None] * NUM_ROWS
  60. with concurrent.futures.ProcessPoolExecutor() as executor:
  61. futures = [
  62. executor.submit(stitch_single_row, i, image_paths[i * NUM_COLS: i * NUM_COLS + NUM_COLS], stitch_params) for
  63. i in range(NUM_ROWS)]
  64. for future in tqdm(concurrent.futures.as_completed(futures), total=NUM_ROWS, desc="[模板]处理行"):
  65. try:
  66. row_index, result_image = future.result()
  67. if result_image is not None:
  68. stitched_rows[row_index] = result_image
  69. else:
  70. logging.warning(f"第 {row_index + 1} 行拼接失败。")
  71. except Exception as exc:
  72. logging.error(f"一个行拼接任务生成了异常: {exc}")
  73. if any(row is None for row in stitched_rows):
  74. logging.error("错误: 存在拼接失败的行,无法进行垂直拼接。")
  75. return None
  76. # --- 阶段二:垂直拼接所有行 ---
  77. logging.info("--- 阶段二: 垂直拼接所有行 ---")
  78. final_image = stitched_rows[0]
  79. for i in tqdm(range(1, NUM_ROWS), desc="[模板]拼接行"):
  80. stitcher_v = ImageStitcherTemplateMatch(
  81. estimate_overlap_pixels=ESTIMATE_OVERLAP_VERTICAL_PIXELS, stitch_type="vertical",
  82. blend_type=BLEND_TYPE, light_uniformity_compensation_enabled=LIGHT_COMPENSATION,
  83. debug=DEBUG_MODE, debug_dir=str(OUTPUT_DIR / f'debug_v_row{i}vs{i + 1}')
  84. )
  85. next_row_image = stitched_rows[i]
  86. final_image = stitcher_v.stitch_main(final_image, next_row_image)
  87. # --- 保存并返回结果 ---
  88. final_output_path = OUTPUT_DIR / "final_stitched_image.jpg"
  89. cv2.imwrite(str(final_output_path), final_image)
  90. logging.info(f"--- [模板匹配] 拼接任务完成!最终图已暂存至: {final_output_path} ---")
  91. return final_output_path