from fastapi import APIRouter, UploadFile, File, HTTPException import os import zipfile import hashlib import shutil import cv2 import glob from uuid import uuid4 import shutil from app.core.config import settings from app.core.models_loader import global_models from app.db.milvus_client import milvus_collection from app.utils.parser import parse_folder_name router = APIRouter() def calculate_md5(file_path): """计算文件的 MD5""" hash_md5 = hashlib.md5() with open(file_path, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): hash_md5.update(chunk) return hash_md5.hexdigest() def fix_text_encoding(text: str) -> str: """ 修复 Zip 解压后的中文乱码 原理:Python zipfile 默认用 cp437 解码,导致中文变乱码。 我们需要反向 encode('cp437') 拿到原始字节,再用 gbk 或 big5 decode 回来。 """ try: # 尝试 GBK (兼容简体和大部分繁体,Windows 默认) return text.encode('cp437').decode('gbk') except: try: # 如果是纯繁体系统生成的 zip,可能是 Big5 return text.encode('cp437').decode('big5') except: try: # 还有一种情况是本身是 utf-8 但被错误识别 return text.encode('cp437').decode('utf-8') except: # 实在解不了,返回原字符串 return text def _execute_batch_insert(images, metadata_list, temp_file_paths): if not images: return 0 # 1. YOLO 预测 try: global_models.yolo.predict_batch(images) cropped_imgs = global_models.yolo.get_max_img_list(cls_id=0) except Exception as e: print(f"Error in YOLO prediction: {e}") return 0 # 2. ViT 提取特征 try: vectors = global_models.vit.run(cropped_imgs, normalize=True) except Exception as e: print(f"Error in ViT extraction: {e}") return 0 # 3. 准备 Milvus 数据列 # Schema: img_md5, img_path, source_id, lang, card_name, card_num, vector entities = [ [], # 0: img_md5 (VARCHAR) [], # 1: img_path (VARCHAR) [], # 2: source_id (VARCHAR) [], # 3: lang (VARCHAR) [], # 4: card_name (VARCHAR) [], # 5: card_num (INT64) [], # 6: vector (FLOAT_VECTOR) ] for i, meta_data in enumerate(metadata_list): vec = vectors[i].tolist() info = meta_data["meta"] new_filename = meta_data["new_filename"] # 图片转存 src_path = temp_file_paths[i] dest_path = os.path.join(settings.IMAGE_STORAGE_DIR, new_filename) shutil.move(src_path, dest_path) db_img_path = f"images/{new_filename}" # --- 处理 card_num --- raw_num = info.get("card_num", "") final_num = -1 # 默认值,代表 None if raw_num: # 去掉可能存在的空格 raw_num_str = str(raw_num).strip() # 如果是 "None" 字符串 或者为空 if raw_num_str.lower() == "none" or raw_num_str == "": final_num = -1 else: # 尝试提取数字 (比如 "004" -> 4) try: # 过滤掉非数字字符 (防止有 'No.004' 这种写法) import re # 只提取数字部分 digits = re.findall(r'\d+', raw_num_str) if digits: final_num = int(digits[0]) else: final_num = -1 except: final_num = -1 # ---------------------------- entities[0].append(meta_data["md5"]) # img_md5 entities[1].append(db_img_path) # img_path entities[2].append(info["source_id"]) # source_id entities[3].append(info["lang"]) # lang entities[4].append(info["card_name"]) # card_name entities[5].append(final_num) # card_num (INT) entities[6].append(vec) # vector (最后) # 4. 插入 Milvus if entities[0]: try: milvus_collection.insert(entities) return len(entities[0]) except Exception as e: print(f"Milvus Insert Error: {e}") raise e return 0 def process_zip_file(zip_path: str): """ 同步处理 Zip 文件逻辑 """ extract_root = os.path.join(settings.TEMP_UPLOAD_DIR, f"extract_{uuid4().hex}") os.makedirs(extract_root, exist_ok=True) total_inserted = 0 total_skipped = 0 try: # 1. 解压 print(f"Unzipping {zip_path}...") with zipfile.ZipFile(zip_path, 'r') as zip_ref: zip_ref.extractall(extract_root) # 准备批处理容器 batch_images = [] batch_metadata = [] batch_temp_paths = [] # 2. 遍历文件夹 print("Scanning files...") for root, dirs, files in os.walk(extract_root): raw_folder_name = os.path.basename(root) # 修复乱码文件夹名 folder_name = fix_text_encoding(raw_folder_name) meta = parse_folder_name(folder_name) # 只有当文件夹名称符合格式,且里面有文件时才处理 if not meta: continue for file in files: if file.lower().endswith(('.png', '.jpg', '.jpeg', '.webp', '.bmp')): full_temp_path = os.path.join(root, file) # --- A. MD5 计算与去重 --- img_md5 = calculate_md5(full_temp_path) # 查询 Milvus 是否已存在该 MD5 res = milvus_collection.query( expr=f'img_md5 == "{img_md5}"', output_fields=["id"] ) if len(res) > 0: total_skipped += 1 # 即使跳过,也要把临时文件标记清理(不需要特殊操作,最后删整个文件夹) continue # --- B. 构造新文件名 (MD5 + 后缀) --- # 获取原始后缀 (如 .png) file_ext = os.path.splitext(file)[1].lower() # 新文件名: 纯净的MD5 + 后缀 new_filename = f"{img_md5}{file_ext}" # --- C. 读取图片用于推理 --- img_cv = cv2.imread(full_temp_path) if img_cv is None: print(f"Warning: Failed to read image {full_temp_path}") continue img_rgb = cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB) # --- D. 加入批处理队列 --- batch_images.append(img_rgb) batch_temp_paths.append(full_temp_path) batch_metadata.append({ "meta": meta, "md5": img_md5, "new_filename": new_filename }) # --- E. 达到 Batch Size,执行处理 --- if len(batch_images) >= settings.BATCH_SIZE: count = _execute_batch_insert(batch_images, batch_metadata, batch_temp_paths) total_inserted += count print(f"Processed batch, inserted: {count}") # 清空队列 batch_images = [] batch_metadata = [] batch_temp_paths = [] # 处理剩余未满一个 batch 的数据 if batch_images: count = _execute_batch_insert(batch_images, batch_metadata, batch_temp_paths) total_inserted += count # 强制刷新 Milvus 确保数据可见 milvus_collection.flush() print(f"Import finished. Inserted: {total_inserted}, Skipped: {total_skipped}") return total_inserted, total_skipped except Exception as e: print(f"Critical Error during import: {e}") raise e finally: # 清理解压的临时文件夹 if os.path.exists(extract_root): shutil.rmtree(extract_root) # 清理上传的 zip 文件 if os.path.exists(zip_path): os.remove(zip_path) @router.post("/batch", summary="上传宝可梦数据压缩包") def upload_zip(file: UploadFile = File(...)): """ 上传 ZIP 存入向量库 大文件夹.zip 小文件夹必须遵循该格式: ('129873', {'us'}, 'Swadloon'), 2 图片... """ if not file.filename.endswith(".zip"): raise HTTPException(status_code=400, detail="Only zip files are allowed.") # 保存上传的 zip file_location = os.path.join(settings.TEMP_UPLOAD_DIR, f"{uuid4().hex}_{file.filename}") try: with open(file_location, "wb+") as file_object: shutil.copyfileobj(file.file, file_object) except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to save zip file: {e}") try: # 直接调用处理函数(同步等待) inserted, skipped = process_zip_file(file_location) return { "message": "Batch import completed successfully.", "data": { "inserted_count": inserted, "skipped_count": skipped, "total_processed": inserted + skipped } } except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing zip: {str(e)}")