key_point_多线程_test.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. import cv2
  2. import os
  3. import time
  4. from pathlib import Path
  5. import re
  6. from tqdm import tqdm
  7. import concurrent.futures # 导入并发库
  8. # 导入您提供的拼接器类
  9. from fry_project_classes.stitch_img_key_point import ImageStitcherKeyPoint
  10. def natural_sort_key(s):
  11. """
  12. 提供自然排序的键,例如 '2.jpg' 会排在 '10.jpg' 之前。
  13. """
  14. return [int(text) if text.isdigit() else text.lower() for text in re.split(r'(\d+)', str(s))]
  15. # --- 新增:用于并行处理的"任务单元"函数 ---
  16. def stitch_single_row_keypoint(row_index, row_image_paths, stitch_params):
  17. """
  18. 负责使用基于关键点的方法拼接单一一行的图片。这个函数将在独立的进程中运行。
  19. Args:
  20. row_index (int): 当前行的索引(从0开始)。
  21. row_image_paths (list): 这一行所有图片的路径列表。
  22. stitch_params (dict): 包含所有拼接所需参数的字典。
  23. Returns:
  24. tuple: 包含行索引和拼接完成的图像 (row_index, stitched_row_image)。
  25. """
  26. # 从参数字典中解包
  27. NUM_COLS = len(row_image_paths)
  28. OUTPUT_DIR = stitch_params['OUTPUT_DIR']
  29. ESTIMATE_OVERLAP_HORIZONTAL_PIXELS = stitch_params['ESTIMATE_OVERLAP_HORIZONTAL_PIXELS']
  30. BLEND_TYPE = stitch_params['BLEND_TYPE']
  31. FeatureDetector = stitch_params['FeatureDetector']
  32. DEBUG_MODE = stitch_params['DEBUG_MODE']
  33. # 加载行的第一张图片
  34. current_row_image = cv2.imread(str(row_image_paths[0]))
  35. if current_row_image is None:
  36. print(f"错误: 无法读取图片 {row_image_paths[0]}")
  37. return row_index, None
  38. # 依次将该行的后续图片拼接到右侧
  39. for j in range(1, NUM_COLS):
  40. stitcher_h = ImageStitcherKeyPoint(
  41. estimate_overlap_pixels=ESTIMATE_OVERLAP_HORIZONTAL_PIXELS,
  42. stitch_type="horizontal",
  43. blend_type=BLEND_TYPE,
  44. feature_detector=FeatureDetector,
  45. blend_ratio=0.5,
  46. debug=DEBUG_MODE,
  47. debug_dir=str(OUTPUT_DIR / f'debug_h_row{row_index + 1}_col{j}vs{j + 1}')
  48. )
  49. next_image = cv2.imread(str(row_image_paths[j]))
  50. if next_image is None:
  51. print(f"错误: 无法读取图片 {row_image_paths[j]}")
  52. return row_index, current_row_image
  53. current_row_image = stitcher_h.stitch_main(current_row_image, next_image)
  54. # 返回拼接结果和行索引,以便主进程能按正确顺序排列
  55. return row_index, current_row_image
  56. # --- 优化后的主拼接函数 ---
  57. def stitch_img(IMAGE_DIR, OUTPUT_DIR, NUM_COLS: int, NUM_ROWS: int,
  58. ESTIMATE_OVERLAP_HORIZONTAL_PIXELS: int, ESTIMATE_OVERLAP_VERTICAL_PIXELS: int,
  59. BLEND_TYPE: str, FeatureDetector: str,
  60. DEBUG_MODE: bool):
  61. OUTPUT_DIR.mkdir(exist_ok=True)
  62. print("--- 图像拼接开始 ---")
  63. print(f"配置: {NUM_ROWS}行 x {NUM_COLS}列")
  64. print(f"图片目录: {IMAGE_DIR}")
  65. print(f"输出目录: {OUTPUT_DIR}")
  66. print(f"水平重叠预估: {ESTIMATE_OVERLAP_HORIZONTAL_PIXELS}px, 垂直重叠预估: {ESTIMATE_OVERLAP_VERTICAL_PIXELS}px")
  67. print(f"融合模式: {BLEND_TYPE}, 特征检测器类型: {FeatureDetector}")
  68. image_paths = sorted(list(IMAGE_DIR.glob("*.jpg")), key=natural_sort_key)
  69. if len(image_paths) != NUM_COLS * NUM_ROWS:
  70. print(f"错误: 找到 {len(image_paths)} 张图片, 但预期需要 {NUM_COLS * NUM_ROWS} 张。")
  71. return
  72. # --- 3. 阶段一:并行水平拼接每一行 (核心优化点) ---
  73. print("\n--- 阶段一: 并行水平拼接每一行 ---")
  74. # 将所有固定参数打包成字典,方便传递给子进程
  75. stitch_params = {
  76. 'OUTPUT_DIR': OUTPUT_DIR,
  77. 'ESTIMATE_OVERLAP_HORIZONTAL_PIXELS': ESTIMATE_OVERLAP_HORIZONTAL_PIXELS,
  78. 'BLEND_TYPE': BLEND_TYPE,
  79. 'FeatureDetector': FeatureDetector,
  80. 'DEBUG_MODE': DEBUG_MODE
  81. }
  82. stitched_rows = [None] * NUM_ROWS # 预分配列表,用于按顺序存放结果
  83. with concurrent.futures.ProcessPoolExecutor() as executor:
  84. futures = []
  85. for i in range(NUM_ROWS):
  86. row_start_index = i * NUM_COLS
  87. row_image_paths = image_paths[row_start_index: row_start_index + NUM_COLS]
  88. future = executor.submit(stitch_single_row_keypoint, i, row_image_paths, stitch_params)
  89. futures.append(future)
  90. for future in tqdm(concurrent.futures.as_completed(futures), total=NUM_ROWS, desc="处理行"):
  91. try:
  92. row_index, result_image = future.result()
  93. if result_image is not None:
  94. stitched_rows[row_index] = result_image
  95. row_output_path = OUTPUT_DIR / f"stitched_row_{row_index + 1}.jpg"
  96. cv2.imwrite(str(row_output_path), result_image)
  97. tqdm.write(f"第 {row_index + 1} 行拼接完成, 已保存至 {row_output_path}")
  98. else:
  99. tqdm.write(f"第 {row_index + 1} 行拼接失败。")
  100. except Exception as exc:
  101. tqdm.write(f"一个行拼接任务生成了异常: {exc}")
  102. if any(row is None for row in stitched_rows):
  103. print("错误: 存在拼接失败的行,无法进行垂直拼接。")
  104. return
  105. # --- 4. 阶段二:垂直拼接所有行 (保持串行) ---
  106. print("\n--- 阶段二: 垂直拼接所有行 ---")
  107. final_image = stitched_rows[0]
  108. for i in tqdm(range(1, NUM_ROWS), desc="拼接行"):
  109. stitcher_v = ImageStitcherKeyPoint(
  110. estimate_overlap_pixels=ESTIMATE_OVERLAP_VERTICAL_PIXELS,
  111. stitch_type="vertical",
  112. blend_type=BLEND_TYPE,
  113. feature_detector=FeatureDetector,
  114. blend_ratio=0.5,
  115. debug=DEBUG_MODE,
  116. debug_dir=str(OUTPUT_DIR / f'debug_v_row{i}vs{i + 1}')
  117. )
  118. next_row_image = stitched_rows[i]
  119. final_image = stitcher_v.stitch_main(final_image, next_row_image)
  120. # --- 5. 保存最终结果 ---
  121. final_output_path = OUTPUT_DIR / "final_stitched_image.jpg"
  122. cv2.imwrite(str(final_output_path), final_image)
  123. print("\n--- 所有拼接任务完成!---")
  124. print(f"最终的全景图已保存至: {final_output_path}")
  125. def main():
  126. """
  127. 主执行函数
  128. """
  129. # --- 1. 配置参数 ---
  130. IMAGE_DIR = Path(r"C:\Code\ML\Project\StitchImageServer\temp\Input\_250801_1146_0034")
  131. NUM_COLS = 4
  132. NUM_ROWS = 6
  133. # 预估重叠像素
  134. ESTIMATE_OVERLAP_HORIZONTAL_PIXELS = 405
  135. ESTIMATE_OVERLAP_VERTICAL_PIXELS = 440
  136. # 默认为 half_importance_add_weight 和 combine
  137. blend_type_list = ['half_importance', 'right_first', "half_importance_add_weight"]
  138. feature_list = ['sift', 'orb', 'brisk', 'combine']
  139. DEBUG_MODE = False
  140. for BLEND_TYPE in blend_type_list:
  141. for i, feature_type in enumerate(feature_list):
  142. base_dir_path = r"C:\Code\ML\Project\StitchImageServer\temp\key_output" # 建议为keypoint方法用一个新目录
  143. img_dir_name = f"{i}_{BLEND_TYPE}_{feature_type}"
  144. OUTPUT_DIR = Path(os.path.join(base_dir_path, img_dir_name))
  145. print("\n" + "=" * 80)
  146. print(f"开始测试配置: {img_dir_name}")
  147. one_img_time = time.time()
  148. stitch_img(IMAGE_DIR=IMAGE_DIR, OUTPUT_DIR=OUTPUT_DIR, NUM_COLS=NUM_COLS, NUM_ROWS=NUM_ROWS,
  149. ESTIMATE_OVERLAP_HORIZONTAL_PIXELS=ESTIMATE_OVERLAP_HORIZONTAL_PIXELS,
  150. ESTIMATE_OVERLAP_VERTICAL_PIXELS=ESTIMATE_OVERLAP_VERTICAL_PIXELS,
  151. BLEND_TYPE=BLEND_TYPE, FeatureDetector=feature_type,
  152. DEBUG_MODE=DEBUG_MODE)
  153. print(f"\n--- 单次配置完成 ---")
  154. print(f"用时: {time.time() - one_img_time:.2f} 秒, 配置: {img_dir_name}")
  155. print("=" * 80)
  156. if __name__ == '__main__':
  157. start_time = time.time()
  158. main()
  159. end_time = time.time()
  160. print(f"\n总耗时: {end_time - start_time:.2f} 秒")