| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208 |
- import os
- import shutil
- import uuid
- import zipfile
- import logging
- from pathlib import Path
- from fastapi import APIRouter, UploadFile, File, Form, HTTPException, BackgroundTasks
- from fastapi.responses import FileResponse, JSONResponse
- # 导入我们的核心逻辑和数据模型
- from app.core import stitcher_keypoint, stitcher_template
- from app.schemas import StitchingMethod, KeypointFeatureDetector, KeypointBlendType, TemplateBlendType
- from utils.utils import cleanup_temp_folder
- router = APIRouter(prefix="/stitch", tags=['拼图'])
- TEMP_DIR = Path("_temp_work")
- TEMP_DIR.mkdir(exist_ok=True)
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
- @router.post("/", response_class=FileResponse, summary="单个拼图接口")
- async def stitch_single_puzzle(
- background_tasks: BackgroundTasks,
- zip_file: UploadFile = File(..., description="包含一个文件夹的ZIP压缩包,文件夹内有24张小图。"),
- # --- 通用参数 ---
- method: StitchingMethod = Form(StitchingMethod.TEMPLATE_MATCH, description="选择拼图方法"),
- num_cols: int = Form(4, description="拼图的列数"),
- num_rows: int = Form(6, description="拼图的行数"),
- overlap_h: int = Form(405, description="预估的水平重叠像素"),
- overlap_v: int = Form(440, description="预估的垂直重叠像素"),
- # --- 点匹配法 (key_point) 特定参数 ---
- kp_blend_type: KeypointBlendType = Form(KeypointBlendType.COMBINE, description="[点匹配] 融合模式"),
- kp_feature_detector: KeypointFeatureDetector = Form(KeypointFeatureDetector.SIFT,
- description="[点匹配] 特征检测器"),
- # --- 模板匹配法 (template_match) 特定参数 ---
- tm_blend_type: TemplateBlendType = Form(TemplateBlendType.HALF_IMPORTANCE_ADD_WEIGHT,
- description="[模板匹配] 融合模式"),
- tm_light_compensation: bool = Form(True, description="[模板匹配] 是否启用光照补偿")
- ):
- """
- 上传一个包含24张小图的文件夹的ZIP压缩包,接口会将其拼接成一张大图并返回。
- - **zip_file**: 必须是.zip格式,内部应仅包含一个文件夹,该文件夹内含所有待拼接的.jpg图片。
- - **返回**: 拼接成功后,返回拼接好的图片文件,文件名与ZIP包内的文件夹名相同。
- """
- request_id = str(uuid.uuid4())
- session_dir = TEMP_DIR / request_id
- session_dir.mkdir()
- background_tasks.add_task(cleanup_temp_folder, session_dir)
- zip_path = session_dir / zip_file.filename
- with open(zip_path, "wb") as buffer:
- shutil.copyfileobj(zip_file.file, buffer)
- extracted_dir = session_dir / "extracted"
- extracted_dir.mkdir()
- try:
- with zipfile.ZipFile(zip_path, 'r') as zf:
- zf.extractall(extracted_dir)
- except zipfile.BadZipFile:
- raise HTTPException(status_code=400, detail="上传的文件不是有效的ZIP格式。")
- image_dir = extracted_dir
- output_dir = session_dir / "output"
- # 根据选择的方法调用不同的拼接函数
- stitched_image_path = None
- if method == StitchingMethod.KEY_POINT:
- stitched_image_path = stitcher_keypoint.stitch_img(
- IMAGE_DIR=image_dir, OUTPUT_DIR=output_dir, NUM_COLS=num_cols, NUM_ROWS=num_rows,
- ESTIMATE_OVERLAP_HORIZONTAL_PIXELS=overlap_h, ESTIMATE_OVERLAP_VERTICAL_PIXELS=overlap_v,
- BLEND_TYPE=kp_blend_type.value, FeatureDetector=kp_feature_detector.value,
- DEBUG_MODE=False
- )
- elif method == StitchingMethod.TEMPLATE_MATCH:
- stitched_image_path = stitcher_template.stitch_img(
- IMAGE_DIR=image_dir, OUTPUT_DIR=output_dir, NUM_COLS=num_cols, NUM_ROWS=num_rows,
- ESTIMATE_OVERLAP_HORIZONTAL_PIXELS=overlap_h, ESTIMATE_OVERLAP_VERTICAL_PIXELS=overlap_v,
- BLEND_TYPE=tm_blend_type.value, LIGHT_COMPENSATION=tm_light_compensation,
- DEBUG_MODE=False
- )
- if not stitched_image_path or not stitched_image_path.exists():
- raise HTTPException(status_code=500, detail=f"图片拼接失败,请检查服务器日志(请求ID: {request_id})。")
- # 使用原始文件夹名命名输出图片
- final_filename = f"{image_dir.name}.jpg"
- final_filepath = stitched_image_path.rename(stitched_image_path.parent / final_filename)
- return FileResponse(
- path=final_filepath,
- filename=final_filename,
- media_type='image/jpeg'
- )
- @router.post("/batch", response_class=FileResponse, summary="批量拼图接口")
- async def stitch_batch_puzzles(
- background_tasks: BackgroundTasks,
- zip_file: UploadFile = File(..., description="包含多个拼图文件夹的ZIP压缩包。"),
- # 参数与单个拼图接口相同
- method: StitchingMethod = Form(StitchingMethod.TEMPLATE_MATCH, description="选择拼图方法"),
- num_cols: int = Form(4, description="拼图的列数"),
- num_rows: int = Form(6, description="拼图的行数"),
- overlap_h: int = Form(405, description="预估的水平重叠像素"),
- overlap_v: int = Form(440, description="预估的垂直重叠像素"),
- kp_blend_type: KeypointBlendType = Form(KeypointBlendType.COMBINE, description="[点匹配] 融合模式"),
- kp_feature_detector: KeypointFeatureDetector = Form(KeypointFeatureDetector.SIFT,
- description="[点匹配] 特征检测器"),
- tm_blend_type: TemplateBlendType = Form(TemplateBlendType.HALF_IMPORTANCE_ADD_WEIGHT,
- description="[模板匹配] 融合模式"),
- tm_light_compensation: bool = Form(True, description="[模板匹配] 是否启用光照补偿"),
- ):
- """
- 上传一个包含多个拼图文件夹的ZIP压缩包,接口会处理所有文件夹,并将结果打包成一个新的ZIP返回。
- - **zip_file**: 必须是.zip格式,内部可以有多个文件夹,每个文件夹都包含待拼接的图片。
- - **返回**: 一个ZIP压缩包,里面是所有拼接好的图片,每张图片以其对应的原文件夹名命名。
- """
- request_id = str(uuid.uuid4())
- session_dir = TEMP_DIR / request_id
- session_dir.mkdir()
- background_tasks.add_task(cleanup_temp_folder, session_dir)
- zip_path = session_dir / zip_file.filename
- with open(zip_path, "wb") as buffer:
- shutil.copyfileobj(zip_file.file, buffer)
- extracted_dir = session_dir / "extracted"
- extracted_dir.mkdir()
- try:
- with zipfile.ZipFile(zip_path, 'r') as zf:
- zf.extractall(extracted_dir)
- except zipfile.BadZipFile:
- raise HTTPException(status_code=400, detail="上传的文件不是有效的ZIP格式。")
- puzzle_folders = [d for d in extracted_dir.iterdir() if d.is_dir()]
- if not puzzle_folders:
- raise HTTPException(status_code=400, detail="ZIP包中未找到任何拼图文件夹。")
- batch_output_dir = session_dir / "batch_output"
- batch_output_dir.mkdir()
- processed_count = 0
- failed_folders = []
- for image_dir in puzzle_folders:
- logging.info(f"--- 开始处理批量任务中的文件夹: {image_dir.name} ---")
- # 为每个子任务创建一个独立的输出目录
- single_output_dir = session_dir / "single_output"
- if single_output_dir.exists():
- shutil.rmtree(single_output_dir) # 清理上一次循环的输出
- stitched_image_path = None
- try:
- if method == StitchingMethod.KEY_POINT:
- stitched_image_path = stitcher_keypoint.stitch_img(
- IMAGE_DIR=image_dir, OUTPUT_DIR=single_output_dir,
- NUM_COLS=num_cols, NUM_ROWS=num_rows,
- ESTIMATE_OVERLAP_HORIZONTAL_PIXELS=overlap_h,
- ESTIMATE_OVERLAP_VERTICAL_PIXELS=overlap_v,
- BLEND_TYPE=kp_blend_type.value, FeatureDetector=kp_feature_detector.value,
- DEBUG_MODE=False
- )
- elif method == StitchingMethod.TEMPLATE_MATCH:
- stitched_image_path = stitcher_template.stitch_img(
- IMAGE_DIR=image_dir, OUTPUT_DIR=single_output_dir,
- NUM_COLS=num_cols, NUM_ROWS=num_rows,
- ESTIMATE_OVERLAP_HORIZONTAL_PIXELS=overlap_h,
- ESTIMATE_OVERLAP_VERTICAL_PIXELS=overlap_v,
- BLEND_TYPE=tm_blend_type.value, LIGHT_COMPENSATION=tm_light_compensation,
- DEBUG_MODE=False
- )
- if stitched_image_path and stitched_image_path.exists():
- # 将成功的结果移到最终的批量输出目录
- target_path = batch_output_dir / f"{image_dir.name}.jpg"
- shutil.move(str(stitched_image_path), str(target_path))
- processed_count += 1
- else:
- logging.error(f"文件夹 {image_dir.name} 拼接失败。")
- failed_folders.append(image_dir.name)
- except Exception as e:
- logging.error(f"处理文件夹 {image_dir.name} 时发生严重错误: {e}")
- failed_folders.append(image_dir.name)
- if processed_count == 0:
- detail_msg = f"所有 {len(puzzle_folders)} 个文件夹都拼接失败。失败列表: {failed_folders}"
- raise HTTPException(status_code=500, detail=detail_msg)
- # 将最终结果打包成ZIP
- output_zip_path = session_dir / "stitched_results.zip"
- shutil.make_archive(str(output_zip_path.with_suffix('')), 'zip', batch_output_dir)
- if failed_folders:
- logging.warning(f"批量任务完成,但有 {len(failed_folders)} 个文件夹失败: {failed_folders}")
- # 可以在响应头中添加自定义信息来通知客户端部分失败
- # response.headers["X-Failed-Folders"] = ",".join(failed_folders)
- return FileResponse(
- path=output_zip_path,
- filename="stitched_results.zip",
- media_type='application/zip'
- )
|