Bladeren bron

上传接口的改动

AnlaAnla 6 dagen geleden
bovenliggende
commit
2e8f2ad454
2 gewijzigde bestanden met toevoegingen van 159 en 98 verwijderingen
  1. 4 1
      app/main.py
  2. 155 97
      app/routers/upload.py

+ 4 - 1
app/main.py

@@ -6,12 +6,15 @@ from contextlib import asynccontextmanager
 from app.core.models_loader import global_models
 from app.routers import search, upload, view
 from app.core.config import settings
-
+import os
 
 @asynccontextmanager
 async def lifespan(app: FastAPI):
     # 启动时加载模型
     global_models.load_models()
+
+    # 创建图像数据目录
+    os.makedirs(settings.IMAGE_STORAGE_DIR, exist_ok=True)
     yield
     # 关闭时清理资源
     pass

+ 155 - 97
app/routers/upload.py

@@ -1,4 +1,4 @@
-from fastapi import APIRouter, UploadFile, File, BackgroundTasks
+from fastapi import APIRouter, UploadFile, File, HTTPException
 import os
 import zipfile
 import hashlib
@@ -6,8 +6,7 @@ import shutil
 import cv2
 import glob
 from uuid import uuid4
-from PIL import Image
-
+import shutil
 from app.core.config import settings
 from app.core.models_loader import global_models
 from app.db.milvus_client import milvus_collection
@@ -17,6 +16,7 @@ router = APIRouter(prefix="/upload", tags=["Upload"])
 
 
 def calculate_md5(file_path):
+    """计算文件的 MD5"""
     hash_md5 = hashlib.md5()
     with open(file_path, "rb") as f:
         for chunk in iter(lambda: f.read(4096), b""):
@@ -24,41 +24,114 @@ def calculate_md5(file_path):
     return hash_md5.hexdigest()
 
 
-def process_batch_import(zip_path: str):
-    """后台任务:处理解压和导入"""
+def _execute_batch_insert(images, metadata_list, temp_file_paths):
+    """
+    执行 YOLO -> ViT -> Milvus 插入 -> 图片转存
+    Args:
+        images: RGB numpy 图片列表
+        metadata_list: 包含 meta 信息, md5, 和 new_filename
+        temp_file_paths: 解压目录下的临时文件路径
+    """
+    if not images: return 0
+
+    # 1. YOLO 批量预测
+    try:
+        global_models.yolo.predict_batch(images)
+        # 获取裁剪后的图片 (如果没有检测到,内部逻辑会返回原图)
+        cropped_imgs = global_models.yolo.get_max_img_list(cls_id=0)
+    except Exception as e:
+        print(f"Error in YOLO prediction: {e}")
+        return 0
+
+    # 2. ViT 批量提取特征
+    try:
+        vectors = global_models.vit.run(cropped_imgs, normalize=True)
+    except Exception as e:
+        print(f"Error in ViT extraction: {e}")
+        return 0
+
+    # 3. 准备 Milvus 数据列
+    entities = [
+        [],  # vector
+        [],  # img_md5
+        [],  # img_path (存储相对路径或文件名,方便前端展示)
+        [],  # source_id
+        [],  # lang
+        [],  # card_name
+        [],  # card_num
+    ]
+
+    for i, meta_data in enumerate(metadata_list):
+        vec = vectors[i].tolist()
+        info = meta_data["meta"]
+        new_filename = meta_data["new_filename"]
+
+        # --- 图片转存逻辑 (使用 MD5 命名) ---
+        # 源文件位置
+        src_path = temp_file_paths[i]
+        # 目标文件位置: static/images/{md5}.ext
+        dest_path = os.path.join(settings.IMAGE_STORAGE_DIR, new_filename)
+
+        # 移动文件 (比 copy 快,因为源文件是临时解压的,后面会删除)
+        # 如果目标文件已存在(理论上MD5去重后不会,但防万一),覆盖它
+        shutil.move(src_path, dest_path)
+
+        # 存入数据库的是相对路径,或者直接存文件名,看你前端怎么拼接
+        # 这里存相对路径 images/xxxx.png
+        db_img_path = f"images/{new_filename}"
+
+        entities[0].append(vec)
+        entities[1].append(meta_data["md5"])
+        entities[2].append(db_img_path)
+        entities[3].append(info["source_id"])
+        entities[4].append(info["lang"])
+        entities[5].append(info["card_name"])
+        entities[6].append(info["card_num"])
+
+    # 4. 插入 Milvus
+    if entities[0]:
+        milvus_collection.insert(entities)
+        return len(entities[0])
+    return 0
+
+
+def process_zip_file(zip_path: str):
+    """
+    同步处理 Zip 文件逻辑
+    """
     extract_root = os.path.join(settings.TEMP_UPLOAD_DIR, f"extract_{uuid4().hex}")
     os.makedirs(extract_root, exist_ok=True)
 
