Forráskód Böngészése

文件夹上传拼图

AnlaAnla 9 hónapja
szülő
commit
85a3c021ad
5 módosított fájl, 451 hozzáadás és 1 törlés
  1. 166 0
      Test/API测试_v2.py
  2. 5 1
      Test/test01.py
  3. 72 0
      Test/test02.py
  4. 206 0
      app/api/stitch_v2.py
  5. 2 0
      app/main.py

+ 166 - 0
Test/API测试_v2.py

@@ -0,0 +1,166 @@
+import time
+
+import requests
+import os
+import shutil
+import zipfile
+from PIL import Image, ImageDraw
+
+# --- 配置 ---
+# 请确保你的 FastAPI 服务器正在运行,并修改此处的地址和端口
+BASE_URL = "http://127.0.0.1:7745/api"  # 假设您的API前缀是 /api, 如果不是,请修改
+STITCH_API_PREFIX = "/stitch"
+SINGLE_FOLDER_URL = f"{BASE_URL}{STITCH_API_PREFIX}/from-folder"
+BATCH_FOLDER_URL = f"{BASE_URL}{STITCH_API_PREFIX}/batch/from-folder"
+
+# 用于存放自动生成的测试图片的临时目录
+TEST_DATA_DIR = "temp_api_test_data"
+
+
+# --- 测试函数 1:新的单个拼图接口 (从文件夹上传) ---
+def single_puzzle_from_folder_api(image_folder_path: str):
+    """
+    测试 /stitch/from-folder 接口 (单个拼图)
+    """
+    print(f"--- 1. 开始测试: 单个拼图接口 (从文件夹上传) ---")
+    print(f"使用文件夹: {image_folder_path}")
+
+    # 1. 准备请求数据
+    # 从文件夹路径推断输出文件名
+    output_filename_base = os.path.basename(image_folder_path)
+    form_data = {
+        'output_filename_base': output_filename_base,
+        'method': 'template_match',
+        'num_cols': 4,
+        'num_rows': 6,
+        'overlap_h': 405,
+        'overlap_v': 440,
+        'tm_blend_type': 'half_importance_add_weight',
+        'tm_light_compensation': True,
+    }
+
+    # 2. 准备要上传的文件列表
+    # 'files' 是 FastAPI 接口中定义的参数名
+    # requests 要求格式为: [('field_name', ('filename', file_object, 'content_type')), ...]
+    files_to_send = []
+    file_objects = []
+    try:
+        for filename in os.listdir(image_folder_path):
+            if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
+                file_path = os.path.join(image_folder_path, filename)
+                f = open(file_path, 'rb')
+                file_objects.append(f)
+                files_to_send.append(('files', (filename, f, 'image/jpeg')))
+
+        if not files_to_send:
+            print("❌ 错误: 在文件夹中未找到可上传的图片。")
+            return
+
+        # 3. 发送 POST 请求
+        print(f"向服务器 {SINGLE_FOLDER_URL} 发送 {len(files_to_send)} 个文件...")
+        response = requests.post(SINGLE_FOLDER_URL, data=form_data, files=files_to_send, timeout=60)
+
+        # 4. 处理响应
+        print(f"服务器响应状态码: {response.status_code}")
+        if response.status_code == 200:
+            content_type = response.headers.get('content-type')
+            print(f"响应内容类型: {content_type}")
+            if 'image/jpeg' in content_type:
+                output_filename = f"stitched_single_{output_filename_base}.jpg"
+                with open(output_filename, "wb") as f:
+                    f.write(response.content)
+                print(f"✅ 成功! 拼接后的大图已保存为: {output_filename}")
+            else:
+                print(f"❌ 失败! 期望得到 'image/jpeg',但收到了 '{content_type}'")
+        else:
+            print(f"❌ 请求失败! 错误信息: {response.text}")
+
+    except requests.exceptions.RequestException as e:
+        print(f"❌ 请求异常! 无法连接到服务器: {e}")
+    finally:
+        # 确保所有打开的文件都被关闭
+        for f in file_objects:
+            f.close()
+
+    print("-" * 50 + "\n")
+
+
+# --- 测试函数 2:新的批量拼图接口 (从文件夹上传, ZIP返回) ---
+def batch_puzzle_from_folder_api(image_folder_path: str):
+    """
+    测试 /stitch/batch/from-folder 接口 (批量拼图)
+    """
+    print(f"--- 2. 开始测试: 批量拼图接口 (单文件夹上传, ZIP返回) ---")
+    print(f"使用文件夹: {image_folder_path}")
+
+    # 1. 准备请求数据
+    output_filename_base = os.path.basename(image_folder_path)
+    form_data = {
+        'output_filename_base': output_filename_base,
+        'method': 'template_match',
+        'num_cols': 4,
+        'num_rows': 6,
+        'overlap_h': 405,
+        'overlap_v': 440,
+    }
+
+    # 2. 准备文件列表 (与单个接口的逻辑相同)
+    files_to_send = []
+    file_objects = []
+    try:
+        for filename in os.listdir(image_folder_path):
+            if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
+                file_path = os.path.join(image_folder_path, filename)
+                f = open(file_path, 'rb')
+                file_objects.append(f)
+                files_to_send.append(('files', (filename, f, 'image/jpeg')))
+
+        if not files_to_send:
+            print("❌ 错误: 在文件夹中未找到可上传的图片。")
+            return
+
+        # 3. 发送 POST 请求
+        print(f"向服务器 {BATCH_FOLDER_URL} 发送 {len(files_to_send)} 个文件...")
+        response = requests.post(BATCH_FOLDER_URL, data=form_data, files=files_to_send, timeout=60)
+
+        # 4. 处理响应
+        print(f"服务器响应状态码: {response.status_code}")
+        if response.status_code == 200:
+            content_type = response.headers.get('content-type')
+            print(f"响应内容类型: {content_type}")
+            if 'application/zip' in content_type:
+                output_filename = f"stitched_batch_{output_filename_base}.zip"
+                with open(output_filename, "wb") as f:
+                    f.write(response.content)
+                print(f"✅ 成功! 包含拼接结果的ZIP包已保存为: {output_filename}")
+                # (可选) 解压并检查结果
+                try:
+                    extract_dir = "batch_results_unzipped"
+                    if os.path.exists(extract_dir):
+                        shutil.rmtree(extract_dir)
+                    os.makedirs(extract_dir, exist_ok=True)
+                    with zipfile.ZipFile(output_filename, 'r') as zf:
+                        zf.extractall(extract_dir)
+                    print(f"  - 结果已自动解压到 '{extract_dir}' 文件夹,包含文件: {os.listdir(extract_dir)}")
+                except Exception as e:
+                    print(f"  - 解压返回的ZIP文件时出错: {e}")
+            else:
+                print(f"❌ 失败! 期望得到 'application/zip',但收到了 '{content_type}'")
+        else:
+            print(f"❌ 请求失败! 错误信息: {response.text}")
+
+    except requests.exceptions.RequestException as e:
+        print(f"❌ 请求异常! 无法连接到服务器: {e}")
+    finally:
+        for f in file_objects:
+            f.close()
+
+    print("-" * 50 + "\n")
+
+
+if __name__ == "__main__":
+    t1 = time.time()
+    single_puzzle_from_folder_api(r"C:\Code\ML\Project\StitchImageServer\temp\Input\_250801_1043_0001")
+    # batch_puzzle_from_folder_api(test_folder_2)
+    t2 = time.time()
+    print("cost: ", t2 - t1)

