|
|
@@ -1,4 +1,4 @@
|
|
|
-from fastapi import APIRouter, UploadFile, File, BackgroundTasks
|
|
|
+from fastapi import APIRouter, UploadFile, File, HTTPException
|
|
|
import os
|
|
|
import zipfile
|
|
|
import hashlib
|
|
|
@@ -6,8 +6,7 @@ import shutil
|
|
|
import cv2
|
|
|
import glob
|
|
|
from uuid import uuid4
|
|
|
-from PIL import Image
|
|
|
-
|
|
|
+import shutil
|
|
|
from app.core.config import settings
|
|
|
from app.core.models_loader import global_models
|
|
|
from app.db.milvus_client import milvus_collection
|
|
|
@@ -17,6 +16,7 @@ router = APIRouter(prefix="/upload", tags=["Upload"])
|
|
|
|
|
|
|
|
|
def calculate_md5(file_path):
|
|
|
+ """计算文件的 MD5"""
|
|
|
hash_md5 = hashlib.md5()
|
|
|
with open(file_path, "rb") as f:
|
|
|
for chunk in iter(lambda: f.read(4096), b""):
|
|
|
@@ -24,41 +24,114 @@ def calculate_md5(file_path):
|
|
|
return hash_md5.hexdigest()
|
|
|
|
|
|
|
|
|
-def process_batch_import(zip_path: str):
|
|
|
- """后台任务:处理解压和导入"""
|
|
|
+def _execute_batch_insert(images, metadata_list, temp_file_paths):
|
|
|
+ """
|
|
|
+ 执行 YOLO -> ViT -> Milvus 插入 -> 图片转存
|
|
|
+ Args:
|
|
|
+ images: RGB numpy 图片列表
|
|
|
+ metadata_list: 包含 meta 信息, md5, 和 new_filename
|
|
|
+ temp_file_paths: 解压目录下的临时文件路径
|
|
|
+ """
|
|
|
+ if not images: return 0
|
|
|
+
|
|
|
+ # 1. YOLO 批量预测
|
|
|
+ try:
|
|
|
+ global_models.yolo.predict_batch(images)
|
|
|
+ # 获取裁剪后的图片 (如果没有检测到,内部逻辑会返回原图)
|
|
|
+ cropped_imgs = global_models.yolo.get_max_img_list(cls_id=0)
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error in YOLO prediction: {e}")
|
|
|
+ return 0
|
|
|
+
|
|
|
+ # 2. ViT 批量提取特征
|
|
|
+ try:
|
|
|
+ vectors = global_models.vit.run(cropped_imgs, normalize=True)
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error in ViT extraction: {e}")
|
|
|
+ return 0
|
|
|
+
|
|
|
+ # 3. 准备 Milvus 数据列
|
|
|
+ entities = [
|
|
|
+ [], # vector
|
|
|
+ [], # img_md5
|
|
|
+ [], # img_path (存储相对路径或文件名,方便前端展示)
|
|
|
+ [], # source_id
|
|
|
+ [], # lang
|
|
|
+ [], # card_name
|
|
|
+ [], # card_num
|
|
|
+ ]
|
|
|
+
|
|
|
+ for i, meta_data in enumerate(metadata_list):
|
|
|
+ vec = vectors[i].tolist()
|
|
|
+ info = meta_data["meta"]
|
|
|
+ new_filename = meta_data["new_filename"]
|
|
|
+
|
|
|
+ # --- 图片转存逻辑 (使用 MD5 命名) ---
|
|
|
+ # 源文件位置
|
|
|
+ src_path = temp_file_paths[i]
|
|
|
+ # 目标文件位置: static/images/{md5}.ext
|
|
|
+ dest_path = os.path.join(settings.IMAGE_STORAGE_DIR, new_filename)
|
|
|
+
|
|
|
+ # 移动文件 (比 copy 快,因为源文件是临时解压的,后面会删除)
|
|
|
+ # 如果目标文件已存在(理论上MD5去重后不会,但防万一),覆盖它
|
|
|
+ shutil.move(src_path, dest_path)
|
|
|
+
|
|
|
+ # 存入数据库的是相对路径,或者直接存文件名,看你前端怎么拼接
|
|
|
+ # 这里存相对路径 images/xxxx.png
|
|
|
+ db_img_path = f"images/{new_filename}"
|
|
|
+
|
|
|
+ entities[0].append(vec)
|
|
|
+ entities[1].append(meta_data["md5"])
|
|
|
+ entities[2].append(db_img_path)
|
|
|
+ entities[3].append(info["source_id"])
|
|
|
+ entities[4].append(info["lang"])
|
|
|
+ entities[5].append(info["card_name"])
|
|
|
+ entities[6].append(info["card_num"])
|
|
|
+
|
|
|
+ # 4. 插入 Milvus
|
|
|
+ if entities[0]:
|
|
|
+ milvus_collection.insert(entities)
|
|
|
+ return len(entities[0])
|
|
|
+ return 0
|
|
|
+
|
|
|
+
|
|
|
+def process_zip_file(zip_path: str):
|
|
|
+ """
|
|
|
+ 同步处理 Zip 文件逻辑
|
|
|
+ """
|
|
|
extract_root = os.path.join(settings.TEMP_UPLOAD_DIR, f"extract_{uuid4().hex}")
|
|
|
os.makedirs(extract_root, exist_ok=True)
|
|
|
|
|
|
+ total_inserted = 0
|
|
|
+ total_skipped = 0
|
|
|
+
|
|
|
try:
|
|
|
# 1. 解压
|
|
|
+ print(f"Unzipping {zip_path}...")
|
|
|
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
|
|
zip_ref.extractall(extract_root)
|
|
|
|
|
|
- # 准备批处理列表
|
|
|
- batch_images = [] # 存放 numpy 图片用于推理
|
|
|
- batch_metadata = [] # 存放元数据
|
|
|
- batch_file_paths = [] # 存放源文件路径 (用于移动)
|
|
|
+ # 准备批处理容器
|
|
|
+ batch_images = []
|
|
|
+ batch_metadata = []
|
|
|
+ batch_temp_paths = []
|
|
|
|
|
|
# 2. 遍历文件夹
|
|
|
- # 假设解压后结构: extract_root/folder_name/image.png
|
|
|
- # 需要递归查找,因为压缩包内可能有一层根目录
|
|
|
+ print("Scanning files...")
|
|
|
for root, dirs, files in os.walk(extract_root):
|
|
|
folder_name = os.path.basename(root)
|
|
|
meta = parse_folder_name(folder_name)
|
|
|
|
|
|
- # 如果当前目录不是目标数据目录,跳过
|
|
|
- if not meta and files:
|
|
|
- # 尝试看上一级目录(有的压缩包解压会多一层)
|
|
|
+ # 只有当文件夹名称符合格式,且里面有文件时才处理
|
|
|
+ if not meta:
|
|
|
continue
|
|
|
|
|
|
- if not meta: continue
|
|
|
-
|
|
|
for file in files:
|
|
|
- if file.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
|
|
|
- full_path = os.path.join(root, file)
|
|
|
+ if file.lower().endswith(('.png', '.jpg', '.jpeg', '.webp', '.bmp')):
|
|
|
+ full_temp_path = os.path.join(root, file)
|
|
|
|
|
|
- # 2.1 MD5 检查
|
|
|
- img_md5 = calculate_md5(full_path)
|
|
|
+ # --- A. MD5 计算与去重 ---
|
|
|
+ img_md5 = calculate_md5(full_temp_path)
|
|
|
|
|
|
# 查询 Milvus 是否已存在该 MD5
|
|
|
res = milvus_collection.query(
|
|
|
@@ -66,111 +139,96 @@ def process_batch_import(zip_path: str):
|
|
|
output_fields=["id"]
|
|
|
)
|
|
|
if len(res) > 0:
|
|
|
- print(f"Skipping duplicate: {file}")
|
|
|
+ total_skipped += 1
|
|
|
+ # 即使跳过,也要把临时文件标记清理(不需要特殊操作,最后删整个文件夹)
|
|
|
+ continue
|
|
|
+
|
|
|
+ # --- B. 构造新文件名 (MD5 + 后缀) ---
|
|
|
+ # 获取原始后缀 (如 .png)
|
|
|
+ file_ext = os.path.splitext(file)[1].lower()
|
|
|
+ # 新文件名: 纯净的MD5 + 后缀
|
|
|
+ new_filename = f"{img_md5}{file_ext}"
|
|
|
+
|
|
|
+ # --- C. 读取图片用于推理 ---
|
|
|
+ img_cv = cv2.imread(full_temp_path)
|
|
|
+ if img_cv is None:
|
|
|
+ print(f"Warning: Failed to read image {full_temp_path}")
|
|
|
continue
|
|
|
|
|
|
- # 读取图片
|
|
|
- # cv2 读取用于 YOLO,注意转换 RGB
|
|
|
- img_cv = cv2.imread(full_path)
|
|
|
- if img_cv is None: continue
|
|
|
img_rgb = cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB)
|
|
|
|
|
|
+ # --- D. 加入批处理队列 ---
|
|
|
batch_images.append(img_rgb)
|
|
|
+ batch_temp_paths.append(full_temp_path)
|
|
|
batch_metadata.append({
|
|
|
"meta": meta,
|
|
|
"md5": img_md5,
|
|
|
- "filename": f"{img_md5}_{file}" # 重命名防止覆盖
|
|
|
+ "new_filename": new_filename
|
|
|
})
|
|
|
- batch_file_paths.append(full_path)
|
|
|
|
|
|
- # 3. 达到 Batch Size,执行推理和插入
|
|
|
+ # --- E. 达到 Batch Size,执行处理 ---
|
|
|
if len(batch_images) >= settings.BATCH_SIZE:
|
|
|
- _execute_batch_insert(batch_images, batch_metadata, batch_file_paths)
|
|
|
- # 清空
|
|
|
+ count = _execute_batch_insert(batch_images, batch_metadata, batch_temp_paths)
|
|
|
+ total_inserted += count
|
|
|
+ print(f"Processed batch, inserted: {count}")
|
|
|
+
|
|
|
+ # 清空队列
|
|
|
batch_images = []
|
|
|
batch_metadata = []
|
|
|
- batch_file_paths = []
|
|
|
+ batch_temp_paths = []
|
|
|
|
|
|
- # 处理剩余的
|
|
|
+ # 处理剩余未满一个 batch 的数据
|
|
|
if batch_images:
|
|
|
- _execute_batch_insert(batch_images, batch_metadata, batch_file_paths)
|
|
|
+ count = _execute_batch_insert(batch_images, batch_metadata, batch_temp_paths)
|
|
|
+ total_inserted += count
|
|
|
|
|
|
- print("Batch import finished.")
|
|
|
+ # 强制刷新 Milvus 确保数据可见
|
|
|
+ milvus_collection.flush()
|
|
|
+ print(f"Import finished. Inserted: {total_inserted}, Skipped: {total_skipped}")
|
|
|
+ return total_inserted, total_skipped
|
|
|
|
|
|
except Exception as e:
|
|
|
- print(f"Error during import: {e}")
|
|
|
+ print(f"Critical Error during import: {e}")
|
|
|
+ raise e
|
|
|
finally:
|
|
|
- # 清理临时文件
|
|
|
+ # 清理解压的临时文件夹
|
|
|
if os.path.exists(extract_root):
|
|
|
shutil.rmtree(extract_root)
|
|
|
+ # 清理上传的 zip 文件
|
|
|
if os.path.exists(zip_path):
|
|
|
os.remove(zip_path)
|
|
|
|
|
|
|
|
|
-def _execute_batch_insert(images, metadata_list, original_paths):
|
|
|
+@router.post("/batch")
|
|
|
+def upload_zip(file: UploadFile = File(...)):
|
|
|
"""
|
|
|
- 辅助函数:执行具体的推理和数据库插入
|
|
|
+ 上传 ZIP 并同步等待处理完成。
|
|
|
+ 注意:这里使用 def 而不是 async def,
|
|
|
+ FastAPI 会自动将其放入线程池运行,不会阻塞主事件循环,
|
|
|
+ 但当前的 HTTP 请求会一直挂起直到处理完成。
|
|
|
"""
|
|
|
- if not images: return
|
|
|
-
|
|
|
- # A. YOLO 批量预测
|
|
|
- global_models.yolo.predict_batch(images)
|
|
|
-
|
|
|
- # B. 获取裁剪后的图片列表
|
|
|
- cropped_imgs = global_models.yolo.get_max_img_list(cls_id=0) # 假设卡片 cls_id=0
|
|
|
-
|
|
|
- # C. ViT 批量提取特征 (注意:ViT 需要接收 List[np.ndarray] 或 PIL)
|
|
|
- # 处理 get_max_img_list 可能返回的 None (尽管原始代码里如果没有检测到会返回原图)
|
|
|
- valid_crop_imgs = []
|
|
|
- # 如果 MyBatchOnnxYolo 在没检测到时返回了 None,这里需要兜底
|
|
|
- # 但根据你的代码,它返回了 orig_img,所以应该是安全的 numpy 数组
|
|
|
- vectors = global_models.vit.run(cropped_imgs, normalize=True)
|
|
|
-
|
|
|
- # D. 准备插入数据
|
|
|
- entities = [
|
|
|
- [], # vector
|
|
|
- [], # img_md5
|
|
|
- [], # img_path
|
|
|
- [], # source_id
|
|
|
- [], # lang
|
|
|
- [], # card_name
|
|
|
- [], # card_num
|
|
|
- ]
|
|
|
-
|
|
|
- for i, meta_data in enumerate(metadata_list):
|
|
|
- vec = vectors[i].tolist()
|
|
|
- info = meta_data["meta"]
|
|
|
- new_filename = meta_data["filename"]
|
|
|
-
|
|
|
- # 保存原始图片到 static/images
|
|
|
- dest_path = os.path.join(settings.IMAGE_STORAGE_DIR, new_filename)
|
|
|
- shutil.copy(original_paths[i], dest_path)
|
|
|
-
|
|
|
- entities[0].append(vec)
|
|
|
- entities[1].append(meta_data["md5"])
|
|
|
- entities[2].append(dest_path) # 存储绝对路径或相对路径
|
|
|
- entities[3].append(info["source_id"])
|
|
|
- entities[4].append(info["lang"])
|
|
|
- entities[5].append(info["card_name"])
|
|
|
- entities[6].append(info["card_num"])
|
|
|
-
|
|
|
- # E. 插入 Milvus
|
|
|
- # 注意 pymilvus insert 格式: [ [vec1, vec2], [md5_1, md5_2], ... ]
|
|
|
- milvus_collection.insert(entities)
|
|
|
- milvus_collection.flush() # 生产环境可以不每次 flush,定期 flush
|
|
|
- print(f"Inserted {len(images)} records.")
|
|
|
-
|
|
|
-
|
|
|
-@router.post("/batch")
|
|
|
-async def upload_zip(background_tasks: BackgroundTasks, file: UploadFile = File(...)):
|
|
|
if not file.filename.endswith(".zip"):
|
|
|
- return {"error": "Only zip files are allowed."}
|
|
|
+ raise HTTPException(status_code=400, detail="Only zip files are allowed.")
|
|
|
|
|
|
+ # 保存上传的 zip
|
|
|
file_location = os.path.join(settings.TEMP_UPLOAD_DIR, f"{uuid4().hex}_{file.filename}")
|
|
|
- with open(file_location, "wb+") as file_object:
|
|
|
- file_object.write(await file.read())
|
|
|
-
|
|
|
- # 后台运行,不阻塞 API
|
|
|
- background_tasks.add_task(process_batch_import, file_location)
|
|
|
+ try:
|
|
|
+ with open(file_location, "wb+") as file_object:
|
|
|
+ shutil.copyfileobj(file.file, file_object)
|
|
|
+ except Exception as e:
|
|
|
+ raise HTTPException(status_code=500, detail=f"Failed to save zip file: {e}")
|
|
|
|
|
|
- return {"message": "File uploaded. Processing started in background."}
|
|
|
+ try:
|
|
|
+ # 直接调用处理函数(同步等待)
|
|
|
+ inserted, skipped = process_zip_file(file_location)
|
|
|
+
|
|
|
+ return {
|
|
|
+ "message": "Batch import completed successfully.",
|
|
|
+ "data": {
|
|
|
+ "inserted_count": inserted,
|
|
|
+ "skipped_count": skipped,
|
|
|
+ "total_processed": inserted + skipped
|
|
|
+ }
|
|
|
+ }
|
|
|
+ except Exception as e:
|
|
|
+ raise HTTPException(status_code=500, detail=f"Error processing zip: {str(e)}")
|