+    total_inserted = 0
+    total_skipped = 0
+
     try:
         # 1. 解压
+        print(f"Unzipping {zip_path}...")
         with zipfile.ZipFile(zip_path, 'r') as zip_ref:
             zip_ref.extractall(extract_root)
 
-        # 准备批处理列表
-        batch_images = []  # 存放 numpy 图片用于推理
-        batch_metadata = []  # 存放元数据
-        batch_file_paths = []  # 存放源文件路径 (用于移动)
+        # 准备批处理容器
+        batch_images = []
+        batch_metadata = []
+        batch_temp_paths = []
 
         # 2. 遍历文件夹
-        # 假设解压后结构: extract_root/folder_name/image.png
-        # 需要递归查找,因为压缩包内可能有一层根目录
+        print("Scanning files...")
         for root, dirs, files in os.walk(extract_root):
             folder_name = os.path.basename(root)
             meta = parse_folder_name(folder_name)
 
-            # 如果当前目录不是目标数据目录,跳过
-            if not meta and files:
-                # 尝试看上一级目录(有的压缩包解压会多一层)
+            # 只有当文件夹名称符合格式,且里面有文件时才处理
+            if not meta:
                 continue
 
-            if not meta: continue
-
             for file in files:
-                if file.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
-                    full_path = os.path.join(root, file)
+                if file.lower().endswith(('.png', '.jpg', '.jpeg', '.webp', '.bmp')):
+                    full_temp_path = os.path.join(root, file)
 
-                    # 2.1 MD5 检查
-                    img_md5 = calculate_md5(full_path)
+                    # --- A. MD5 计算与去重 ---
+                    img_md5 = calculate_md5(full_temp_path)
 
                     # 查询 Milvus 是否已存在该 MD5
                     res = milvus_collection.query(
@@ -66,111 +139,96 @@ def process_batch_import(zip_path: str):
                         output_fields=["id"]
                     )
                     if len(res) > 0:
-                        print(f"Skipping duplicate: {file}")
+                        total_skipped += 1
+                        # 即使跳过,也要把临时文件标记清理(不需要特殊操作,最后删整个文件夹)
+                        continue
+
+                    # --- B. 构造新文件名 (MD5 + 后缀) ---
+                    # 获取原始后缀 (如 .png)
+                    file_ext = os.path.splitext(file)[1].lower()
+                    # 新文件名: 纯净的MD5 + 后缀
+                    new_filename = f"{img_md5}{file_ext}"
+
+                    # --- C. 读取图片用于推理 ---
+                    img_cv = cv2.imread(full_temp_path)
+                    if img_cv is None:
+                        print(f"Warning: Failed to read image {full_temp_path}")
                         continue
 
-                    # 读取图片
-                    # cv2 读取用于 YOLO,注意转换 RGB
-                    img_cv = cv2.imread(full_path)
-                    if img_cv is None: continue
                     img_rgb = cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB)
 
+                    # --- D. 加入批处理队列 ---
                     batch_images.append(img_rgb)
+                    batch_temp_paths.append(full_temp_path)
                     batch_metadata.append({
                         "meta": meta,
                         "md5": img_md5,
-                        "filename": f"{img_md5}_{file}"  # 重命名防止覆盖
+                        "new_filename": new_filename
                     })
-                    batch_file_paths.append(full_path)
 
-                    # 3. 达到 Batch Size,执行推理和插入
+                    # --- E. 达到 Batch Size,执行处理 ---
                     if len(batch_images) >= settings.BATCH_SIZE:
-                        _execute_batch_insert(batch_images, batch_metadata, batch_file_paths)
-                        # 清空
+                        count = _execute_batch_insert(batch_images, batch_metadata, batch_temp_paths)
+                        total_inserted += count
+                        print(f"Processed batch, inserted: {count}")
+
+                        # 清空队列
                         batch_images = []
                         batch_metadata = []
-                        batch_file_paths = []
+                        batch_temp_paths = []
 
-        # 处理剩余的
+        # 处理剩余未满一个 batch 数据
         if batch_images:
-            _execute_batch_insert(batch_images, batch_metadata, batch_file_paths)
+            count = _execute_batch_insert(batch_images, batch_metadata, batch_temp_paths)
+            total_inserted += count
 
