| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176 |
- from fastapi import APIRouter, UploadFile, File, BackgroundTasks
- import os
- import zipfile
- import hashlib
- import shutil
- import cv2
- import glob
- from uuid import uuid4
- from PIL import Image
- from app.core.config import settings
- from app.core.models_loader import global_models
- from app.db.milvus_client import milvus_collection
- from app.utils.parser import parse_folder_name
- router = APIRouter(prefix="/upload", tags=["Upload"])
- def calculate_md5(file_path):
- hash_md5 = hashlib.md5()
- with open(file_path, "rb") as f:
- for chunk in iter(lambda: f.read(4096), b""):
- hash_md5.update(chunk)
- return hash_md5.hexdigest()
- def process_batch_import(zip_path: str):
- """后台任务:处理解压和导入"""
- extract_root = os.path.join(settings.TEMP_UPLOAD_DIR, f"extract_{uuid4().hex}")
- os.makedirs(extract_root, exist_ok=True)
- try:
- # 1. 解压
- with zipfile.ZipFile(zip_path, 'r') as zip_ref:
- zip_ref.extractall(extract_root)
- # 准备批处理列表
- batch_images = [] # 存放 numpy 图片用于推理
- batch_metadata = [] # 存放元数据
- batch_file_paths = [] # 存放源文件路径 (用于移动)
- # 2. 遍历文件夹
- # 假设解压后结构: extract_root/folder_name/image.png
- # 需要递归查找,因为压缩包内可能有一层根目录
- for root, dirs, files in os.walk(extract_root):
- folder_name = os.path.basename(root)
- meta = parse_folder_name(folder_name)
- # 如果当前目录不是目标数据目录,跳过
- if not meta and files:
- # 尝试看上一级目录(有的压缩包解压会多一层)
- continue
- if not meta: continue
- for file in files:
- if file.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
- full_path = os.path.join(root, file)
- # 2.1 MD5 检查
- img_md5 = calculate_md5(full_path)
- # 查询 Milvus 是否已存在该 MD5
- res = milvus_collection.query(
- expr=f'img_md5 == "{img_md5}"',
- output_fields=["id"]
- )
- if len(res) > 0:
- print(f"Skipping duplicate: {file}")
- continue
- # 读取图片
- # cv2 读取用于 YOLO,注意转换 RGB
- img_cv = cv2.imread(full_path)
- if img_cv is None: continue
- img_rgb = cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB)
- batch_images.append(img_rgb)
- batch_metadata.append({
- "meta": meta,
- "md5": img_md5,
- "filename": f"{img_md5}_{file}" # 重命名防止覆盖
- })
- batch_file_paths.append(full_path)
- # 3. 达到 Batch Size,执行推理和插入
- if len(batch_images) >= settings.BATCH_SIZE:
- _execute_batch_insert(batch_images, batch_metadata, batch_file_paths)
- # 清空
- batch_images = []
- batch_metadata = []
- batch_file_paths = []
- # 处理剩余的
- if batch_images:
- _execute_batch_insert(batch_images, batch_metadata, batch_file_paths)
- print("Batch import finished.")
- except Exception as e:
- print(f"Error during import: {e}")
- finally:
- # 清理临时文件
- if os.path.exists(extract_root):
- shutil.rmtree(extract_root)
- if os.path.exists(zip_path):
- os.remove(zip_path)
- def _execute_batch_insert(images, metadata_list, original_paths):
- """
- 辅助函数:执行具体的推理和数据库插入
- """
- if not images: return
- # A. YOLO 批量预测
- global_models.yolo.predict_batch(images)
- # B. 获取裁剪后的图片列表
- cropped_imgs = global_models.yolo.get_max_img_list(cls_id=0) # 假设卡片 cls_id=0
- # C. ViT 批量提取特征 (注意:ViT 需要接收 List[np.ndarray] 或 PIL)
- # 处理 get_max_img_list 可能返回的 None (尽管原始代码里如果没有检测到会返回原图)
- valid_crop_imgs = []
- # 如果 MyBatchOnnxYolo 在没检测到时返回了 None,这里需要兜底
- # 但根据你的代码,它返回了 orig_img,所以应该是安全的 numpy 数组
- vectors = global_models.vit.run(cropped_imgs, normalize=True)
- # D. 准备插入数据
- entities = [
- [], # vector
- [], # img_md5
- [], # img_path
- [], # source_id
- [], # lang
- [], # card_name
- [], # card_num
- ]
- for i, meta_data in enumerate(metadata_list):
- vec = vectors[i].tolist()
- info = meta_data["meta"]
- new_filename = meta_data["filename"]
- # 保存原始图片到 static/images
- dest_path = os.path.join(settings.IMAGE_STORAGE_DIR, new_filename)
- shutil.copy(original_paths[i], dest_path)
- entities[0].append(vec)
- entities[1].append(meta_data["md5"])
- entities[2].append(dest_path) # 存储绝对路径或相对路径
- entities[3].append(info["source_id"])
- entities[4].append(info["lang"])
- entities[5].append(info["card_name"])
- entities[6].append(info["card_num"])
- # E. 插入 Milvus
- # 注意 pymilvus insert 格式: [ [vec1, vec2], [md5_1, md5_2], ... ]
- milvus_collection.insert(entities)
- milvus_collection.flush() # 生产环境可以不每次 flush,定期 flush
- print(f"Inserted {len(images)} records.")
- @router.post("/batch")
- async def upload_zip(background_tasks: BackgroundTasks, file: UploadFile = File(...)):
- if not file.filename.endswith(".zip"):
- return {"error": "Only zip files are allowed."}
- file_location = os.path.join(settings.TEMP_UPLOAD_DIR, f"{uuid4().hex}_{file.filename}")
- with open(file_location, "wb+") as file_object:
- file_object.write(await file.read())
- # 后台运行,不阻塞 API
- background_tasks.add_task(process_batch_import, file_location)
- return {"message": "File uploaded. Processing started in background."}
|