key_point_test.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. import cv2
  2. import os
  3. import time
  4. from pathlib import Path
  5. import re
  6. from tqdm import tqdm
  7. # 导入您提供的拼接器类
  8. from fry_project_classes.stitch_img_key_point import ImageStitcherKeyPoint
  9. def natural_sort_key(s):
  10. """
  11. 提供自然排序的键,例如 '2.jpg' 会排在 '10.jpg' 之前。
  12. """
  13. return [int(text) if text.isdigit() else text.lower() for text in re.split(r'(\d+)', str(s))]
  14. def stitch_img(IMAGE_DIR, OUTPUT_DIR, NUM_COLS: int, NUM_ROWS: int,
  15. ESTIMATE_OVERLAP_HORIZONTAL_PIXELS: int, ESTIMATE_OVERLAP_VERTICAL_PIXELS: int,
  16. BLEND_TYPE: str, FeatureDetector: str,
  17. DEBUG_MODE: bool):
  18. OUTPUT_DIR.mkdir(exist_ok=True) # 创建输出文件夹
  19. # --- 2. 加载并排序图片 ---
  20. print("--- 图像拼接开始 ---")
  21. print(f"配置: {NUM_ROWS}行 x {NUM_COLS}列")
  22. print(f"图片目录: {IMAGE_DIR}")
  23. print(f"输出目录: {OUTPUT_DIR}")
  24. print(f"水平重叠预估: {ESTIMATE_OVERLAP_HORIZONTAL_PIXELS}px, 垂直重叠预估: {ESTIMATE_OVERLAP_VERTICAL_PIXELS}px")
  25. print(f"融合模式: {BLEND_TYPE}, 特征检测器类型: {FeatureDetector}")
  26. # --- 2. 加载并排序图片 ---
  27. image_paths = sorted(list(IMAGE_DIR.glob("*.jpg")), key=natural_sort_key)
  28. if len(image_paths) != NUM_COLS * NUM_ROWS:
  29. print(f"错误: 找到 {len(image_paths)} 张图片, 但预期需要 {NUM_COLS * NUM_ROWS} 张。")
  30. return
  31. # --- 3. 阶段一:水平拼接每一行 ---
  32. stitched_rows = []
  33. print("\n--- 阶段一: 水平拼接每一行 ---")
  34. for i in tqdm(range(NUM_ROWS), desc="处理行"):
  35. row_start_index = i * NUM_COLS
  36. row_image_paths = image_paths[row_start_index: row_start_index + NUM_COLS]
  37. # 加载行的第一张图片
  38. current_row_image = cv2.imread(str(row_image_paths[0]))
  39. if current_row_image is None:
  40. print(f"错误: 无法读取图片 {row_image_paths[0]}")
  41. continue
  42. # 依次将该行的后续图片拼接到右侧
  43. for j in range(1, NUM_COLS):
  44. # 为每次拼接实例化一个新的Stitcher对象,以隔离调试文件夹
  45. stitcher_h = ImageStitcherKeyPoint(
  46. estimate_overlap_pixels=ESTIMATE_OVERLAP_HORIZONTAL_PIXELS,
  47. stitch_type="horizontal",
  48. blend_type=BLEND_TYPE,
  49. feature_detector=FeatureDetector,
  50. blend_ratio=0.5,
  51. debug=DEBUG_MODE,
  52. debug_dir=str(OUTPUT_DIR / f'debug_h_row{i + 1}_col{j}vs{j + 1}')
  53. )
  54. next_image = cv2.imread(str(row_image_paths[j]))
  55. if next_image is None:
  56. print(f"错误: 无法读取图片 {row_image_paths[j]}")
  57. break
  58. current_row_image = stitcher_h.stitch_main(current_row_image, next_image)
  59. # 保存拼接好的行
  60. row_output_path = OUTPUT_DIR / f"stitched_row_{i + 1}.jpg"
  61. cv2.imwrite(str(row_output_path), current_row_image)
  62. stitched_rows.append(current_row_image)
  63. tqdm.write(f"第 {i + 1} 行拼接完成, 已保存至 {row_output_path}")
  64. # --- 4. 阶段二:垂直拼接所有行 ---
  65. print("\n--- 阶段二: 垂直拼接所有行 ---")
  66. if not stitched_rows:
  67. print("错误: 没有成功拼接的行,无法进行垂直拼接。")
  68. return
  69. final_image = stitched_rows[0]
  70. for i in tqdm(range(1, NUM_ROWS), desc="拼接行"):
  71. # 实例化垂直拼接器
  72. stitcher_v = ImageStitcherKeyPoint(
  73. estimate_overlap_pixels=ESTIMATE_OVERLAP_VERTICAL_PIXELS,
  74. stitch_type="vertical",
  75. blend_type=BLEND_TYPE,
  76. feature_detector=FeatureDetector,
  77. blend_ratio=0.5,
  78. debug=DEBUG_MODE,
  79. debug_dir=str(OUTPUT_DIR / f'debug_v_row{i}vs{i + 1}')
  80. )
  81. next_row_image = stitched_rows[i]
  82. final_image = stitcher_v.stitch_main(final_image, next_row_image)
  83. # --- 5. 保存最终结果 ---
  84. final_output_path = OUTPUT_DIR / "final_stitched_image.jpg"
  85. cv2.imwrite(str(final_output_path), final_image)
  86. print("\n--- 所有拼接任务完成!---")
  87. print(f"最终的全景图已保存至: {final_output_path}")
  88. def main():
  89. """
  90. 主执行函数
  91. """
  92. # --- 1. 配置参数 ---
  93. # 图片和输出目录设置
  94. IMAGE_DIR = Path(r"C:\Code\ML\Project\StitchImageServer\temp\Input\_250801_1141_0029")
  95. # OUTPUT_DIR = Path(r"C:\Code\ML\Project\StitchImageServer\temp\output")
  96. # 拼图网格设置
  97. NUM_COLS = 4
  98. NUM_ROWS = 6
  99. # !!!关键拼接参数,您可能需要根据实际图片进行调整!!!
  100. # 预估水平方向重叠的像素数。如果您的图片宽1920像素,重叠25%,则该值为 1920 * 0.25 ≈ 480
  101. ESTIMATE_OVERLAP_HORIZONTAL_PIXELS = 500 * 4
  102. # 预估垂直方向重叠的像素数。如果您的图片高1080像素,重叠25%,则该值为 1080 * 0.25 ≈ 270
  103. ESTIMATE_OVERLAP_VERTICAL_PIXELS = 500 * 4
  104. blend_type_list = ["half_importance", "right_first", "left_first", "half_importance_add_weight"]
  105. # BLEND_TYPE = 'blend_half_importance_partial_HSV'
  106. # 是否开启调试模式(会生成大量中间过程图片,用于分析问题)
  107. DEBUG_MODE = False
  108. for i, BLEND_TYPE in enumerate(blend_type_list):
  109. base_dir_path = r"C:\Code\ML\Project\StitchImageServer\temp\output"
  110. img_dir_name = f"{i}_{BLEND_TYPE}"
  111. OUTPUT_DIR = Path(os.path.join(base_dir_path, img_dir_name))
  112. one_img_time = time.time()
  113. stitch_img(IMAGE_DIR=IMAGE_DIR, OUTPUT_DIR=OUTPUT_DIR, NUM_COLS=NUM_COLS, NUM_ROWS=NUM_ROWS,
  114. ESTIMATE_OVERLAP_HORIZONTAL_PIXELS=ESTIMATE_OVERLAP_HORIZONTAL_PIXELS,
  115. ESTIMATE_OVERLAP_VERTICAL_PIXELS=ESTIMATE_OVERLAP_VERTICAL_PIXELS,
  116. BLEND_TYPE=BLEND_TYPE, FeatureDetector="combine",
  117. DEBUG_MODE=DEBUG_MODE)
  118. print(f"{BLEND_TYPE}: {time.time() - one_img_time}")
  119. if __name__ == '__main__':
  120. start_time = time.time()
  121. main()
  122. end_time = time.time()
  123. print(f"\n总耗时: {end_time - start_time:.2f} 秒")