+ 5 - 1
Test/test01.py

@@ -1,9 +1,13 @@
 import shutil
+import time
 from pathlib import Path
 
 path = Path(__file__).parent.absolute()
 save_path = path.joinpath('123.zip')
 print(path)
 
-shutil.make_archive(str(save_path.with_suffix('')), 'zip', r"C:\Code\ML\Project\StitchImageServer\temp\output")
+t1 = time.time()
+shutil.make_archive(str(save_path.with_suffix('')), 'zip', r"C:\Code\ML\Project\StitchImageServer\temp\Input\_250801_1043_0001")
+t2 = time.time()
+print(t2-t1)
 print('end')

+ 72 - 0
Test/test02.py

@@ -0,0 +1,72 @@
+import os
+import shutil
+from typing import List
+
+import uvicorn
+from fastapi import FastAPI, File, UploadFile, HTTPException
+from fastapi.responses import HTMLResponse
+
+# 创建一个目标文件夹来存放上传的文件
+UPLOAD_DIRECTORY = "./uploads"
+if not os.path.exists(UPLOAD_DIRECTORY):
+    os.makedirs(UPLOAD_DIRECTORY)
+
+app = FastAPI()
+
+
+@app.post("/upload-folder/")
+async def upload_folder(files: List[UploadFile] = File(...)):
+    """
+    接收通过 webkitdirectory 上传的整个文件夹
+    """
+    saved_files = []
+    for file in files:
+        # file.filename 会包含从选定目录开始的相对路径
+        # 例如: "my_folder/data.csv" 或 "my_folder/images/pic.png"
+
+        # 安全性检查:防止路径遍历攻击 (e.g., "my_folder/../../etc/passwd")
+        if ".." in file.filename:
+            raise HTTPException(status_code=400, detail=f"Invalid filename: {file.filename}. Contains '..'")
+
+        # 在服务器上创建完整的目标路径
+        # os.path.join 会正确处理不同操作系统的路径分隔符
+        destination_path = os.path.join(UPLOAD_DIRECTORY, file.filename)
+
+        # 获取目标文件的目录路径
+        destination_dir = os.path.dirname(destination_path)
+
+        # 如果目录不存在,则创建它
+        if not os.path.exists(destination_dir):
+            os.makedirs(destination_dir)
+
+        try:
+            # 异步地将文件内容写入目标路径
+            with open(destination_path, "wb") as buffer:
+                shutil.copyfileobj(file.file, buffer)
+
+            saved_files.append(file.filename)
+        finally:
+            # 确保关闭文件
+            await file.close()
+
+    return {"message": f"Successfully uploaded {len(saved_files)} files", "filenames": saved_files}
+
+
+# 提供一个简单的 HTML 上传页面用于测试
+@app.get("/")
+async def main():
+    content = """
+    <body>
+    <h2>上传整个文件夹</h2>
+    <p>选择一个文件夹,其中的所有文件(包括子目录中的文件)都将被上传。</p>
+    <form action="/upload-folder/" enctype="multipart/form-data" method="post">
+    <input name="files" type="file" webkitdirectory directory multiple>
+    <input type="submit">
+    </form>
+    </body>
+    """
+    return HTMLResponse(content=content)
+
+
+if __name__ == "__main__":
+    uvicorn.run(app, host="0.0.0.0", port=8000)

