stitch.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. import os
  2. import shutil
  3. import uuid
  4. import zipfile
  5. import logging
  6. from pathlib import Path
  7. from fastapi import APIRouter, UploadFile, File, Form, HTTPException, BackgroundTasks
  8. from fastapi.responses import FileResponse, JSONResponse
  9. # 导入我们的核心逻辑和数据模型
  10. from app.core import stitcher_keypoint, stitcher_template
  11. from app.schemas import StitchingMethod, KeypointFeatureDetector, KeypointBlendType, TemplateBlendType
  12. from utils.utils import cleanup_temp_folder
  13. router = APIRouter(prefix="/stitch", tags=['拼图'])
  14. TEMP_DIR = Path("_temp_work")
  15. TEMP_DIR.mkdir(exist_ok=True)
  16. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
  17. @router.post("/", response_class=FileResponse, summary="单个拼图接口")
  18. async def stitch_single_puzzle(
  19. background_tasks: BackgroundTasks,
  20. zip_file: UploadFile = File(..., description="包含一个文件夹的ZIP压缩包,文件夹内有24张小图。"),
  21. # --- 通用参数 ---
  22. method: StitchingMethod = Form(StitchingMethod.TEMPLATE_MATCH, description="选择拼图方法"),
  23. num_cols: int = Form(4, description="拼图的列数"),
  24. num_rows: int = Form(6, description="拼图的行数"),
  25. overlap_h: int = Form(405, description="预估的水平重叠像素"),
  26. overlap_v: int = Form(440, description="预估的垂直重叠像素"),
  27. # --- 点匹配法 (key_point) 特定参数 ---
  28. kp_blend_type: KeypointBlendType = Form(KeypointBlendType.COMBINE, description="[点匹配] 融合模式"),
  29. kp_feature_detector: KeypointFeatureDetector = Form(KeypointFeatureDetector.SIFT,
  30. description="[点匹配] 特征检测器"),
  31. # --- 模板匹配法 (template_match) 特定参数 ---
  32. tm_blend_type: TemplateBlendType = Form(TemplateBlendType.HALF_IMPORTANCE_ADD_WEIGHT,
  33. description="[模板匹配] 融合模式"),
  34. tm_light_compensation: bool = Form(True, description="[模板匹配] 是否启用光照补偿")
  35. ):
  36. """
  37. 上传一个包含24张小图的文件夹的ZIP压缩包,接口会将其拼接成一张大图并返回。
  38. - **zip_file**: 必须是.zip格式,内部应仅包含一个文件夹,该文件夹内含所有待拼接的.jpg图片。
  39. - **返回**: 拼接成功后,返回拼接好的图片文件,文件名与ZIP包内的文件夹名相同。
  40. """
  41. request_id = str(uuid.uuid4())
  42. session_dir = TEMP_DIR / request_id
  43. session_dir.mkdir()
  44. background_tasks.add_task(cleanup_temp_folder, session_dir)
  45. zip_path = session_dir / zip_file.filename
  46. with open(zip_path, "wb") as buffer:
  47. shutil.copyfileobj(zip_file.file, buffer)
  48. extracted_dir = session_dir / "extracted"
  49. extracted_dir.mkdir()
  50. try:
  51. with zipfile.ZipFile(zip_path, 'r') as zf:
  52. zf.extractall(extracted_dir)
  53. except zipfile.BadZipFile:
  54. raise HTTPException(status_code=400, detail="上传的文件不是有效的ZIP格式。")
  55. image_dir = extracted_dir
  56. output_dir = session_dir / "output"
  57. # 根据选择的方法调用不同的拼接函数
  58. stitched_image_path = None
  59. if method == StitchingMethod.KEY_POINT:
  60. stitched_image_path = stitcher_keypoint.stitch_img(
  61. IMAGE_DIR=image_dir, OUTPUT_DIR=output_dir, NUM_COLS=num_cols, NUM_ROWS=num_rows,
  62. ESTIMATE_OVERLAP_HORIZONTAL_PIXELS=overlap_h, ESTIMATE_OVERLAP_VERTICAL_PIXELS=overlap_v,
  63. BLEND_TYPE=kp_blend_type.value, FeatureDetector=kp_feature_detector.value,
  64. DEBUG_MODE=False
  65. )
  66. elif method == StitchingMethod.TEMPLATE_MATCH:
  67. stitched_image_path = stitcher_template.stitch_img(
  68. IMAGE_DIR=image_dir, OUTPUT_DIR=output_dir, NUM_COLS=num_cols, NUM_ROWS=num_rows,
  69. ESTIMATE_OVERLAP_HORIZONTAL_PIXELS=overlap_h, ESTIMATE_OVERLAP_VERTICAL_PIXELS=overlap_v,
  70. BLEND_TYPE=tm_blend_type.value, LIGHT_COMPENSATION=tm_light_compensation,
  71. DEBUG_MODE=False
  72. )
  73. if not stitched_image_path or not stitched_image_path.exists():
  74. raise HTTPException(status_code=500, detail=f"图片拼接失败,请检查服务器日志(请求ID: {request_id})。")
  75. # 使用原始文件夹名命名输出图片
  76. final_filename = f"{image_dir.name}.jpg"
  77. final_filepath = stitched_image_path.rename(stitched_image_path.parent / final_filename)
  78. return FileResponse(
  79. path=final_filepath,
  80. filename=final_filename,
  81. media_type='image/jpeg'
  82. )
  83. @router.post("/batch", response_class=FileResponse, summary="批量拼图接口")
  84. async def stitch_batch_puzzles(
  85. background_tasks: BackgroundTasks,
  86. zip_file: UploadFile = File(..., description="包含多个拼图文件夹的ZIP压缩包。"),
  87. # 参数与单个拼图接口相同
  88. method: StitchingMethod = Form(StitchingMethod.TEMPLATE_MATCH, description="选择拼图方法"),
  89. num_cols: int = Form(4, description="拼图的列数"),
  90. num_rows: int = Form(6, description="拼图的行数"),
  91. overlap_h: int = Form(405, description="预估的水平重叠像素"),
  92. overlap_v: int = Form(440, description="预估的垂直重叠像素"),
  93. kp_blend_type: KeypointBlendType = Form(KeypointBlendType.COMBINE, description="[点匹配] 融合模式"),
  94. kp_feature_detector: KeypointFeatureDetector = Form(KeypointFeatureDetector.SIFT,
  95. description="[点匹配] 特征检测器"),
  96. tm_blend_type: TemplateBlendType = Form(TemplateBlendType.HALF_IMPORTANCE_ADD_WEIGHT,
  97. description="[模板匹配] 融合模式"),
  98. tm_light_compensation: bool = Form(True, description="[模板匹配] 是否启用光照补偿"),
  99. ):
  100. """
  101. 上传一个包含多个拼图文件夹的ZIP压缩包,接口会处理所有文件夹,并将结果打包成一个新的ZIP返回。
  102. - **zip_file**: 必须是.zip格式,内部可以有多个文件夹,每个文件夹都包含待拼接的图片。
  103. - **返回**: 一个ZIP压缩包,里面是所有拼接好的图片,每张图片以其对应的原文件夹名命名。
  104. """
  105. request_id = str(uuid.uuid4())
  106. session_dir = TEMP_DIR / request_id
  107. session_dir.mkdir()
  108. background_tasks.add_task(cleanup_temp_folder, session_dir)
  109. zip_path = session_dir / zip_file.filename
  110. with open(zip_path, "wb") as buffer:
  111. shutil.copyfileobj(zip_file.file, buffer)
  112. extracted_dir = session_dir / "extracted"
  113. extracted_dir.mkdir()
  114. try:
  115. with zipfile.ZipFile(zip_path, 'r') as zf:
  116. zf.extractall(extracted_dir)
  117. except zipfile.BadZipFile:
  118. raise HTTPException(status_code=400, detail="上传的文件不是有效的ZIP格式。")
  119. puzzle_folders = [d for d in extracted_dir.iterdir() if d.is_dir()]
  120. if not puzzle_folders:
  121. raise HTTPException(status_code=400, detail="ZIP包中未找到任何拼图文件夹。")
  122. batch_output_dir = session_dir / "batch_output"
  123. batch_output_dir.mkdir()
  124. processed_count = 0
  125. failed_folders = []
  126. for image_dir in puzzle_folders:
  127. logging.info(f"--- 开始处理批量任务中的文件夹: {image_dir.name} ---")
  128. # 为每个子任务创建一个独立的输出目录
  129. single_output_dir = session_dir / "single_output"
  130. if single_output_dir.exists():
  131. shutil.rmtree(single_output_dir) # 清理上一次循环的输出
  132. stitched_image_path = None
  133. try:
  134. if method == StitchingMethod.KEY_POINT:
  135. stitched_image_path = stitcher_keypoint.stitch_img(
  136. IMAGE_DIR=image_dir, OUTPUT_DIR=single_output_dir,
  137. NUM_COLS=num_cols, NUM_ROWS=num_rows,
  138. ESTIMATE_OVERLAP_HORIZONTAL_PIXELS=overlap_h,
  139. ESTIMATE_OVERLAP_VERTICAL_PIXELS=overlap_v,
  140. BLEND_TYPE=kp_blend_type.value, FeatureDetector=kp_feature_detector.value,
  141. DEBUG_MODE=False
  142. )
  143. elif method == StitchingMethod.TEMPLATE_MATCH:
  144. stitched_image_path = stitcher_template.stitch_img(
  145. IMAGE_DIR=image_dir, OUTPUT_DIR=single_output_dir,
  146. NUM_COLS=num_cols, NUM_ROWS=num_rows,
  147. ESTIMATE_OVERLAP_HORIZONTAL_PIXELS=overlap_h,
  148. ESTIMATE_OVERLAP_VERTICAL_PIXELS=overlap_v,
  149. BLEND_TYPE=tm_blend_type.value, LIGHT_COMPENSATION=tm_light_compensation,
  150. DEBUG_MODE=False
  151. )
  152. if stitched_image_path and stitched_image_path.exists():
  153. # 将成功的结果移到最终的批量输出目录
  154. target_path = batch_output_dir / f"{image_dir.name}.jpg"
  155. shutil.move(str(stitched_image_path), str(target_path))
  156. processed_count += 1
  157. else:
  158. logging.error(f"文件夹 {image_dir.name} 拼接失败。")
  159. failed_folders.append(image_dir.name)
  160. except Exception as e:
  161. logging.error(f"处理文件夹 {image_dir.name} 时发生严重错误: {e}")
  162. failed_folders.append(image_dir.name)
  163. if processed_count == 0:
  164. detail_msg = f"所有 {len(puzzle_folders)} 个文件夹都拼接失败。失败列表: {failed_folders}"
  165. raise HTTPException(status_code=500, detail=detail_msg)
  166. # 将最终结果打包成ZIP
  167. output_zip_path = session_dir / "stitched_results.zip"
  168. shutil.make_archive(str(output_zip_path.with_suffix('')), 'zip', batch_output_dir)
  169. if failed_folders:
  170. logging.warning(f"批量任务完成,但有 {len(failed_folders)} 个文件夹失败: {failed_folders}")
  171. # 可以在响应头中添加自定义信息来通知客户端部分失败
  172. # response.headers["X-Failed-Folders"] = ",".join(failed_folders)
  173. return FileResponse(
  174. path=output_zip_path,
  175. filename="stitched_results.zip",
  176. media_type='application/zip'
  177. )