from fastapi import APIRouter, UploadFile, File, BackgroundTasks import os import zipfile import hashlib import shutil import cv2 import glob from uuid import uuid4 from PIL import Image from app.core.config import settings from app.core.models_loader import global_models from app.db.milvus_client import milvus_collection from app.utils.parser import parse_folder_name router = APIRouter(prefix="/upload", tags=["Upload"]) def calculate_md5(file_path): hash_md5 = hashlib.md5() with open(file_path, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): hash_md5.update(chunk) return hash_md5.hexdigest() def process_batch_import(zip_path: str): """后台任务:处理解压和导入""" extract_root = os.path.join(settings.TEMP_UPLOAD_DIR, f"extract_{uuid4().hex}") os.makedirs(extract_root, exist_ok=True) try: # 1. 解压 with zipfile.ZipFile(zip_path, 'r') as zip_ref: zip_ref.extractall(extract_root) # 准备批处理列表 batch_images = [] # 存放 numpy 图片用于推理 batch_metadata = [] # 存放元数据 batch_file_paths = [] # 存放源文件路径 (用于移动) # 2. 遍历文件夹 # 假设解压后结构: extract_root/folder_name/image.png # 需要递归查找,因为压缩包内可能有一层根目录 for root, dirs, files in os.walk(extract_root): folder_name = os.path.basename(root) meta = parse_folder_name(folder_name) # 如果当前目录不是目标数据目录,跳过 if not meta and files: # 尝试看上一级目录(有的压缩包解压会多一层) continue if not meta: continue for file in files: if file.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')): full_path = os.path.join(root, file) # 2.1 MD5 检查 img_md5 = calculate_md5(full_path) # 查询 Milvus 是否已存在该 MD5 res = milvus_collection.query( expr=f'img_md5 == "{img_md5}"', output_fields=["id"] ) if len(res) > 0: print(f"Skipping duplicate: {file}") continue # 读取图片 # cv2 读取用于 YOLO,注意转换 RGB img_cv = cv2.imread(full_path) if img_cv is None: continue img_rgb = cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB) batch_images.append(img_rgb) batch_metadata.append({ "meta": meta, "md5": img_md5, "filename": f"{img_md5}_{file}" # 重命名防止覆盖 }) batch_file_paths.append(full_path) # 3. 达到 Batch Size,执行推理和插入 if len(batch_images) >= settings.BATCH_SIZE: _execute_batch_insert(batch_images, batch_metadata, batch_file_paths) # 清空 batch_images = [] batch_metadata = [] batch_file_paths = [] # 处理剩余的 if batch_images: _execute_batch_insert(batch_images, batch_metadata, batch_file_paths) print("Batch import finished.") except Exception as e: print(f"Error during import: {e}") finally: # 清理临时文件 if os.path.exists(extract_root): shutil.rmtree(extract_root) if os.path.exists(zip_path): os.remove(zip_path) def _execute_batch_insert(images, metadata_list, original_paths): """ 辅助函数:执行具体的推理和数据库插入 """ if not images: return # A. YOLO 批量预测 global_models.yolo.predict_batch(images) # B. 获取裁剪后的图片列表 cropped_imgs = global_models.yolo.get_max_img_list(cls_id=0) # 假设卡片 cls_id=0 # C. ViT 批量提取特征 (注意:ViT 需要接收 List[np.ndarray] 或 PIL) # 处理 get_max_img_list 可能返回的 None (尽管原始代码里如果没有检测到会返回原图) valid_crop_imgs = [] # 如果 MyBatchOnnxYolo 在没检测到时返回了 None,这里需要兜底 # 但根据你的代码,它返回了 orig_img,所以应该是安全的 numpy 数组 vectors = global_models.vit.run(cropped_imgs, normalize=True) # D. 准备插入数据 entities = [ [], # vector [], # img_md5 [], # img_path [], # source_id [], # lang [], # card_name [], # card_num ] for i, meta_data in enumerate(metadata_list): vec = vectors[i].tolist() info = meta_data["meta"] new_filename = meta_data["filename"] # 保存原始图片到 static/images dest_path = os.path.join(settings.IMAGE_STORAGE_DIR, new_filename) shutil.copy(original_paths[i], dest_path) entities[0].append(vec) entities[1].append(meta_data["md5"]) entities[2].append(dest_path) # 存储绝对路径或相对路径 entities[3].append(info["source_id"]) entities[4].append(info["lang"]) entities[5].append(info["card_name"]) entities[6].append(info["card_num"]) # E. 插入 Milvus # 注意 pymilvus insert 格式: [ [vec1, vec2], [md5_1, md5_2], ... ] milvus_collection.insert(entities) milvus_collection.flush() # 生产环境可以不每次 flush,定期 flush print(f"Inserted {len(images)} records.") @router.post("/batch") async def upload_zip(background_tasks: BackgroundTasks, file: UploadFile = File(...)): if not file.filename.endswith(".zip"): return {"error": "Only zip files are allowed."} file_location = os.path.join(settings.TEMP_UPLOAD_DIR, f"{uuid4().hex}_{file.filename}") with open(file_location, "wb+") as file_object: file_object.write(await file.read()) # 后台运行,不阻塞 API background_tasks.add_task(process_batch_import, file_location) return {"message": "File uploaded. Processing started in background."}