upload.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. from fastapi import APIRouter, UploadFile, File, HTTPException
  2. import os
  3. import zipfile
  4. import hashlib
  5. import shutil
  6. import cv2
  7. import glob
  8. from uuid import uuid4
  9. import shutil
  10. from app.core.config import settings
  11. from app.core.models_loader import global_models
  12. from app.db.milvus_client import milvus_collection
  13. from app.utils.parser import parse_folder_name
  14. router = APIRouter(prefix="/upload", tags=["Upload"])
  15. def calculate_md5(file_path):
  16. """计算文件的 MD5"""
  17. hash_md5 = hashlib.md5()
  18. with open(file_path, "rb") as f:
  19. for chunk in iter(lambda: f.read(4096), b""):
  20. hash_md5.update(chunk)
  21. return hash_md5.hexdigest()
  22. def _execute_batch_insert(images, metadata_list, temp_file_paths):
  23. if not images: return 0
  24. # 1. YOLO 预测
  25. try:
  26. global_models.yolo.predict_batch(images)
  27. cropped_imgs = global_models.yolo.get_max_img_list(cls_id=0)
  28. except Exception as e:
  29. print(f"Error in YOLO prediction: {e}")
  30. return 0
  31. # 2. ViT 提取特征
  32. try:
  33. vectors = global_models.vit.run(cropped_imgs, normalize=True)
  34. except Exception as e:
  35. print(f"Error in ViT extraction: {e}")
  36. return 0
  37. # 3. 准备 Milvus 数据列
  38. # Schema: img_md5, img_path, source_id, lang, card_name, card_num, vector
  39. entities = [
  40. [], # 0: img_md5 (VARCHAR)
  41. [], # 1: img_path (VARCHAR)
  42. [], # 2: source_id (VARCHAR)
  43. [], # 3: lang (VARCHAR)
  44. [], # 4: card_name (VARCHAR)
  45. [], # 5: card_num (INT64)
  46. [], # 6: vector (FLOAT_VECTOR)
  47. ]
  48. for i, meta_data in enumerate(metadata_list):
  49. vec = vectors[i].tolist()
  50. info = meta_data["meta"]
  51. new_filename = meta_data["new_filename"]
  52. # 图片转存
  53. src_path = temp_file_paths[i]
  54. dest_path = os.path.join(settings.IMAGE_STORAGE_DIR, new_filename)
  55. shutil.move(src_path, dest_path)
  56. db_img_path = f"images/{new_filename}"
  57. # --- 处理 card_num ---
  58. raw_num = info.get("card_num", "")
  59. final_num = -1 # 默认值,代表 None
  60. if raw_num:
  61. # 去掉可能存在的空格
  62. raw_num_str = str(raw_num).strip()
  63. # 如果是 "None" 字符串 或者为空
  64. if raw_num_str.lower() == "none" or raw_num_str == "":
  65. final_num = -1
  66. else:
  67. # 尝试提取数字 (比如 "004" -> 4)
  68. try:
  69. # 过滤掉非数字字符 (防止有 'No.004' 这种写法)
  70. import re
  71. # 只提取数字部分
  72. digits = re.findall(r'\d+', raw_num_str)
  73. if digits:
  74. final_num = int(digits[0])
  75. else:
  76. final_num = -1
  77. except:
  78. final_num = -1
  79. # ----------------------------
  80. entities[0].append(meta_data["md5"]) # img_md5
  81. entities[1].append(db_img_path) # img_path
  82. entities[2].append(info["source_id"]) # source_id
  83. entities[3].append(info["lang"]) # lang
  84. entities[4].append(info["card_name"]) # card_name
  85. entities[5].append(final_num) # card_num (INT)
  86. entities[6].append(vec) # vector (最后)
  87. # 4. 插入 Milvus
  88. if entities[0]:
  89. try:
  90. milvus_collection.insert(entities)
  91. return len(entities[0])
  92. except Exception as e:
  93. print(f"Milvus Insert Error: {e}")
  94. raise e
  95. return 0
  96. def process_zip_file(zip_path: str):
  97. """
  98. 同步处理 Zip 文件逻辑
  99. """
  100. extract_root = os.path.join(settings.TEMP_UPLOAD_DIR, f"extract_{uuid4().hex}")
  101. os.makedirs(extract_root, exist_ok=True)
  102. total_inserted = 0
  103. total_skipped = 0
  104. try:
  105. # 1. 解压
  106. print(f"Unzipping {zip_path}...")
  107. with zipfile.ZipFile(zip_path, 'r') as zip_ref:
  108. zip_ref.extractall(extract_root)
  109. # 准备批处理容器
  110. batch_images = []
  111. batch_metadata = []
  112. batch_temp_paths = []
  113. # 2. 遍历文件夹
  114. print("Scanning files...")
  115. for root, dirs, files in os.walk(extract_root):
  116. folder_name = os.path.basename(root)
  117. meta = parse_folder_name(folder_name)
  118. # 只有当文件夹名称符合格式,且里面有文件时才处理
  119. if not meta:
  120. continue
  121. for file in files:
  122. if file.lower().endswith(('.png', '.jpg', '.jpeg', '.webp', '.bmp')):
  123. full_temp_path = os.path.join(root, file)
  124. # --- A. MD5 计算与去重 ---
  125. img_md5 = calculate_md5(full_temp_path)
  126. # 查询 Milvus 是否已存在该 MD5
  127. res = milvus_collection.query(
  128. expr=f'img_md5 == "{img_md5}"',
  129. output_fields=["id"]
  130. )
  131. if len(res) > 0:
  132. total_skipped += 1
  133. # 即使跳过,也要把临时文件标记清理(不需要特殊操作,最后删整个文件夹)
  134. continue
  135. # --- B. 构造新文件名 (MD5 + 后缀) ---
  136. # 获取原始后缀 (如 .png)
  137. file_ext = os.path.splitext(file)[1].lower()
  138. # 新文件名: 纯净的MD5 + 后缀
  139. new_filename = f"{img_md5}{file_ext}"
  140. # --- C. 读取图片用于推理 ---
  141. img_cv = cv2.imread(full_temp_path)
  142. if img_cv is None:
  143. print(f"Warning: Failed to read image {full_temp_path}")
  144. continue
  145. img_rgb = cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB)
  146. # --- D. 加入批处理队列 ---
  147. batch_images.append(img_rgb)
  148. batch_temp_paths.append(full_temp_path)
  149. batch_metadata.append({
  150. "meta": meta,
  151. "md5": img_md5,
  152. "new_filename": new_filename
  153. })
  154. # --- E. 达到 Batch Size,执行处理 ---
  155. if len(batch_images) >= settings.BATCH_SIZE:
  156. count = _execute_batch_insert(batch_images, batch_metadata, batch_temp_paths)
  157. total_inserted += count
  158. print(f"Processed batch, inserted: {count}")
  159. # 清空队列
  160. batch_images = []
  161. batch_metadata = []
  162. batch_temp_paths = []
  163. # 处理剩余未满一个 batch 的数据
  164. if batch_images:
  165. count = _execute_batch_insert(batch_images, batch_metadata, batch_temp_paths)
  166. total_inserted += count
  167. # 强制刷新 Milvus 确保数据可见
  168. milvus_collection.flush()
  169. print(f"Import finished. Inserted: {total_inserted}, Skipped: {total_skipped}")
  170. return total_inserted, total_skipped
  171. except Exception as e:
  172. print(f"Critical Error during import: {e}")
  173. raise e
  174. finally:
  175. # 清理解压的临时文件夹
  176. if os.path.exists(extract_root):
  177. shutil.rmtree(extract_root)
  178. # 清理上传的 zip 文件
  179. if os.path.exists(zip_path):
  180. os.remove(zip_path)
  181. @router.post("/batch")
  182. def upload_zip(file: UploadFile = File(...)):
  183. """
  184. 上传 ZIP 存入向量库
  185. """
  186. if not file.filename.endswith(".zip"):
  187. raise HTTPException(status_code=400, detail="Only zip files are allowed.")
  188. # 保存上传的 zip
  189. file_location = os.path.join(settings.TEMP_UPLOAD_DIR, f"{uuid4().hex}_{file.filename}")
  190. try:
  191. with open(file_location, "wb+") as file_object:
  192. shutil.copyfileobj(file.file, file_object)
  193. except Exception as e:
  194. raise HTTPException(status_code=500, detail=f"Failed to save zip file: {e}")
  195. try:
  196. # 直接调用处理函数(同步等待)
  197. inserted, skipped = process_zip_file(file_location)
  198. return {
  199. "message": "Batch import completed successfully.",
  200. "data": {
  201. "inserted_count": inserted,
  202. "skipped_count": skipped,
  203. "total_processed": inserted + skipped
  204. }
  205. }
  206. except Exception as e:
  207. raise HTTPException(status_code=500, detail=f"Error processing zip: {str(e)}")