|
|
@@ -0,0 +1,206 @@
|
|
|
+import os
|
|
|
+import shutil
|
|
|
+import uuid
|
|
|
+import zipfile
|
|
|
+import logging
|
|
|
+from pathlib import Path
|
|
|
+from typing import List # 导入 List 类型
|
|
|
+
|
|
|
+from fastapi import APIRouter, UploadFile, File, Form, HTTPException, BackgroundTasks
|
|
|
+from fastapi.responses import FileResponse, JSONResponse
|
|
|
+
|
|
|
+# 导入我们的核心逻辑和数据模型
|
|
|
+from app.core import stitcher_keypoint, stitcher_template
|
|
|
+from app.schemas import StitchingMethod, KeypointFeatureDetector, KeypointBlendType, TemplateBlendType
|
|
|
+
|
|
|
+from utils.utils import cleanup_temp_folder
|
|
|
+
|
|
|
+
|
|
|
+router = APIRouter(prefix="/stitch", tags=['拼图'])
|
|
|
+
|
|
|
+TEMP_DIR = Path("_temp_work")
|
|
|
+TEMP_DIR.mkdir(exist_ok=True)
|
|
|
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
+
|
|
|
+
|
|
|
+# --- 内部辅助函数,用于处理单个拼图任务,避免代码重复 ---
|
|
|
+def _process_single_puzzle(
|
|
|
+ image_dir: Path,
|
|
|
+ output_dir: Path,
|
|
|
+ method: StitchingMethod,
|
|
|
+ params: dict
|
|
|
+) -> Path | None:
|
|
|
+ """
|
|
|
+ 处理单个拼图任务的核心逻辑。
|
|
|
+
|
|
|
+ :param image_dir: 包含所有小图的输入目录。
|
|
|
+ :param output_dir: 存放拼接结果的输出目录。
|
|
|
+ :param method: 拼图方法。
|
|
|
+ :param params: 包含所有拼图参数的字典。
|
|
|
+ :return: 成功则返回拼接后图片的路径,否则返回 None。
|
|
|
+ """
|
|
|
+ output_dir.mkdir(exist_ok=True)
|
|
|
+ stitched_image_path = None
|
|
|
+ try:
|
|
|
+ if method == StitchingMethod.KEY_POINT:
|
|
|
+ stitched_image_path = stitcher_keypoint.stitch_img(
|
|
|
+ IMAGE_DIR=image_dir, OUTPUT_DIR=output_dir,
|
|
|
+ NUM_COLS=params["num_cols"], NUM_ROWS=params["num_rows"],
|
|
|
+ ESTIMATE_OVERLAP_HORIZONTAL_PIXELS=params["overlap_h"],
|
|
|
+ ESTIMATE_OVERLAP_VERTICAL_PIXELS=params["overlap_v"],
|
|
|
+ BLEND_TYPE=params["kp_blend_type"].value,
|
|
|
+ FeatureDetector=params["kp_feature_detector"].value,
|
|
|
+ DEBUG_MODE=False
|
|
|
+ )
|
|
|
+ elif method == StitchingMethod.TEMPLATE_MATCH:
|
|
|
+ stitched_image_path = stitcher_template.stitch_img(
|
|
|
+ IMAGE_DIR=image_dir, OUTPUT_DIR=output_dir,
|
|
|
+ NUM_COLS=params["num_cols"], NUM_ROWS=params["num_rows"],
|
|
|
+ ESTIMATE_OVERLAP_HORIZONTAL_PIXELS=params["overlap_h"],
|
|
|
+ ESTIMATE_OVERLAP_VERTICAL_PIXELS=params["overlap_v"],
|
|
|
+ BLEND_TYPE=params["tm_blend_type"].value,
|
|
|
+ LIGHT_COMPENSATION=params["tm_light_compensation"],
|
|
|
+ DEBUG_MODE=False
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ logging.error(f"处理文件夹 {image_dir.name} 时发生错误: {e}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ return stitched_image_path
|
|
|
+
|
|
|
+
|
|
|
+# --- 新的单个拼图接口 (直接上传图片文件夹) ---
|
|
|
+@router.post("/from-folder", response_class=FileResponse, summary="单个拼图接口 (从文件夹上传)")
|
|
|
+async def stitch_single_from_folder(
|
|
|
+ background_tasks: BackgroundTasks,
|
|
|
+ files: List[UploadFile] = File(..., description="一个文件夹中的所有待拼接图片。"),
|
|
|
+ output_filename_base: str = Form(..., description="输出图片的基础名称(不含扩展名),例如 'puzzle_A'。"),
|
|
|
+ # --- 通用参数 ---
|
|
|
+ method: StitchingMethod = Form(StitchingMethod.TEMPLATE_MATCH, description="选择拼图方法"),
|
|
|
+ num_cols: int = Form(4, description="拼图的列数"),
|
|
|
+ num_rows: int = Form(6, description="拼图的行数"),
|
|
|
+ overlap_h: int = Form(405, description="预估的水平重叠像素"),
|
|
|
+ overlap_v: int = Form(440, description="预估的垂直重叠像素"),
|
|
|
+ # --- 点匹配法 (key_point) 特定参数 ---
|
|
|
+ kp_blend_type: KeypointBlendType = Form(KeypointBlendType.COMBINE, description="[点匹配] 融合模式"),
|
|
|
+ kp_feature_detector: KeypointFeatureDetector = Form(KeypointFeatureDetector.SIFT,
|
|
|
+ description="[点匹配] 特征检测器"),
|
|
|
+ # --- 模板匹配法 (template_match) 特定参数 ---
|
|
|
+ tm_blend_type: TemplateBlendType = Form(TemplateBlendType.HALF_IMPORTANCE_ADD_WEIGHT,
|
|
|
+ description="[模板匹配] 融合模式"),
|
|
|
+ tm_light_compensation: bool = Form(True, description="[模板匹配] 是否启用光照补偿")
|
|
|
+):
|
|
|
+ """
|
|
|
+ 上传一个文件夹内的所有图片进行拼接,直接返回拼接好的单张大图。
|
|
|
+
|
|
|
+ - **files**: 选择一个文件夹中的所有图片进行上传。
|
|
|
+ - **output_filename_base**: 为你的拼图任务命名,这个名字将作为返回图片的文件名。
|
|
|
+ - **返回**: 拼接成功后,返回拼接好的图片文件。
|
|
|
+ """
|
|
|
+ if not files:
|
|
|
+ raise HTTPException(status_code=400, detail="没有上传任何文件。")
|
|
|
+
|
|
|
+ request_id = str(uuid.uuid4())
|
|
|
+ session_dir = TEMP_DIR / request_id
|
|
|
+ session_dir.mkdir()
|
|
|
+ background_tasks.add_task(cleanup_temp_folder, session_dir)
|
|
|
+
|
|
|
+ # 创建用于存放上传图片的临时目录
|
|
|
+ image_dir = session_dir / "images"
|
|
|
+ image_dir.mkdir()
|
|
|
+
|
|
|
+ # 保存所有上传的文件
|
|
|
+ for upload_file in files:
|
|
|
+ file_path = image_dir / upload_file.filename
|
|
|
+ with open(file_path, "wb") as buffer:
|
|
|
+ shutil.copyfileobj(upload_file.file, buffer)
|
|
|
+
|
|
|
+ output_dir = session_dir / "output"
|
|
|
+
|
|
|
+ # 将所有参数打包到一个字典中,方便传递
|
|
|
+ params = locals()
|
|
|
+
|
|
|
+ # 调用核心处理函数
|
|
|
+ stitched_image_path = _process_single_puzzle(image_dir, output_dir, method, params)
|
|
|
+
|
|
|
+ if not stitched_image_path or not stitched_image_path.exists():
|
|
|
+ raise HTTPException(status_code=500, detail=f"图片拼接失败,请检查服务器日志(请求ID: {request_id})。")
|
|
|
+
|
|
|
+ # 使用用户提供的基础名称命名输出图片
|
|
|
+ final_filename = f"{output_filename_base}.jpg"
|
|
|
+ final_filepath = stitched_image_path.rename(stitched_image_path.parent / final_filename)
|
|
|
+
|
|
|
+ return FileResponse(
|
|
|
+ path=final_filepath,
|
|
|
+ filename=final_filename,
|
|
|
+ media_type='image/jpeg'
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+# --- 新的批量拼图接口 (上传一个拼图文件夹,返回ZIP) ---
|
|
|
+@router.post("/batch/from-folder", response_class=FileResponse, summary="批量拼图接口 (单文件夹上传, ZIP返回)")
|
|
|
+async def stitch_batch_from_folder(
|
|
|
+ background_tasks: BackgroundTasks,
|
|
|
+ files: List[UploadFile] = File(..., description="一个文件夹中的所有待拼接图片。"),
|
|
|
+ output_filename_base: str = Form(..., description="输出图片的基础名称(不含扩展名),例如 'puzzle_A'。"),
|
|
|
+ # --- 参数与单个拼图接口相同 ---
|
|
|
+ method: StitchingMethod = Form(StitchingMethod.TEMPLATE_MATCH, description="选择拼图方法"),
|
|
|
+ num_cols: int = Form(4, description="拼图的列数"),
|
|
|
+ num_rows: int = Form(6, description="拼图的行数"),
|
|
|
+ overlap_h: int = Form(405, description="预估的水平重叠像素"),
|
|
|
+ overlap_v: int = Form(440, description="预估的垂直重叠像素"),
|
|
|
+ kp_blend_type: KeypointBlendType = Form(KeypointBlendType.COMBINE, description="[点匹配] 融合模式"),
|
|
|
+ kp_feature_detector: KeypointFeatureDetector = Form(KeypointFeatureDetector.SIFT,
|
|
|
+ description="[点匹配] 特征检测器"),
|
|
|
+ tm_blend_type: TemplateBlendType = Form(TemplateBlendType.HALF_IMPORTANCE_ADD_WEIGHT,
|
|
|
+ description="[模板匹配] 融合模式"),
|
|
|
+ tm_light_compensation: bool = Form(True, description="[模板匹配] 是否启用光照补偿"),
|
|
|
+):
|
|
|
+ """
|
|
|
+ 上传一个文件夹内的所有图片进行拼接,将结果打包成一个ZIP压缩文件返回。
|
|
|
+
|
|
|
+ - **files**: 选择一个文件夹中的所有图片进行上传。
|
|
|
+ - **output_filename_base**: 为你的拼图任务命名,这个名字将作为ZIP包内图片的文件名。
|
|
|
+ - **返回**: 一个ZIP压缩包,里面包含拼接好的单张图片。
|
|
|
+ """
|
|
|
+ if not files:
|
|
|
+ raise HTTPException(status_code=400, detail="没有上传任何文件。")
|
|
|
+
|
|
|
+ request_id = str(uuid.uuid4())
|
|
|
+ session_dir = TEMP_DIR / request_id
|
|
|
+ session_dir.mkdir()
|
|
|
+ background_tasks.add_task(cleanup_temp_folder, session_dir)
|
|
|
+
|
|
|
+ image_dir = session_dir / "images"
|
|
|
+ image_dir.mkdir()
|
|
|
+
|
|
|
+ for upload_file in files:
|
|
|
+ file_path = image_dir / upload_file.filename
|
|
|
+ with open(file_path, "wb") as buffer:
|
|
|
+ shutil.copyfileobj(upload_file.file, buffer)
|
|
|
+
|
|
|
+ single_output_dir = session_dir / "single_output"
|
|
|
+ params = locals()
|
|
|
+
|
|
|
+ stitched_image_path = _process_single_puzzle(image_dir, single_output_dir, method, params)
|
|
|
+
|
|
|
+ if not stitched_image_path or not stitched_image_path.exists():
|
|
|
+ raise HTTPException(status_code=500, detail=f"图片拼接失败,请检查服务器日志(请求ID: {request_id})。")
|
|
|
+
|
|
|
+ # --- 将单个结果打包成ZIP ---
|
|
|
+ batch_output_dir = session_dir / "batch_output"
|
|
|
+ batch_output_dir.mkdir()
|
|
|
+
|
|
|
+ # 将成功的结果移到最终的批量输出目录
|
|
|
+ target_path = batch_output_dir / f"{output_filename_base}.jpg"
|
|
|
+ shutil.move(str(stitched_image_path), str(target_path))
|
|
|
+
|
|
|
+ # 将最终结果打包成ZIP
|
|
|
+ output_zip_path = session_dir / "stitched_result.zip"
|
|
|
+ shutil.make_archive(str(output_zip_path.with_suffix('')), 'zip', batch_output_dir)
|
|
|
+
|
|
|
+ return FileResponse(
|
|
|
+ path=output_zip_path,
|
|
|
+ filename=f"{output_filename_base}_stitched.zip",
|
|
|
+ media_type='application/zip'
|
|
|
+ )
|