| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206 |
- import os
- import shutil
- import uuid
- import zipfile
- import logging
- from pathlib import Path
- from typing import List # 导入 List 类型
- from fastapi import APIRouter, UploadFile, File, Form, HTTPException, BackgroundTasks
- from fastapi.responses import FileResponse, JSONResponse
- # 导入我们的核心逻辑和数据模型
- from app.core import stitcher_keypoint, stitcher_template
- from app.schemas import StitchingMethod, KeypointFeatureDetector, KeypointBlendType, TemplateBlendType
- from utils.utils import cleanup_temp_folder
- router = APIRouter(prefix="/stitch", tags=['拼图'])
- TEMP_DIR = Path("_temp_work")
- TEMP_DIR.mkdir(exist_ok=True)
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
- # --- 内部辅助函数,用于处理单个拼图任务,避免代码重复 ---
- def _process_single_puzzle(
- image_dir: Path,
- output_dir: Path,
- method: StitchingMethod,
- params: dict
- ) -> Path | None:
- """
- 处理单个拼图任务的核心逻辑。
- :param image_dir: 包含所有小图的输入目录。
- :param output_dir: 存放拼接结果的输出目录。
- :param method: 拼图方法。
- :param params: 包含所有拼图参数的字典。
- :return: 成功则返回拼接后图片的路径,否则返回 None。
- """
- output_dir.mkdir(exist_ok=True)
- stitched_image_path = None
- try:
- if method == StitchingMethod.KEY_POINT:
- stitched_image_path = stitcher_keypoint.stitch_img(
- IMAGE_DIR=image_dir, OUTPUT_DIR=output_dir,
- NUM_COLS=params["num_cols"], NUM_ROWS=params["num_rows"],
- ESTIMATE_OVERLAP_HORIZONTAL_PIXELS=params["overlap_h"],
- ESTIMATE_OVERLAP_VERTICAL_PIXELS=params["overlap_v"],
- BLEND_TYPE=params["kp_blend_type"].value,
- FeatureDetector=params["kp_feature_detector"].value,
- DEBUG_MODE=False
- )
- elif method == StitchingMethod.TEMPLATE_MATCH:
- stitched_image_path = stitcher_template.stitch_img(
- IMAGE_DIR=image_dir, OUTPUT_DIR=output_dir,
- NUM_COLS=params["num_cols"], NUM_ROWS=params["num_rows"],
- ESTIMATE_OVERLAP_HORIZONTAL_PIXELS=params["overlap_h"],
- ESTIMATE_OVERLAP_VERTICAL_PIXELS=params["overlap_v"],
- BLEND_TYPE=params["tm_blend_type"].value,
- LIGHT_COMPENSATION=params["tm_light_compensation"],
- DEBUG_MODE=False
- )
- except Exception as e:
- logging.error(f"处理文件夹 {image_dir.name} 时发生错误: {e}")
- return None
- return stitched_image_path
- # --- 新的单个拼图接口 (直接上传图片文件夹) ---
- @router.post("/from-folder", response_class=FileResponse, summary="单个拼图接口 (从文件夹上传)")
- async def stitch_single_from_folder(
- background_tasks: BackgroundTasks,
- files: List[UploadFile] = File(..., description="一个文件夹中的所有待拼接图片。"),
- output_filename_base: str = Form(..., description="输出图片的基础名称(不含扩展名),例如 'puzzle_A'。"),
- # --- 通用参数 ---
- method: StitchingMethod = Form(StitchingMethod.TEMPLATE_MATCH, description="选择拼图方法"),
- num_cols: int = Form(4, description="拼图的列数"),
- num_rows: int = Form(6, description="拼图的行数"),
- overlap_h: int = Form(405, description="预估的水平重叠像素"),
- overlap_v: int = Form(440, description="预估的垂直重叠像素"),
- # --- 点匹配法 (key_point) 特定参数 ---
- kp_blend_type: KeypointBlendType = Form(KeypointBlendType.COMBINE, description="[点匹配] 融合模式"),
- kp_feature_detector: KeypointFeatureDetector = Form(KeypointFeatureDetector.SIFT,
- description="[点匹配] 特征检测器"),
- # --- 模板匹配法 (template_match) 特定参数 ---
- tm_blend_type: TemplateBlendType = Form(TemplateBlendType.HALF_IMPORTANCE_ADD_WEIGHT,
- description="[模板匹配] 融合模式"),
- tm_light_compensation: bool = Form(True, description="[模板匹配] 是否启用光照补偿")
- ):
- """
- 上传一个文件夹内的所有图片进行拼接,直接返回拼接好的单张大图。
- - **files**: 选择一个文件夹中的所有图片进行上传。
- - **output_filename_base**: 为你的拼图任务命名,这个名字将作为返回图片的文件名。
- - **返回**: 拼接成功后,返回拼接好的图片文件。
- """
- if not files:
- raise HTTPException(status_code=400, detail="没有上传任何文件。")
- request_id = str(uuid.uuid4())
- session_dir = TEMP_DIR / request_id
- session_dir.mkdir()
- background_tasks.add_task(cleanup_temp_folder, session_dir)
- # 创建用于存放上传图片的临时目录
- image_dir = session_dir / "images"
- image_dir.mkdir()
- # 保存所有上传的文件
- for upload_file in files:
- file_path = image_dir / upload_file.filename
- with open(file_path, "wb") as buffer:
- shutil.copyfileobj(upload_file.file, buffer)
- output_dir = session_dir / "output"
- # 将所有参数打包到一个字典中,方便传递
- params = locals()
- # 调用核心处理函数
- stitched_image_path = _process_single_puzzle(image_dir, output_dir, method, params)
- if not stitched_image_path or not stitched_image_path.exists():
- raise HTTPException(status_code=500, detail=f"图片拼接失败,请检查服务器日志(请求ID: {request_id})。")
- # 使用用户提供的基础名称命名输出图片
- final_filename = f"{output_filename_base}.jpg"
- final_filepath = stitched_image_path.rename(stitched_image_path.parent / final_filename)
- return FileResponse(
- path=final_filepath,
- filename=final_filename,
- media_type='image/jpeg'
- )
- # --- 新的批量拼图接口 (上传一个拼图文件夹,返回ZIP) ---
- @router.post("/batch/from-folder", response_class=FileResponse, summary="批量拼图接口 (单文件夹上传, ZIP返回)")
- async def stitch_batch_from_folder(
- background_tasks: BackgroundTasks,
- files: List[UploadFile] = File(..., description="一个文件夹中的所有待拼接图片。"),
- output_filename_base: str = Form(..., description="输出图片的基础名称(不含扩展名),例如 'puzzle_A'。"),
- # --- 参数与单个拼图接口相同 ---
- method: StitchingMethod = Form(StitchingMethod.TEMPLATE_MATCH, description="选择拼图方法"),
- num_cols: int = Form(4, description="拼图的列数"),
- num_rows: int = Form(6, description="拼图的行数"),
- overlap_h: int = Form(405, description="预估的水平重叠像素"),
- overlap_v: int = Form(440, description="预估的垂直重叠像素"),
- kp_blend_type: KeypointBlendType = Form(KeypointBlendType.COMBINE, description="[点匹配] 融合模式"),
- kp_feature_detector: KeypointFeatureDetector = Form(KeypointFeatureDetector.SIFT,
- description="[点匹配] 特征检测器"),
- tm_blend_type: TemplateBlendType = Form(TemplateBlendType.HALF_IMPORTANCE_ADD_WEIGHT,
- description="[模板匹配] 融合模式"),
- tm_light_compensation: bool = Form(True, description="[模板匹配] 是否启用光照补偿"),
- ):
- """
- 上传一个文件夹内的所有图片进行拼接,将结果打包成一个ZIP压缩文件返回。
- - **files**: 选择一个文件夹中的所有图片进行上传。
- - **output_filename_base**: 为你的拼图任务命名,这个名字将作为ZIP包内图片的文件名。
- - **返回**: 一个ZIP压缩包,里面包含拼接好的单张图片。
- """
- if not files:
- raise HTTPException(status_code=400, detail="没有上传任何文件。")
- request_id = str(uuid.uuid4())
- session_dir = TEMP_DIR / request_id
- session_dir.mkdir()
- background_tasks.add_task(cleanup_temp_folder, session_dir)
- image_dir = session_dir / "images"
- image_dir.mkdir()
- for upload_file in files:
- file_path = image_dir / upload_file.filename
- with open(file_path, "wb") as buffer:
- shutil.copyfileobj(upload_file.file, buffer)
- single_output_dir = session_dir / "single_output"
- params = locals()
- stitched_image_path = _process_single_puzzle(image_dir, single_output_dir, method, params)
- if not stitched_image_path or not stitched_image_path.exists():
- raise HTTPException(status_code=500, detail=f"图片拼接失败,请检查服务器日志(请求ID: {request_id})。")
- # --- 将单个结果打包成ZIP ---
- batch_output_dir = session_dir / "batch_output"
- batch_output_dir.mkdir()
- # 将成功的结果移到最终的批量输出目录
- target_path = batch_output_dir / f"{output_filename_base}.jpg"
- shutil.move(str(stitched_image_path), str(target_path))
- # 将最终结果打包成ZIP
- output_zip_path = session_dir / "stitched_result.zip"
- shutil.make_archive(str(output_zip_path.with_suffix('')), 'zip', batch_output_dir)
- return FileResponse(
- path=output_zip_path,
- filename=f"{output_filename_base}_stitched.zip",
- media_type='application/zip'
- )
|