-        print("Batch import finished.")
+        # 强制刷新 Milvus 确保数据可见
+        milvus_collection.flush()
+        print(f"Import finished. Inserted: {total_inserted}, Skipped: {total_skipped}")
+        return total_inserted, total_skipped
 
     except Exception as e:
-        print(f"Error during import: {e}")
+        print(f"Critical Error during import: {e}")
+        raise e
     finally:
-        # 清理临时文件
+        # 清理解压的临时文件
         if os.path.exists(extract_root):
             shutil.rmtree(extract_root)
+        # 清理上传的 zip 文件
         if os.path.exists(zip_path):
             os.remove(zip_path)
 
 
-def _execute_batch_insert(images, metadata_list, original_paths):
+@router.post("/batch")
+def upload_zip(file: UploadFile = File(...)):
     """
-    辅助函数:执行具体的推理和数据库插入
+    上传 ZIP 并同步等待处理完成。
+    注意:这里使用 def 而不是 async def,
+    FastAPI 会自动将其放入线程池运行,不会阻塞主事件循环,
+    但当前的 HTTP 请求会一直挂起直到处理完成。
     """
-    if not images: return
-
-    # A. YOLO 批量预测
-    global_models.yolo.predict_batch(images)
-
-    # B. 获取裁剪后的图片列表
-    cropped_imgs = global_models.yolo.get_max_img_list(cls_id=0)  # 假设卡片 cls_id=0
-
-    # C. ViT 批量提取特征 (注意:ViT 需要接收 List[np.ndarray] 或 PIL)
-    # 处理 get_max_img_list 可能返回的 None (尽管原始代码里如果没有检测到会返回原图)
-    valid_crop_imgs = []
-    # 如果 MyBatchOnnxYolo 在没检测到时返回了 None,这里需要兜底
-    # 但根据你的代码,它返回了 orig_img,所以应该是安全的 numpy 数组
-    vectors = global_models.vit.run(cropped_imgs, normalize=True)
-
-    # D. 准备插入数据
-    entities = [
-        [],  # vector
-        [],  # img_md5
-        [],  # img_path
-        [],  # source_id
-        [],  # lang
-        [],  # card_name
-        [],  # card_num
-    ]
-
-    for i, meta_data in enumerate(metadata_list):
-        vec = vectors[i].tolist()
-        info = meta_data["meta"]
-        new_filename = meta_data["filename"]
-
-        # 保存原始图片到 static/images
-        dest_path = os.path.join(settings.IMAGE_STORAGE_DIR, new_filename)
-        shutil.copy(original_paths[i], dest_path)
-
-        entities[0].append(vec)
-        entities[1].append(meta_data["md5"])
-        entities[2].append(dest_path)  # 存储绝对路径或相对路径
-        entities[3].append(info["source_id"])
-        entities[4].append(info["lang"])
-        entities[5].append(info["card_name"])
-        entities[6].append(info["card_num"])
-
-    # E. 插入 Milvus
-    # 注意 pymilvus insert 格式: [ [vec1, vec2], [md5_1, md5_2], ... ]
-    milvus_collection.insert(entities)
-    milvus_collection.flush()  # 生产环境可以不每次 flush,定期 flush
-    print(f"Inserted {len(images)} records.")
-
-
-@router.post("/batch")
-async def upload_zip(background_tasks: BackgroundTasks, file: UploadFile = File(...)):
     if not file.filename.endswith(".zip"):
-        return {"error": "Only zip files are allowed."}
+        raise HTTPException(status_code=400, detail="Only zip files are allowed.")
 
+    # 保存上传的 zip
     file_location = os.path.join(settings.TEMP_UPLOAD_DIR, f"{uuid4().hex}_{file.filename}")
-    with open(file_location, "wb+") as file_object:
-        file_object.write(await file.read())
-
-    # 后台运行,不阻塞 API
-    background_tasks.add_task(process_batch_import, file_location)
+    try:
+        with open(file_location, "wb+") as file_object:
+            shutil.copyfileobj(file.file, file_object)
+    except Exception as e:
+        raise HTTPException(status_code=500, detail=f"Failed to save zip file: {e}")
 
-    return {"message": "File uploaded. Processing started in background."}
+    try:
+        # 直接调用处理函数(同步等待)
+        inserted, skipped = process_zip_file(file_location)
+
+        return {
+            "message": "Batch import completed successfully.",
+            "data": {
+                "inserted_count": inserted,
+                "skipped_count": skipped,
+                "total_processed": inserted + skipped
+            }
+        }
+    except Exception as e:
+        raise HTTPException(status_code=500, detail=f"Error processing zip: {str(e)}")