+ 206 - 0
app/api/stitch_v2.py

@@ -0,0 +1,206 @@
+import os
+import shutil
+import uuid
+import zipfile
+import logging
+from pathlib import Path
+from typing import List  # 导入 List 类型
+
+from fastapi import APIRouter, UploadFile, File, Form, HTTPException, BackgroundTasks
+from fastapi.responses import FileResponse, JSONResponse
+
+# 导入我们的核心逻辑和数据模型
+from app.core import stitcher_keypoint, stitcher_template
+from app.schemas import StitchingMethod, KeypointFeatureDetector, KeypointBlendType, TemplateBlendType
+
+from utils.utils import cleanup_temp_folder
+
+
+router = APIRouter(prefix="/stitch", tags=['拼图'])
+
+TEMP_DIR = Path("_temp_work")
+TEMP_DIR.mkdir(exist_ok=True)
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
+
+
+# --- 内部辅助函数,用于处理单个拼图任务,避免代码重复 ---
+def _process_single_puzzle(
+    image_dir: Path,
+    output_dir: Path,
+    method: StitchingMethod,
+    params: dict
+) -> Path | None:
+    """
+    处理单个拼图任务的核心逻辑。
+
+    :param image_dir: 包含所有小图的输入目录。
+    :param output_dir: 存放拼接结果的输出目录。
+    :param method: 拼图方法。
+    :param params: 包含所有拼图参数的字典。
+    :return: 成功则返回拼接后图片的路径,否则返回 None。
+    """
+    output_dir.mkdir(exist_ok=True)
+    stitched_image_path = None
+    try:
+        if method == StitchingMethod.KEY_POINT:
+            stitched_image_path = stitcher_keypoint.stitch_img(
+                IMAGE_DIR=image_dir, OUTPUT_DIR=output_dir,
+                NUM_COLS=params["num_cols"], NUM_ROWS=params["num_rows"],
+                ESTIMATE_OVERLAP_HORIZONTAL_PIXELS=params["overlap_h"],
+                ESTIMATE_OVERLAP_VERTICAL_PIXELS=params["overlap_v"],
+                BLEND_TYPE=params["kp_blend_type"].value,
+                FeatureDetector=params["kp_feature_detector"].value,
+                DEBUG_MODE=False
+            )
+        elif method == StitchingMethod.TEMPLATE_MATCH:
+            stitched_image_path = stitcher_template.stitch_img(
+                IMAGE_DIR=image_dir, OUTPUT_DIR=output_dir,
+                NUM_COLS=params["num_cols"], NUM_ROWS=params["num_rows"],
+                ESTIMATE_OVERLAP_HORIZONTAL_PIXELS=params["overlap_h"],
+                ESTIMATE_OVERLAP_VERTICAL_PIXELS=params["overlap_v"],
+                BLEND_TYPE=params["tm_blend_type"].value,
+                LIGHT_COMPENSATION=params["tm_light_compensation"],
+                DEBUG_MODE=False
+            )
+    except Exception as e:
+        logging.error(f"处理文件夹 {image_dir.name} 时发生错误: {e}")
+        return None
+
+    return stitched_image_path
+
+
+# --- 新的单个拼图接口 (直接上传图片文件夹) ---
+@router.post("/from-folder", response_class=FileResponse, summary="单个拼图接口 (从文件夹上传)")
+async def stitch_single_from_folder(
+        background_tasks: BackgroundTasks,
+        files: List[UploadFile] = File(..., description="一个文件夹中的所有待拼接图片。"),
+        output_filename_base: str = Form(..., description="输出图片的基础名称(不含扩展名),例如 'puzzle_A'。"),
+        # --- 通用参数 ---
+        method: StitchingMethod = Form(StitchingMethod.TEMPLATE_MATCH, description="选择拼图方法"),
+        num_cols: int = Form(4, description="拼图的列数"),
+        num_rows: int = Form(6, description="拼图的行数"),
+        overlap_h: int = Form(405, description="预估的水平重叠像素"),
+        overlap_v: int = Form(440, description="预估的垂直重叠像素"),
+        # --- 点匹配法 (key_point) 特定参数 ---
+        kp_blend_type: KeypointBlendType = Form(KeypointBlendType.COMBINE, description="[点匹配] 融合模式"),
+        kp_feature_detector: KeypointFeatureDetector = Form(KeypointFeatureDetector.SIFT,
+                                                            description="[点匹配] 特征检测器"),
+        # --- 模板匹配法 (template_match) 特定参数 ---
+        tm_blend_type: TemplateBlendType = Form(TemplateBlendType.HALF_IMPORTANCE_ADD_WEIGHT,
+                                                description="[模板匹配] 融合模式"),
+        tm_light_compensation: bool = Form(True, description="[模板匹配] 是否启用光照补偿")
+):
+    """
+    上传一个文件夹内的所有图片进行拼接,直接返回拼接好的单张大图。
+
+    - **files**: 选择一个文件夹中的所有图片进行上传。
+    - **output_filename_base**: 为你的拼图任务命名,这个名字将作为返回图片的文件名。
+    - **返回**: 拼接成功后,返回拼接好的图片文件。
+    """
+    if not files:
+        raise HTTPException(status_code=400, detail="没有上传任何文件。")
+
+    request_id = str(uuid.uuid4())
+    session_dir = TEMP_DIR / request_id
+    session_dir.mkdir()
+    background_tasks.add_task(cleanup_temp_folder, session_dir)
+
+    # 创建用于存放上传图片的临时目录
+    image_dir = session_dir / "images"
+    image_dir.mkdir()
+
+    # 保存所有上传的文件
+    for upload_file in files:
+        file_path = image_dir / upload_file.filename
+        with open(file_path, "wb") as buffer:
+            shutil.copyfileobj(upload_file.file, buffer)
+
+    output_dir = session_dir / "output"
+
+    # 将所有参数打包到一个字典中,方便传递
+    params = locals()
+
+    # 调用核心处理函数
+    stitched_image_path = _process_single_puzzle(image_dir, output_dir, method, params)
+
+    if not stitched_image_path or not stitched_image_path.exists():
+        raise HTTPException(status_code=500, detail=f"图片拼接失败,请检查服务器日志(请求ID: {request_id})。")
+
+    # 使用用户提供的基础名称命名输出图片
+    final_filename = f"{output_filename_base}.jpg"
+    final_filepath = stitched_image_path.rename(stitched_image_path.parent / final_filename)
+
+    return FileResponse(
+        path=final_filepath,
+        filename=final_filename,
+        media_type='image/jpeg'
+    )
+
+
+# --- 新的批量拼图接口 (上传一个拼图文件夹,返回ZIP) ---
+@router.post("/batch/from-folder", response_class=FileResponse, summary="批量拼图接口 (单文件夹上传, ZIP返回)")
+async def stitch_batch_from_folder(
+        background_tasks: BackgroundTasks,
+        files: List[UploadFile] = File(..., description="一个文件夹中的所有待拼接图片。"),
+        output_filename_base: str = Form(..., description="输出图片的基础名称(不含扩展名),例如 'puzzle_A'。"),
+        # --- 参数与单个拼图接口相同 ---
+        method: StitchingMethod = Form(StitchingMethod.TEMPLATE_MATCH, description="选择拼图方法"),
+        num_cols: int = Form(4, description="拼图的列数"),
+        num_rows: int = Form(6, description="拼图的行数"),
+        overlap_h: int = Form(405, description="预估的水平重叠像素"),
+        overlap_v: int = Form(440, description="预估的垂直重叠像素"),
+        kp_blend_type: KeypointBlendType = Form(KeypointBlendType.COMBINE, description="[点匹配] 融合模式"),
+        kp_feature_detector: KeypointFeatureDetector = Form(KeypointFeatureDetector.SIFT,
+                                                            description="[点匹配] 特征检测器"),
+        tm_blend_type: TemplateBlendType = Form(TemplateBlendType.HALF_IMPORTANCE_ADD_WEIGHT,
+                                                description="[模板匹配] 融合模式"),
+        tm_light_compensation: bool = Form(True, description="[模板匹配] 是否启用光照补偿"),
+):
+    """
+    上传一个文件夹内的所有图片进行拼接,将结果打包成一个ZIP压缩文件返回。
+
+    - **files**: 选择一个文件夹中的所有图片进行上传。
+    - **output_filename_base**: 为你的拼图任务命名,这个名字将作为ZIP包内图片的文件名。
+    - **返回**: 一个ZIP压缩包,里面包含拼接好的单张图片。
+    """
+    if not files:
+        raise HTTPException(status_code=400, detail="没有上传任何文件。")
+
+    request_id = str(uuid.uuid4())
+    session_dir = TEMP_DIR / request_id
+    session_dir.mkdir()
+    background_tasks.add_task(cleanup_temp_folder, session_dir)
+
+    image_dir = session_dir / "images"
+    image_dir.mkdir()
+
+    for upload_file in files:
+        file_path = image_dir / upload_file.filename
+        with open(file_path, "wb") as buffer:
+            shutil.copyfileobj(upload_file.file, buffer)
+
+    single_output_dir = session_dir / "single_output"
+    params = locals()
+
+    stitched_image_path = _process_single_puzzle(image_dir, single_output_dir, method, params)
+
+    if not stitched_image_path or not stitched_image_path.exists():
+        raise HTTPException(status_code=500, detail=f"图片拼接失败,请检查服务器日志(请求ID: {request_id})。")
+
+    # --- 将单个结果打包成ZIP ---
+    batch_output_dir = session_dir / "batch_output"
+    batch_output_dir.mkdir()
+
+    # 将成功的结果移到最终的批量输出目录
+    target_path = batch_output_dir / f"{output_filename_base}.jpg"
+    shutil.move(str(stitched_image_path), str(target_path))
+
+    # 将最终结果打包成ZIP
+    output_zip_path = session_dir / "stitched_result.zip"
+    shutil.make_archive(str(output_zip_path.with_suffix('')), 'zip', batch_output_dir)
+
+    return FileResponse(
+        path=output_zip_path,
+        filename=f"{output_filename_base}_stitched.zip",
+        media_type='application/zip'
+    )

+ 2 - 0
app/main.py

@@ -3,6 +3,7 @@ from fastapi.middleware.cors import CORSMiddleware
 import logging
 
 from app.api.stitch import router as stitch_router
+from app.api.stitch_v2 import router as stitch_router_v2
 
 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
@@ -23,3 +24,4 @@ app.add_middleware(
 
 
 app.include_router(stitch_router, prefix="/api")
+app.include_router(stitch_router_v2, prefix="/api")