stitch_v2.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. import os
  2. import shutil
  3. import uuid
  4. import zipfile
  5. import logging
  6. from pathlib import Path
  7. from typing import List # 导入 List 类型
  8. from fastapi import APIRouter, UploadFile, File, Form, HTTPException, BackgroundTasks
  9. from fastapi.responses import FileResponse, JSONResponse
  10. # 导入我们的核心逻辑和数据模型
  11. from app.core import stitcher_keypoint, stitcher_template
  12. from app.schemas import StitchingMethod, KeypointFeatureDetector, KeypointBlendType, TemplateBlendType
  13. from utils.utils import cleanup_temp_folder
  14. router = APIRouter(prefix="/stitch", tags=['拼图'])
  15. TEMP_DIR = Path("_temp_work")
  16. TEMP_DIR.mkdir(exist_ok=True)
  17. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
  18. # --- 内部辅助函数,用于处理单个拼图任务,避免代码重复 ---
  19. def _process_single_puzzle(
  20. image_dir: Path,
  21. output_dir: Path,
  22. method: StitchingMethod,
  23. params: dict
  24. ) -> Path | None:
  25. """
  26. 处理单个拼图任务的核心逻辑。
  27. :param image_dir: 包含所有小图的输入目录。
  28. :param output_dir: 存放拼接结果的输出目录。
  29. :param method: 拼图方法。
  30. :param params: 包含所有拼图参数的字典。
  31. :return: 成功则返回拼接后图片的路径,否则返回 None。
  32. """
  33. output_dir.mkdir(exist_ok=True)
  34. stitched_image_path = None
  35. try:
  36. if method == StitchingMethod.KEY_POINT:
  37. stitched_image_path = stitcher_keypoint.stitch_img(
  38. IMAGE_DIR=image_dir, OUTPUT_DIR=output_dir,
  39. NUM_COLS=params["num_cols"], NUM_ROWS=params["num_rows"],
  40. ESTIMATE_OVERLAP_HORIZONTAL_PIXELS=params["overlap_h"],
  41. ESTIMATE_OVERLAP_VERTICAL_PIXELS=params["overlap_v"],
  42. BLEND_TYPE=params["kp_blend_type"].value,
  43. FeatureDetector=params["kp_feature_detector"].value,
  44. DEBUG_MODE=False
  45. )
  46. elif method == StitchingMethod.TEMPLATE_MATCH:
  47. stitched_image_path = stitcher_template.stitch_img(
  48. IMAGE_DIR=image_dir, OUTPUT_DIR=output_dir,
  49. NUM_COLS=params["num_cols"], NUM_ROWS=params["num_rows"],
  50. ESTIMATE_OVERLAP_HORIZONTAL_PIXELS=params["overlap_h"],
  51. ESTIMATE_OVERLAP_VERTICAL_PIXELS=params["overlap_v"],
  52. BLEND_TYPE=params["tm_blend_type"].value,
  53. LIGHT_COMPENSATION=params["tm_light_compensation"],
  54. DEBUG_MODE=False
  55. )
  56. except Exception as e:
  57. logging.error(f"处理文件夹 {image_dir.name} 时发生错误: {e}")
  58. return None
  59. return stitched_image_path
  60. # --- 新的单个拼图接口 (直接上传图片文件夹) ---
  61. @router.post("/from-folder", response_class=FileResponse, summary="单个拼图接口 (从文件夹上传)")
  62. async def stitch_single_from_folder(
  63. background_tasks: BackgroundTasks,
  64. files: List[UploadFile] = File(..., description="一个文件夹中的所有待拼接图片。"),
  65. output_filename_base: str = Form(..., description="输出图片的基础名称(不含扩展名),例如 'puzzle_A'。"),
  66. # --- 通用参数 ---
  67. method: StitchingMethod = Form(StitchingMethod.TEMPLATE_MATCH, description="选择拼图方法"),
  68. num_cols: int = Form(4, description="拼图的列数"),
  69. num_rows: int = Form(6, description="拼图的行数"),
  70. overlap_h: int = Form(405, description="预估的水平重叠像素"),
  71. overlap_v: int = Form(440, description="预估的垂直重叠像素"),
  72. # --- 点匹配法 (key_point) 特定参数 ---
  73. kp_blend_type: KeypointBlendType = Form(KeypointBlendType.COMBINE, description="[点匹配] 融合模式"),
  74. kp_feature_detector: KeypointFeatureDetector = Form(KeypointFeatureDetector.SIFT,
  75. description="[点匹配] 特征检测器"),
  76. # --- 模板匹配法 (template_match) 特定参数 ---
  77. tm_blend_type: TemplateBlendType = Form(TemplateBlendType.HALF_IMPORTANCE_ADD_WEIGHT,
  78. description="[模板匹配] 融合模式"),
  79. tm_light_compensation: bool = Form(True, description="[模板匹配] 是否启用光照补偿")
  80. ):
  81. """
  82. 上传一个文件夹内的所有图片进行拼接,直接返回拼接好的单张大图。
  83. - **files**: 选择一个文件夹中的所有图片进行上传。
  84. - **output_filename_base**: 为你的拼图任务命名,这个名字将作为返回图片的文件名。
  85. - **返回**: 拼接成功后,返回拼接好的图片文件。
  86. """
  87. if not files:
  88. raise HTTPException(status_code=400, detail="没有上传任何文件。")
  89. request_id = str(uuid.uuid4())
  90. session_dir = TEMP_DIR / request_id
  91. session_dir.mkdir()
  92. background_tasks.add_task(cleanup_temp_folder, session_dir)
  93. # 创建用于存放上传图片的临时目录
  94. image_dir = session_dir / "images"
  95. image_dir.mkdir()
  96. # 保存所有上传的文件
  97. for upload_file in files:
  98. file_path = image_dir / upload_file.filename
  99. with open(file_path, "wb") as buffer:
  100. shutil.copyfileobj(upload_file.file, buffer)
  101. output_dir = session_dir / "output"
  102. # 将所有参数打包到一个字典中,方便传递
  103. params = locals()
  104. # 调用核心处理函数
  105. stitched_image_path = _process_single_puzzle(image_dir, output_dir, method, params)
  106. if not stitched_image_path or not stitched_image_path.exists():
  107. raise HTTPException(status_code=500, detail=f"图片拼接失败,请检查服务器日志(请求ID: {request_id})。")
  108. # 使用用户提供的基础名称命名输出图片
  109. final_filename = f"{output_filename_base}.jpg"
  110. final_filepath = stitched_image_path.rename(stitched_image_path.parent / final_filename)
  111. return FileResponse(
  112. path=final_filepath,
  113. filename=final_filename,
  114. media_type='image/jpeg'
  115. )
  116. # --- 新的批量拼图接口 (上传一个拼图文件夹,返回ZIP) ---
  117. @router.post("/batch/from-folder", response_class=FileResponse, summary="批量拼图接口 (单文件夹上传, ZIP返回)")
  118. async def stitch_batch_from_folder(
  119. background_tasks: BackgroundTasks,
  120. files: List[UploadFile] = File(..., description="一个文件夹中的所有待拼接图片。"),
  121. output_filename_base: str = Form(..., description="输出图片的基础名称(不含扩展名),例如 'puzzle_A'。"),
  122. # --- 参数与单个拼图接口相同 ---
  123. method: StitchingMethod = Form(StitchingMethod.TEMPLATE_MATCH, description="选择拼图方法"),
  124. num_cols: int = Form(4, description="拼图的列数"),
  125. num_rows: int = Form(6, description="拼图的行数"),
  126. overlap_h: int = Form(405, description="预估的水平重叠像素"),
  127. overlap_v: int = Form(440, description="预估的垂直重叠像素"),
  128. kp_blend_type: KeypointBlendType = Form(KeypointBlendType.COMBINE, description="[点匹配] 融合模式"),
  129. kp_feature_detector: KeypointFeatureDetector = Form(KeypointFeatureDetector.SIFT,
  130. description="[点匹配] 特征检测器"),
  131. tm_blend_type: TemplateBlendType = Form(TemplateBlendType.HALF_IMPORTANCE_ADD_WEIGHT,
  132. description="[模板匹配] 融合模式"),
  133. tm_light_compensation: bool = Form(True, description="[模板匹配] 是否启用光照补偿"),
  134. ):
  135. """
  136. 上传一个文件夹内的所有图片进行拼接,将结果打包成一个ZIP压缩文件返回。
  137. - **files**: 选择一个文件夹中的所有图片进行上传。
  138. - **output_filename_base**: 为你的拼图任务命名,这个名字将作为ZIP包内图片的文件名。
  139. - **返回**: 一个ZIP压缩包,里面包含拼接好的单张图片。
  140. """
  141. if not files:
  142. raise HTTPException(status_code=400, detail="没有上传任何文件。")
  143. request_id = str(uuid.uuid4())
  144. session_dir = TEMP_DIR / request_id
  145. session_dir.mkdir()
  146. background_tasks.add_task(cleanup_temp_folder, session_dir)
  147. image_dir = session_dir / "images"
  148. image_dir.mkdir()
  149. for upload_file in files:
  150. file_path = image_dir / upload_file.filename
  151. with open(file_path, "wb") as buffer:
  152. shutil.copyfileobj(upload_file.file, buffer)
  153. single_output_dir = session_dir / "single_output"
  154. params = locals()
  155. stitched_image_path = _process_single_puzzle(image_dir, single_output_dir, method, params)
  156. if not stitched_image_path or not stitched_image_path.exists():
  157. raise HTTPException(status_code=500, detail=f"图片拼接失败,请检查服务器日志(请求ID: {request_id})。")
  158. # --- 将单个结果打包成ZIP ---
  159. batch_output_dir = session_dir / "batch_output"
  160. batch_output_dir.mkdir()
  161. # 将成功的结果移到最终的批量输出目录
  162. target_path = batch_output_dir / f"{output_filename_base}.jpg"
  163. shutil.move(str(stitched_image_path), str(target_path))
  164. # 将最终结果打包成ZIP
  165. output_zip_path = session_dir / "stitched_result.zip"
  166. shutil.make_archive(str(output_zip_path.with_suffix('')), 'zip', batch_output_dir)
  167. return FileResponse(
  168. path=output_zip_path,
  169. filename=f"{output_filename_base}_stitched.zip",
  170. media_type='application/zip'
  171. )