| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245 |
- from fastapi import APIRouter, UploadFile, File, HTTPException
- import os
- import zipfile
- import hashlib
- import shutil
- import cv2
- import glob
- from uuid import uuid4
- import shutil
- from app.core.config import settings
- from app.core.models_loader import global_models
- from app.db.milvus_client import milvus_collection
- from app.utils.parser import parse_folder_name
- router = APIRouter(prefix="/upload", tags=["Upload"])
- def calculate_md5(file_path):
- """计算文件的 MD5"""
- hash_md5 = hashlib.md5()
- with open(file_path, "rb") as f:
- for chunk in iter(lambda: f.read(4096), b""):
- hash_md5.update(chunk)
- return hash_md5.hexdigest()
- def _execute_batch_insert(images, metadata_list, temp_file_paths):
- if not images: return 0
- # 1. YOLO 预测
- try:
- global_models.yolo.predict_batch(images)
- cropped_imgs = global_models.yolo.get_max_img_list(cls_id=0)
- except Exception as e:
- print(f"Error in YOLO prediction: {e}")
- return 0
- # 2. ViT 提取特征
- try:
- vectors = global_models.vit.run(cropped_imgs, normalize=True)
- except Exception as e:
- print(f"Error in ViT extraction: {e}")
- return 0
- # 3. 准备 Milvus 数据列
- # Schema: img_md5, img_path, source_id, lang, card_name, card_num, vector
- entities = [
- [], # 0: img_md5 (VARCHAR)
- [], # 1: img_path (VARCHAR)
- [], # 2: source_id (VARCHAR)
- [], # 3: lang (VARCHAR)
- [], # 4: card_name (VARCHAR)
- [], # 5: card_num (INT64)
- [], # 6: vector (FLOAT_VECTOR)
- ]
- for i, meta_data in enumerate(metadata_list):
- vec = vectors[i].tolist()
- info = meta_data["meta"]
- new_filename = meta_data["new_filename"]
- # 图片转存
- src_path = temp_file_paths[i]
- dest_path = os.path.join(settings.IMAGE_STORAGE_DIR, new_filename)
- shutil.move(src_path, dest_path)
- db_img_path = f"images/{new_filename}"
- # --- 处理 card_num ---
- raw_num = info.get("card_num", "")
- final_num = -1 # 默认值,代表 None
- if raw_num:
- # 去掉可能存在的空格
- raw_num_str = str(raw_num).strip()
- # 如果是 "None" 字符串 或者为空
- if raw_num_str.lower() == "none" or raw_num_str == "":
- final_num = -1
- else:
- # 尝试提取数字 (比如 "004" -> 4)
- try:
- # 过滤掉非数字字符 (防止有 'No.004' 这种写法)
- import re
- # 只提取数字部分
- digits = re.findall(r'\d+', raw_num_str)
- if digits:
- final_num = int(digits[0])
- else:
- final_num = -1
- except:
- final_num = -1
- # ----------------------------
- entities[0].append(meta_data["md5"]) # img_md5
- entities[1].append(db_img_path) # img_path
- entities[2].append(info["source_id"]) # source_id
- entities[3].append(info["lang"]) # lang
- entities[4].append(info["card_name"]) # card_name
- entities[5].append(final_num) # card_num (INT)
- entities[6].append(vec) # vector (最后)
- # 4. 插入 Milvus
- if entities[0]:
- try:
- milvus_collection.insert(entities)
- return len(entities[0])
- except Exception as e:
- print(f"Milvus Insert Error: {e}")
- raise e
- return 0
- def process_zip_file(zip_path: str):
- """
- 同步处理 Zip 文件逻辑
- """
- extract_root = os.path.join(settings.TEMP_UPLOAD_DIR, f"extract_{uuid4().hex}")
- os.makedirs(extract_root, exist_ok=True)
- total_inserted = 0
- total_skipped = 0
- try:
- # 1. 解压
- print(f"Unzipping {zip_path}...")
- with zipfile.ZipFile(zip_path, 'r') as zip_ref:
- zip_ref.extractall(extract_root)
- # 准备批处理容器
- batch_images = []
- batch_metadata = []
- batch_temp_paths = []
- # 2. 遍历文件夹
- print("Scanning files...")
- for root, dirs, files in os.walk(extract_root):
- folder_name = os.path.basename(root)
- meta = parse_folder_name(folder_name)
- # 只有当文件夹名称符合格式,且里面有文件时才处理
- if not meta:
- continue
- for file in files:
- if file.lower().endswith(('.png', '.jpg', '.jpeg', '.webp', '.bmp')):
- full_temp_path = os.path.join(root, file)
- # --- A. MD5 计算与去重 ---
- img_md5 = calculate_md5(full_temp_path)
- # 查询 Milvus 是否已存在该 MD5
- res = milvus_collection.query(
- expr=f'img_md5 == "{img_md5}"',
- output_fields=["id"]
- )
- if len(res) > 0:
- total_skipped += 1
- # 即使跳过,也要把临时文件标记清理(不需要特殊操作,最后删整个文件夹)
- continue
- # --- B. 构造新文件名 (MD5 + 后缀) ---
- # 获取原始后缀 (如 .png)
- file_ext = os.path.splitext(file)[1].lower()
- # 新文件名: 纯净的MD5 + 后缀
- new_filename = f"{img_md5}{file_ext}"
- # --- C. 读取图片用于推理 ---
- img_cv = cv2.imread(full_temp_path)
- if img_cv is None:
- print(f"Warning: Failed to read image {full_temp_path}")
- continue
- img_rgb = cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB)
- # --- D. 加入批处理队列 ---
- batch_images.append(img_rgb)
- batch_temp_paths.append(full_temp_path)
- batch_metadata.append({
- "meta": meta,
- "md5": img_md5,
- "new_filename": new_filename
- })
- # --- E. 达到 Batch Size,执行处理 ---
- if len(batch_images) >= settings.BATCH_SIZE:
- count = _execute_batch_insert(batch_images, batch_metadata, batch_temp_paths)
- total_inserted += count
- print(f"Processed batch, inserted: {count}")
- # 清空队列
- batch_images = []
- batch_metadata = []
- batch_temp_paths = []
- # 处理剩余未满一个 batch 的数据
- if batch_images:
- count = _execute_batch_insert(batch_images, batch_metadata, batch_temp_paths)
- total_inserted += count
- # 强制刷新 Milvus 确保数据可见
- milvus_collection.flush()
- print(f"Import finished. Inserted: {total_inserted}, Skipped: {total_skipped}")
- return total_inserted, total_skipped
- except Exception as e:
- print(f"Critical Error during import: {e}")
- raise e
- finally:
- # 清理解压的临时文件夹
- if os.path.exists(extract_root):
- shutil.rmtree(extract_root)
- # 清理上传的 zip 文件
- if os.path.exists(zip_path):
- os.remove(zip_path)
- @router.post("/batch")
- def upload_zip(file: UploadFile = File(...)):
- """
- 上传 ZIP 存入向量库
- """
- if not file.filename.endswith(".zip"):
- raise HTTPException(status_code=400, detail="Only zip files are allowed.")
- # 保存上传的 zip
- file_location = os.path.join(settings.TEMP_UPLOAD_DIR, f"{uuid4().hex}_{file.filename}")
- try:
- with open(file_location, "wb+") as file_object:
- shutil.copyfileobj(file.file, file_object)
- except Exception as e:
- raise HTTPException(status_code=500, detail=f"Failed to save zip file: {e}")
- try:
- # 直接调用处理函数(同步等待)
- inserted, skipped = process_zip_file(file_location)
- return {
- "message": "Batch import completed successfully.",
- "data": {
- "inserted_count": inserted,
- "skipped_count": skipped,
- "total_processed": inserted + skipped
- }
- }
- except Exception as e:
- raise HTTPException(status_code=500, detail=f"Error processing zip: {str(e)}")
|