|
@@ -25,25 +25,17 @@ def calculate_md5(file_path):
|
|
|
|
|
|
|
|
|
|
|
|
|
def _execute_batch_insert(images, metadata_list, temp_file_paths):
|
|
def _execute_batch_insert(images, metadata_list, temp_file_paths):
|
|
|
- """
|
|
|
|
|
- 执行 YOLO -> ViT -> Milvus 插入 -> 图片转存
|
|
|
|
|
- Args:
|
|
|
|
|
- images: RGB numpy 图片列表
|
|
|
|
|
- metadata_list: 包含 meta 信息, md5, 和 new_filename
|
|
|
|
|
- temp_file_paths: 解压目录下的临时文件路径
|
|
|
|
|
- """
|
|
|
|
|
if not images: return 0
|
|
if not images: return 0
|
|
|
|
|
|
|
|
- # 1. YOLO 批量预测
|
|
|
|
|
|
|
+ # 1. YOLO 预测
|
|
|
try:
|
|
try:
|
|
|
global_models.yolo.predict_batch(images)
|
|
global_models.yolo.predict_batch(images)
|
|
|
- # 获取裁剪后的图片 (如果没有检测到,内部逻辑会返回原图)
|
|
|
|
|
cropped_imgs = global_models.yolo.get_max_img_list(cls_id=0)
|
|
cropped_imgs = global_models.yolo.get_max_img_list(cls_id=0)
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
print(f"Error in YOLO prediction: {e}")
|
|
print(f"Error in YOLO prediction: {e}")
|
|
|
return 0
|
|
return 0
|
|
|
|
|
|
|
|
- # 2. ViT 批量提取特征
|
|
|
|
|
|
|
+ # 2. ViT 提取特征
|
|
|
try:
|
|
try:
|
|
|
vectors = global_models.vit.run(cropped_imgs, normalize=True)
|
|
vectors = global_models.vit.run(cropped_imgs, normalize=True)
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
@@ -51,14 +43,15 @@ def _execute_batch_insert(images, metadata_list, temp_file_paths):
|
|
|
return 0
|
|
return 0
|
|
|
|
|
|
|
|
# 3. 准备 Milvus 数据列
|
|
# 3. 准备 Milvus 数据列
|
|
|
|
|
+ # Schema: img_md5, img_path, source_id, lang, card_name, card_num, vector
|
|
|
entities = [
|
|
entities = [
|
|
|
- [], # vector
|
|
|
|
|
- [], # img_md5
|
|
|
|
|
- [], # img_path (存储相对路径或文件名,方便前端展示)
|
|
|
|
|
- [], # source_id
|
|
|
|
|
- [], # lang
|
|
|
|
|
- [], # card_name
|
|
|
|
|
- [], # card_num
|
|
|
|
|
|
|
+ [], # 0: img_md5 (VARCHAR)
|
|
|
|
|
+ [], # 1: img_path (VARCHAR)
|
|
|
|
|
+ [], # 2: source_id (VARCHAR)
|
|
|
|
|
+ [], # 3: lang (VARCHAR)
|
|
|
|
|
+ [], # 4: card_name (VARCHAR)
|
|
|
|
|
+ [], # 5: card_num (INT64)
|
|
|
|
|
+ [], # 6: vector (FLOAT_VECTOR)
|
|
|
]
|
|
]
|
|
|
|
|
|
|
|
for i, meta_data in enumerate(metadata_list):
|
|
for i, meta_data in enumerate(metadata_list):
|
|
@@ -66,32 +59,53 @@ def _execute_batch_insert(images, metadata_list, temp_file_paths):
|
|
|
info = meta_data["meta"]
|
|
info = meta_data["meta"]
|
|
|
new_filename = meta_data["new_filename"]
|
|
new_filename = meta_data["new_filename"]
|
|
|
|
|
|
|
|
- # --- 图片转存逻辑 (使用 MD5 命名) ---
|
|
|
|
|
- # 源文件位置
|
|
|
|
|
|
|
+ # 图片转存
|
|
|
src_path = temp_file_paths[i]
|
|
src_path = temp_file_paths[i]
|
|
|
- # 目标文件位置: static/images/{md5}.ext
|
|
|
|
|
dest_path = os.path.join(settings.IMAGE_STORAGE_DIR, new_filename)
|
|
dest_path = os.path.join(settings.IMAGE_STORAGE_DIR, new_filename)
|
|
|
-
|
|
|
|
|
- # 移动文件 (比 copy 快,因为源文件是临时解压的,后面会删除)
|
|
|
|
|
- # 如果目标文件已存在(理论上MD5去重后不会,但防万一),覆盖它
|
|
|
|
|
shutil.move(src_path, dest_path)
|
|
shutil.move(src_path, dest_path)
|
|
|
-
|
|
|
|
|
- # 存入数据库的是相对路径,或者直接存文件名,看你前端怎么拼接
|
|
|
|
|
- # 这里存相对路径 images/xxxx.png
|
|
|
|
|
db_img_path = f"images/{new_filename}"
|
|
db_img_path = f"images/{new_filename}"
|
|
|
|
|
|
|
|
- entities[0].append(vec)
|
|
|
|
|
- entities[1].append(meta_data["md5"])
|
|
|
|
|
- entities[2].append(db_img_path)
|
|
|
|
|
- entities[3].append(info["source_id"])
|
|
|
|
|
- entities[4].append(info["lang"])
|
|
|
|
|
- entities[5].append(info["card_name"])
|
|
|
|
|
- entities[6].append(info["card_num"])
|
|
|
|
|
|
|
+ # --- 处理 card_num ---
|
|
|
|
|
+ raw_num = info.get("card_num", "")
|
|
|
|
|
+ final_num = -1 # 默认值,代表 None
|
|
|
|
|
+
|
|
|
|
|
+ if raw_num:
|
|
|
|
|
+ # 去掉可能存在的空格
|
|
|
|
|
+ raw_num_str = str(raw_num).strip()
|
|
|
|
|
+ # 如果是 "None" 字符串 或者为空
|
|
|
|
|
+ if raw_num_str.lower() == "none" or raw_num_str == "":
|
|
|
|
|
+ final_num = -1
|
|
|
|
|
+ else:
|
|
|
|
|
+ # 尝试提取数字 (比如 "004" -> 4)
|
|
|
|
|
+ try:
|
|
|
|
|
+ # 过滤掉非数字字符 (防止有 'No.004' 这种写法)
|
|
|
|
|
+ import re
|
|
|
|
|
+ # 只提取数字部分
|
|
|
|
|
+ digits = re.findall(r'\d+', raw_num_str)
|
|
|
|
|
+ if digits:
|
|
|
|
|
+ final_num = int(digits[0])
|
|
|
|
|
+ else:
|
|
|
|
|
+ final_num = -1
|
|
|
|
|
+ except:
|
|
|
|
|
+ final_num = -1
|
|
|
|
|
+ # ----------------------------
|
|
|
|
|
+
|
|
|
|
|
+ entities[0].append(meta_data["md5"]) # img_md5
|
|
|
|
|
+ entities[1].append(db_img_path) # img_path
|
|
|
|
|
+ entities[2].append(info["source_id"]) # source_id
|
|
|
|
|
+ entities[3].append(info["lang"]) # lang
|
|
|
|
|
+ entities[4].append(info["card_name"]) # card_name
|
|
|
|
|
+ entities[5].append(final_num) # card_num (INT)
|
|
|
|
|
+ entities[6].append(vec) # vector (最后)
|
|
|
|
|
|
|
|
# 4. 插入 Milvus
|
|
# 4. 插入 Milvus
|
|
|
if entities[0]:
|
|
if entities[0]:
|
|
|
- milvus_collection.insert(entities)
|
|
|
|
|
- return len(entities[0])
|
|
|
|
|
|
|
+ try:
|
|
|
|
|
+ milvus_collection.insert(entities)
|
|
|
|
|
+ return len(entities[0])
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(f"Milvus Insert Error: {e}")
|
|
|
|
|
+ raise e
|
|
|
return 0
|
|
return 0
|
|
|
|
|
|
|
|
|
|
|
|
@@ -202,10 +216,7 @@ def process_zip_file(zip_path: str):
|
|
|
@router.post("/batch")
|
|
@router.post("/batch")
|
|
|
def upload_zip(file: UploadFile = File(...)):
|
|
def upload_zip(file: UploadFile = File(...)):
|
|
|
"""
|
|
"""
|
|
|
- 上传 ZIP 并同步等待处理完成。
|
|
|
|
|
- 注意:这里使用 def 而不是 async def,
|
|
|
|
|
- FastAPI 会自动将其放入线程池运行,不会阻塞主事件循环,
|
|
|
|
|
- 但当前的 HTTP 请求会一直挂起直到处理完成。
|
|
|
|
|
|
|
+ 上传 ZIP 存入向量库
|
|
|
"""
|
|
"""
|
|
|
if not file.filename.endswith(".zip"):
|
|
if not file.filename.endswith(".zip"):
|
|
|
raise HTTPException(status_code=400, detail="Only zip files are allowed.")
|
|
raise HTTPException(status_code=400, detail="Only zip files are allowed.")
|