stitch.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. import os
  2. import shutil
  3. import uuid
  4. import zipfile
  5. import logging
  6. from typing import List
  7. from pathlib import Path
  8. from fastapi import APIRouter, UploadFile, File, Form, HTTPException, BackgroundTasks
  9. from fastapi.responses import FileResponse, JSONResponse
  10. # 导入我们的核心逻辑和数据模型
  11. from app.core import stitcher_keypoint, stitcher_template
  12. from app.schemas import StitchingMethod, KeypointFeatureDetector, KeypointBlendType, TemplateBlendType
  13. from utils.utils import cleanup_temp_folder
  14. router = APIRouter(prefix="/stitch", tags=['拼图'])
  15. TEMP_DIR = Path("_temp_work")
  16. TEMP_DIR.mkdir(exist_ok=True)
  17. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
  18. @router.post("/", summary="通用拼图接口")
  19. async def stitch_puzzle(
  20. background_tasks: BackgroundTasks,
  21. zip_file: UploadFile = File(..., description="包含一个或多个拼图文件夹的ZIP压缩包。"),
  22. # --- 通用参数 ---
  23. method: StitchingMethod = Form(StitchingMethod.TEMPLATE_MATCH, description="选择拼图方法"),
  24. num_cols: int = Form(4, description="拼图的列数"),
  25. num_rows: int = Form(6, description="拼图的行数"),
  26. overlap_h: int = Form(405, description="预估的水平重叠像素"),
  27. overlap_v: int = Form(440, description="预估的垂直重叠像素"),
  28. # --- 关键点匹配法 (key_point) 特定参数 ---
  29. kp_blend_type: KeypointBlendType = Form(KeypointBlendType.HALF_IMPORTANCE_ADD_WEIGHT,
  30. description="[关键点] 融合模式"),
  31. kp_feature_detector: KeypointFeatureDetector = Form(KeypointFeatureDetector.SIFT,
  32. description="[关键点] 特征检测器"),
  33. kp_blend_ratio: float = Form(0.5, description="[关键点] 融合权重 (0.0-1.0)"),
  34. # --- 模板匹配法 (template_match) 特定参数 ---
  35. tm_blend_type: TemplateBlendType = Form(TemplateBlendType.HALF_IMPORTANCE_ADD_WEIGHT,
  36. description="[模板匹配] 融合模式"),
  37. tm_blend_ratio: float = Form(0.5, description="[模板匹配] 融合权重 (0.0-1.0)"),
  38. tm_light_compensation: bool = Form(True, description="[模板/关键点] 是否启用光照补偿"),
  39. tm_light_compensation_width: int = Form(15, description="[模板/关键点] 光照补偿宽度 (像素)")
  40. ) -> FileResponse:
  41. """
  42. 上传一个包含拼图文件夹的ZIP压缩包,接口会处理所有文件夹。
  43. - **如果只处理了一个文件夹**:直接返回拼接好的图片文件。
  44. - **如果处理了多个文件夹**:将结果打包成一个新的ZIP压缩包返回。
  45. """
  46. request_id = str(uuid.uuid4())
  47. session_dir = TEMP_DIR / request_id
  48. session_dir.mkdir()
  49. # 确保无论成功或失败,临时文件夹最终都会被清理
  50. background_tasks.add_task(cleanup_temp_folder, session_dir)
  51. zip_path = session_dir / zip_file.filename
  52. with open(zip_path, "wb") as buffer:
  53. shutil.copyfileobj(zip_file.file, buffer)
  54. extracted_dir = session_dir / "extracted"
  55. extracted_dir.mkdir()
  56. try:
  57. with zipfile.ZipFile(zip_path, 'r') as zf:
  58. zf.extractall(extracted_dir)
  59. except zipfile.BadZipFile:
  60. raise HTTPException(status_code=400, detail="上传的文件不是有效的ZIP格式。")
  61. # 智能判断是单文件夹还是多文件夹模式
  62. sub_items = list(extracted_dir.iterdir())
  63. if len(sub_items) == 1 and sub_items[0].is_dir():
  64. # Case 1: ZIP内只有一个顶层文件夹,任务在其子文件夹中
  65. puzzle_folders = [d for d in sub_items[0].iterdir() if d.is_dir()]
  66. # 如果子文件夹为空,则认为顶层文件夹本身就是任务
  67. if not puzzle_folders:
  68. puzzle_folders = [sub_items[0]]
  69. else:
  70. # Case 2: ZIP内有多个文件/文件夹,任务是其中的文件夹
  71. puzzle_folders = [d for d in sub_items if d.is_dir()]
  72. # 如果没有子目录,但有图片,则认为整个解压目录是一个任务
  73. if not puzzle_folders and any(f.suffix.lower() in ['.jpg', '.png', '.jpeg'] for f in sub_items if f.is_file()):
  74. puzzle_folders = [extracted_dir]
  75. if not puzzle_folders:
  76. raise HTTPException(status_code=400, detail="ZIP包中未找到任何包含图片的拼图文件夹。")
  77. batch_output_dir = session_dir / "batch_output"
  78. batch_output_dir.mkdir()
  79. processed_count = 0
  80. failed_folders = []
  81. for image_dir in puzzle_folders:
  82. logging.info(f"--- 开始处理文件夹: {image_dir.name} ---")
  83. single_output_dir = session_dir / f"output_{image_dir.name}"
  84. stitched_image_path = None
  85. try:
  86. if method == StitchingMethod.KEY_POINT:
  87. stitched_image_path = stitcher_keypoint.stitch_img(
  88. IMAGE_DIR=image_dir, OUTPUT_DIR=single_output_dir,
  89. NUM_COLS=num_cols, NUM_ROWS=num_rows,
  90. ESTIMATE_OVERLAP_HORIZONTAL_PIXELS=overlap_h,
  91. ESTIMATE_OVERLAP_VERTICAL_PIXELS=overlap_v,
  92. BLEND_TYPE=kp_blend_type.value, FeatureDetector=kp_feature_detector.value,
  93. BLEND_RATIO=kp_blend_ratio,
  94. LIGHT_COMPENSATION=tm_light_compensation,
  95. LIGHT_COMPENSATION_WIDTH=tm_light_compensation_width,
  96. DEBUG_MODE=False
  97. )
  98. elif method == StitchingMethod.TEMPLATE_MATCH:
  99. stitched_image_path = stitcher_template.stitch_img(
  100. IMAGE_DIR=image_dir, OUTPUT_DIR=single_output_dir,
  101. NUM_COLS=num_cols, NUM_ROWS=num_rows,
  102. ESTIMATE_OVERLAP_HORIZONTAL_PIXELS=overlap_h,
  103. ESTIMATE_OVERLAP_VERTICAL_PIXELS=overlap_v,
  104. BLEND_TYPE=tm_blend_type.value, LIGHT_COMPENSATION=tm_light_compensation,
  105. BLEND_RATIO=tm_blend_ratio,
  106. LIGHT_COMPENSATION_WIDTH=tm_light_compensation_width,
  107. DEBUG_MODE=False
  108. )
  109. if stitched_image_path and stitched_image_path.exists():
  110. target_path = batch_output_dir / f"{image_dir.name}.jpg"
  111. shutil.move(str(stitched_image_path), str(target_path))
  112. processed_count += 1
  113. else:
  114. logging.error(f"文件夹 {image_dir.name} 拼接失败。")
  115. failed_folders.append(image_dir.name)
  116. except Exception as e:
  117. logging.error(f"处理文件夹 {image_dir.name} 时发生严重错误: {e}")
  118. failed_folders.append(image_dir.name)
  119. if processed_count == 0:
  120. detail_msg = f"所有 {len(puzzle_folders)} 个文件夹都拼接失败。失败列表: {failed_folders}"
  121. raise HTTPException(status_code=500, detail=detail_msg)
  122. # --- 核心修改点:根据处理成功的数量决定返回类型 ---
  123. successful_files = list(batch_output_dir.iterdir())
  124. # 如果只有一个成功的拼图结果
  125. if len(successful_files) == 1:
  126. single_result_path = successful_files[0]
  127. logging.info(f"检测到单个成功结果,直接返回图片: {single_result_path.name}")
  128. return FileResponse(
  129. path=single_result_path,
  130. filename=single_result_path.name,
  131. media_type='image/jpeg'
  132. )
  133. # 如果有多个成功的结果,或者即使只有一个成功但原始任务也是多个(为了保持一致性)
  134. else:
  135. output_zip_name = "stitched_results.zip"
  136. output_zip_path = session_dir / output_zip_name
  137. shutil.make_archive(str(output_zip_path.with_suffix('')), 'zip', batch_output_dir)
  138. logging.info(f"检测到多个成功结果,返回ZIP包: {output_zip_name}")
  139. if failed_folders:
  140. logging.warning(f"批量任务完成,但有 {len(failed_folders)} 个文件夹失败: {failed_folders}")
  141. return FileResponse(
  142. path=output_zip_path,
  143. filename=output_zip_name,
  144. media_type='application/zip'
  145. )
  146. @router.post("/folder", response_class=FileResponse, summary="单个拼图接口 (直接上传文件)")
  147. async def stitch_puzzle_from_folder(
  148. background_tasks: BackgroundTasks,
  149. files: List[UploadFile] = File(..., description="一个文件夹中的所有待拼接图片。"),
  150. output_filename_base: str = Form("stitched_result", description="输出图片的基础名称(不含扩展名)。"),
  151. # --- 参数与ZIP接口完全相同 ---
  152. method: StitchingMethod = Form(StitchingMethod.TEMPLATE_MATCH, description="选择拼图方法"),
  153. num_cols: int = Form(4, description="拼图的列数"),
  154. num_rows: int = Form(6, description="拼图的行数"),
  155. overlap_h: int = Form(405, description="预估的水平重叠像素"),
  156. overlap_v: int = Form(440, description="预估的垂直重叠像素"),
  157. kp_blend_type: KeypointBlendType = Form(KeypointBlendType.HALF_IMPORTANCE_ADD_WEIGHT,
  158. description="[关键点] 融合模式"),
  159. kp_feature_detector: KeypointFeatureDetector = Form(KeypointFeatureDetector.SIFT,
  160. description="[关键点] 特征检测器"),
  161. kp_blend_ratio: float = Form(0.5, description="[关键点] 融合权重 (0.0-1.0)"),
  162. tm_blend_type: TemplateBlendType = Form(TemplateBlendType.HALF_IMPORTANCE_ADD_WEIGHT,
  163. description="[模板匹配] 融合模式"),
  164. tm_blend_ratio: float = Form(0.5, description="[模板匹配] 融合权重 (0.0-1.0)"),
  165. tm_light_compensation: bool = Form(True, description="[模板/关键点] 是否启用光照补偿"),
  166. tm_light_compensation_width: int = Form(15, description="[模板/关键点] 光照补偿宽度 (像素)")
  167. ):
  168. """
  169. 上传一个文件夹内的所有图片进行拼接,直接返回拼接好的单张大图。
  170. 此接口专为无法或不便在客户端进行ZIP压缩的场景设计。
  171. """
  172. if not files:
  173. raise HTTPException(status_code=400, detail="没有上传任何文件。")
  174. request_id = str(uuid.uuid4())
  175. session_dir = TEMP_DIR / request_id
  176. session_dir.mkdir()
  177. background_tasks.add_task(cleanup_temp_folder, session_dir)
  178. # 在会话目录中创建一个子目录来存放上传的图片,模拟一个文件夹
  179. image_dir = session_dir / "images"
  180. image_dir.mkdir()
  181. for upload_file in files:
  182. file_path = image_dir / upload_file.filename
  183. with open(file_path, "wb") as buffer:
  184. shutil.copyfileobj(upload_file.file, buffer)
  185. output_dir = session_dir / "output"
  186. stitched_image_path = None
  187. try:
  188. if method == StitchingMethod.KEY_POINT:
  189. stitched_image_path = stitcher_keypoint.stitch_img(
  190. IMAGE_DIR=image_dir, OUTPUT_DIR=output_dir, NUM_COLS=num_cols, NUM_ROWS=num_rows,
  191. ESTIMATE_OVERLAP_HORIZONTAL_PIXELS=overlap_h, ESTIMATE_OVERLAP_VERTICAL_PIXELS=overlap_v,
  192. BLEND_TYPE=kp_blend_type.value, FeatureDetector=kp_feature_detector.value, BLEND_RATIO=kp_blend_ratio,
  193. LIGHT_COMPENSATION=tm_light_compensation, LIGHT_COMPENSATION_WIDTH=tm_light_compensation_width,
  194. DEBUG_MODE=False
  195. )
  196. elif method == StitchingMethod.TEMPLATE_MATCH:
  197. stitched_image_path = stitcher_template.stitch_img(
  198. IMAGE_DIR=image_dir, OUTPUT_DIR=output_dir, NUM_COLS=num_cols, NUM_ROWS=num_rows,
  199. ESTIMATE_OVERLAP_HORIZONTAL_PIXELS=overlap_h, ESTIMATE_OVERLAP_VERTICAL_PIXELS=overlap_v,
  200. BLEND_TYPE=tm_blend_type.value, LIGHT_COMPENSATION=tm_light_compensation, BLEND_RATIO=tm_blend_ratio,
  201. LIGHT_COMPENSATION_WIDTH=tm_light_compensation_width, DEBUG_MODE=False
  202. )
  203. except Exception as e:
  204. raise HTTPException(status_code=500, detail=f"图片拼接过程中发生内部错误: {e}")
  205. if not stitched_image_path or not stitched_image_path.exists():
  206. raise HTTPException(status_code=500, detail=f"图片拼接失败,未能生成结果文件。")
  207. final_filename = f"{output_filename_base}.jpg"
  208. # 我们直接从最终的输出路径返回,不需要移动文件
  209. return FileResponse(
  210. path=stitched_image_path,
  211. filename=final_filename,
  212. media_type='image/jpeg'
  213. )