stitcher_keypoint.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import cv2
  2. from pathlib import Path
  3. import concurrent.futures
  4. import logging
  5. from tqdm import tqdm
  6. from fry_project_classes.stitch_img_key_point import ImageStitcherKeyPoint
  7. from utils.utils import natural_sort_key
  8. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
  9. def stitch_single_row_keypoint(row_index, row_image_paths, stitch_params):
  10. NUM_COLS = len(row_image_paths)
  11. OUTPUT_DIR = stitch_params['OUTPUT_DIR']
  12. ESTIMATE_OVERLAP_HORIZONTAL_PIXELS = stitch_params['ESTIMATE_OVERLAP_HORIZONTAL_PIXELS']
  13. BLEND_TYPE = stitch_params['BLEND_TYPE']
  14. FeatureDetector = stitch_params['FeatureDetector']
  15. DEBUG_MODE = stitch_params['DEBUG_MODE']
  16. current_row_image = cv2.imread(str(row_image_paths[0]))
  17. if current_row_image is None:
  18. logging.error(f"错误: 无法读取图片 {row_image_paths[0]}")
  19. return row_index, None
  20. for j in range(1, NUM_COLS):
  21. stitcher_h = ImageStitcherKeyPoint(
  22. estimate_overlap_pixels=ESTIMATE_OVERLAP_HORIZONTAL_PIXELS,
  23. stitch_type="horizontal",
  24. blend_type=BLEND_TYPE,
  25. feature_detector=FeatureDetector,
  26. blend_ratio=0.5,
  27. debug=DEBUG_MODE,
  28. debug_dir=str(OUTPUT_DIR / f'debug_h_row{row_index + 1}_col{j}vs{j + 1}')
  29. )
  30. next_image = cv2.imread(str(row_image_paths[j]))
  31. if next_image is None:
  32. logging.error(f"错误: 无法读取图片 {row_image_paths[j]}")
  33. return row_index, current_row_image
  34. current_row_image = stitcher_h.stitch_main(current_row_image, next_image)
  35. return row_index, current_row_image
  36. def stitch_img(IMAGE_DIR: Path, OUTPUT_DIR: Path, NUM_COLS: int, NUM_ROWS: int,
  37. ESTIMATE_OVERLAP_HORIZONTAL_PIXELS: int, ESTIMATE_OVERLAP_VERTICAL_PIXELS: int,
  38. BLEND_TYPE: str, FeatureDetector: str,
  39. DEBUG_MODE: bool) -> Path | None:
  40. """
  41. 基于关键点的图像拼接函数。
  42. 成功时返回最终图像的路径,失败时返回 None。
  43. """
  44. OUTPUT_DIR.mkdir(exist_ok=True, parents=True)
  45. logging.info("--- [关键点] 图像拼接开始 ---")
  46. logging.info(f"配置: {NUM_ROWS}行 x {NUM_COLS}列, 图片目录: {IMAGE_DIR}")
  47. image_paths = sorted(list(IMAGE_DIR.glob("*.jpg")), key=natural_sort_key)
  48. if len(image_paths) != NUM_COLS * NUM_ROWS:
  49. logging.error(f"错误: 找到 {len(image_paths)} 张图片, 但预期需要 {NUM_COLS * NUM_ROWS} 张。")
  50. return None
  51. # --- 阶段一:并行水平拼接每一行 ---
  52. logging.info("--- 阶段一: 并行水平拼接每一行 ---")
  53. stitch_params = {
  54. 'OUTPUT_DIR': OUTPUT_DIR,
  55. 'ESTIMATE_OVERLAP_HORIZONTAL_PIXELS': ESTIMATE_OVERLAP_HORIZONTAL_PIXELS,
  56. 'BLEND_TYPE': BLEND_TYPE,
  57. 'FeatureDetector': FeatureDetector,
  58. 'DEBUG_MODE': DEBUG_MODE
  59. }
  60. stitched_rows = [None] * NUM_ROWS
  61. with concurrent.futures.ProcessPoolExecutor() as executor:
  62. futures = [executor.submit(stitch_single_row_keypoint, i, image_paths[i * NUM_COLS: i * NUM_COLS + NUM_COLS],
  63. stitch_params) for i in range(NUM_ROWS)]
  64. for future in tqdm(concurrent.futures.as_completed(futures), total=NUM_ROWS, desc="[关键点]处理行"):
  65. try:
  66. row_index, result_image = future.result()
  67. if result_image is not None:
  68. stitched_rows[row_index] = result_image
  69. else:
  70. logging.warning(f"第 {row_index + 1} 行拼接失败。")
  71. except Exception as exc:
  72. logging.error(f"一个行拼接任务生成了异常: {exc}")
  73. if any(row is None for row in stitched_rows):
  74. logging.error("错误: 存在拼接失败的行,无法进行垂直拼接。")
  75. return None
  76. # --- 阶段二:垂直拼接所有行 ---
  77. logging.info("--- 阶段二: 垂直拼接所有行 ---")
  78. final_image = stitched_rows[0]
  79. for i in tqdm(range(1, NUM_ROWS), desc="[关键点]拼接行"):
  80. stitcher_v = ImageStitcherKeyPoint(
  81. estimate_overlap_pixels=ESTIMATE_OVERLAP_VERTICAL_PIXELS, stitch_type="vertical",
  82. blend_type=BLEND_TYPE, feature_detector=FeatureDetector, blend_ratio=0.5,
  83. debug=DEBUG_MODE, debug_dir=str(OUTPUT_DIR / f'debug_v_row{i}vs{i + 1}')
  84. )
  85. next_row_image = stitched_rows[i]
  86. final_image = stitcher_v.stitch_main(final_image, next_row_image)
  87. # --- 保存并返回结果 ---
  88. final_output_path = OUTPUT_DIR / "final_stitched_image.jpg"
  89. cv2.imwrite(str(final_output_path), final_image)
  90. logging.info(f"--- [关键点] 拼接任务完成!最终图已暂存至: {final_output_path} ---")
  91. return final_output_path