template_match_多线程_test.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. import cv2
  2. import os
  3. import time
  4. from pathlib import Path
  5. import re
  6. from tqdm import tqdm
  7. import concurrent.futures
  8. from fry_project_classes.stitch_img_template_match import ImageStitcherTemplateMatch
  9. def natural_sort_key(s):
  10. return [int(text) if text.isdigit() else text.lower() for text in re.split(r'(\d+)', str(s))]
  11. # --- 新增:用于并行处理的"任务单元"函数 ---
  12. def stitch_single_row(row_index, row_image_paths, stitch_params):
  13. """
  14. 负责拼接单一一行的图片。这个函数将在独立的进程中运行。
  15. Args:
  16. row_index (int): 当前行的索引(从0开始),用于日志和调试文件命名。
  17. row_image_paths (list): 这一行所有图片的路径列表。
  18. stitch_params (dict): 包含所有拼接所需参数的字典。
  19. Returns:
  20. tuple: 包含行索引和拼接完成的图像 (row_index, stitched_row_image)。
  21. """
  22. # 从参数字典中解包
  23. NUM_COLS = len(row_image_paths)
  24. OUTPUT_DIR = stitch_params['OUTPUT_DIR']
  25. ESTIMATE_OVERLAP_HORIZONTAL_PIXELS = stitch_params['ESTIMATE_OVERLAP_HORIZONTAL_PIXELS']
  26. BLEND_TYPE = stitch_params['BLEND_TYPE']
  27. LIGHT_COMPENSATION = stitch_params['LIGHT_COMPENSATION']
  28. DEBUG_MODE = stitch_params['DEBUG_MODE']
  29. # 加载行的第一张图片
  30. current_row_image = cv2.imread(str(row_image_paths[0]))
  31. if current_row_image is None:
  32. print(f"错误: 无法读取图片 {row_image_paths[0]}")
  33. return row_index, None
  34. # 依次将该行的后续图片拼接到右侧
  35. for j in range(1, NUM_COLS):
  36. stitcher_h = ImageStitcherTemplateMatch(
  37. estimate_overlap_pixels=ESTIMATE_OVERLAP_HORIZONTAL_PIXELS,
  38. stitch_type="horizontal",
  39. blend_type=BLEND_TYPE,
  40. light_uniformity_compensation_enabled=LIGHT_COMPENSATION,
  41. light_uniformity_compensation_width=30,
  42. debug=DEBUG_MODE,
  43. # 注意调试目录的命名,确保不同进程不会写入同一个文件夹
  44. debug_dir=str(OUTPUT_DIR / f'debug_h_row{row_index + 1}_col{j}vs{j + 1}')
  45. )
  46. next_image = cv2.imread(str(row_image_paths[j]))
  47. if next_image is None:
  48. print(f"错误: 无法读取图片 {row_image_paths[j]}")
  49. # 如果中间一张图片读取失败,返回当前已拼接的部分
  50. return row_index, current_row_image
  51. current_row_image = stitcher_h.stitch_main(current_row_image, next_image)
  52. # 返回拼接结果和行索引,以便主进程能按正确顺序排列
  53. return row_index, current_row_image
  54. # --- 优化后的主拼接函数 ---
  55. def stitch_img(IMAGE_DIR, OUTPUT_DIR, NUM_COLS: int, NUM_ROWS: int,
  56. ESTIMATE_OVERLAP_HORIZONTAL_PIXELS: int, ESTIMATE_OVERLAP_VERTICAL_PIXELS: int,
  57. BLEND_TYPE: str, LIGHT_COMPENSATION: bool,
  58. DEBUG_MODE: bool):
  59. OUTPUT_DIR.mkdir(exist_ok=True)
  60. print("--- 图像拼接开始 ---")
  61. print(f"配置: {NUM_ROWS}行 x {NUM_COLS}列")
  62. print(f"图片目录: {IMAGE_DIR}")
  63. print(f"输出目录: {OUTPUT_DIR}")
  64. print(f"水平重叠预估: {ESTIMATE_OVERLAP_HORIZONTAL_PIXELS}px, 垂直重叠预估: {ESTIMATE_OVERLAP_VERTICAL_PIXELS}px")
  65. print(f"融合模式: {BLEND_TYPE}, 光照补偿: {'启用' if LIGHT_COMPENSATION else '禁用'}")
  66. # --- 2. 加载并排序图片 ---
  67. image_paths = sorted(list(IMAGE_DIR.glob("*.jpg")), key=natural_sort_key)
  68. if len(image_paths) != NUM_COLS * NUM_ROWS:
  69. print(f"错误: 找到 {len(image_paths)} 张图片, 但预期需要 {NUM_COLS * NUM_ROWS} 张。")
  70. return
  71. # --- 3. 阶段一:并行水平拼接每一行 (核心优化点) ---
  72. print("\n--- 阶段一: 并行水平拼接每一行 ---")
  73. # 准备传递给每个进程的参数
  74. stitch_params = {
  75. 'OUTPUT_DIR': OUTPUT_DIR,
  76. 'ESTIMATE_OVERLAP_HORIZONTAL_PIXELS': ESTIMATE_OVERLAP_HORIZONTAL_PIXELS,
  77. 'BLEND_TYPE': BLEND_TYPE,
  78. 'LIGHT_COMPENSATION': LIGHT_COMPENSATION,
  79. 'DEBUG_MODE': DEBUG_MODE
  80. }
  81. stitched_rows = [None] * NUM_ROWS # 预先分配列表,用于按顺序存放结果
  82. # 使用进程池执行器
  83. with concurrent.futures.ProcessPoolExecutor() as executor:
  84. # 提交所有行的拼接任务
  85. futures = []
  86. for i in range(NUM_ROWS):
  87. row_start_index = i * NUM_COLS
  88. row_image_paths = image_paths[row_start_index: row_start_index + NUM_COLS]
  89. # 提交任务到进程池
  90. future = executor.submit(stitch_single_row, i, row_image_paths, stitch_params)
  91. futures.append(future)
  92. # 使用tqdm来显示进度条,并收集结果
  93. # as_completed会在任务完成时立即返回,这比直接等待所有任务更具响应性
  94. for future in tqdm(concurrent.futures.as_completed(futures), total=NUM_ROWS, desc="处理行"):
  95. try:
  96. row_index, result_image = future.result()
  97. if result_image is not None:
  98. stitched_rows[row_index] = result_image
  99. # 保存拼接好的行
  100. row_output_path = OUTPUT_DIR / f"stitched_row_{row_index + 1}.jpg"
  101. cv2.imwrite(str(row_output_path), result_image)
  102. tqdm.write(f"第 {row_index + 1} 行拼接完成, 已保存至 {row_output_path}")
  103. else:
  104. tqdm.write(f"第 {row_index + 1} 行拼接失败。")
  105. except Exception as exc:
  106. tqdm.write(f"一个行拼接任务生成了异常: {exc}")
  107. # 检查是否有失败的行
  108. if any(row is None for row in stitched_rows):
  109. print("错误: 存在拼接失败的行,无法进行垂直拼接。")
  110. return
  111. # --- 4. 阶段二:垂直拼接所有行 (这部分保持串行) ---
  112. print("\n--- 阶段二: 垂直拼接所有行 ---")
  113. final_image = stitched_rows[0]
  114. for i in tqdm(range(1, NUM_ROWS), desc="拼接行"):
  115. stitcher_v = ImageStitcherTemplateMatch(
  116. estimate_overlap_pixels=ESTIMATE_OVERLAP_VERTICAL_PIXELS,
  117. stitch_type="vertical",
  118. blend_type=BLEND_TYPE,
  119. light_uniformity_compensation_enabled=LIGHT_COMPENSATION,
  120. light_uniformity_compensation_width=30,
  121. debug=DEBUG_MODE,
  122. debug_dir=str(OUTPUT_DIR / f'debug_v_row{i}vs{i + 1}')
  123. )
  124. next_row_image = stitched_rows[i]
  125. final_image = stitcher_v.stitch_main(final_image, next_row_image)
  126. # --- 5. 保存最终结果 ---
  127. final_output_path = OUTPUT_DIR / "final_stitched_image.jpg"
  128. cv2.imwrite(str(final_output_path), final_image)
  129. print("\n--- 所有拼接任务完成!---")
  130. print(f"最终的全景图已保存至: {final_output_path}")
  131. def main():
  132. """
  133. 主执行函数
  134. """
  135. # --- 1. 配置参数 ---
  136. # 图片和输出目录设置
  137. IMAGE_DIR = Path(r"C:\Code\ML\Project\StitchImageServer\temp\Input\_250801_1146_0034")
  138. # 拼图网格设置
  139. NUM_COLS = 4
  140. NUM_ROWS = 6
  141. # 预估重叠像素
  142. ESTIMATE_OVERLAP_HORIZONTAL_PIXELS = 405
  143. ESTIMATE_OVERLAP_VERTICAL_PIXELS = 440
  144. # 融合模式列表
  145. # 默认 half_importance_add_weight
  146. blend_type_list = ["half_importance_add_weight",
  147. "half_importance_global_brightness", "half_importance_partial_brightness",
  148. "blend_half_importance_partial_HV", "blend_half_importance_partial_SV",
  149. "blend_half_importance_partial_HSV", "blend_half_importance_partial_brightness_add_weight"]
  150. LIGHT_COMPENSATION = True
  151. DEBUG_MODE = False
  152. for i, BLEND_TYPE in enumerate(blend_type_list):
  153. base_dir_path = r"C:\Code\ML\Project\StitchImageServer\temp\output"
  154. img_dir_name = f"{i}_{BLEND_TYPE}"
  155. OUTPUT_DIR = Path(os.path.join(base_dir_path, img_dir_name))
  156. one_img_time = time.time()
  157. stitch_img(IMAGE_DIR=IMAGE_DIR, OUTPUT_DIR=OUTPUT_DIR, NUM_COLS=NUM_COLS, NUM_ROWS=NUM_ROWS,
  158. ESTIMATE_OVERLAP_HORIZONTAL_PIXELS=ESTIMATE_OVERLAP_HORIZONTAL_PIXELS,
  159. ESTIMATE_OVERLAP_VERTICAL_PIXELS=ESTIMATE_OVERLAP_VERTICAL_PIXELS,
  160. BLEND_TYPE=BLEND_TYPE, LIGHT_COMPENSATION=LIGHT_COMPENSATION,
  161. DEBUG_MODE=DEBUG_MODE)
  162. print()
  163. print("_" * 20)
  164. print(f"单个用时: {img_dir_name}: {time.time() - one_img_time}")
  165. print("_" * 20)
  166. if __name__ == '__main__':
  167. start_time = time.time()
  168. main()
  169. end_time = time.time()
  170. print(f"\n总耗时: {end_time - start_time:.2f} 秒")