upload.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. from fastapi import APIRouter, UploadFile, File, HTTPException
  2. import os
  3. import zipfile
  4. import hashlib
  5. import shutil
  6. import cv2
  7. import glob
  8. from uuid import uuid4
  9. import shutil
  10. from app.core.config import settings
  11. from app.core.models_loader import global_models
  12. from app.db.milvus_client import milvus_collection
  13. from app.utils.parser import parse_folder_name
  14. router = APIRouter(prefix="/upload", tags=["Upload"])
  15. def calculate_md5(file_path):
  16. """计算文件的 MD5"""
  17. hash_md5 = hashlib.md5()
  18. with open(file_path, "rb") as f:
  19. for chunk in iter(lambda: f.read(4096), b""):
  20. hash_md5.update(chunk)
  21. return hash_md5.hexdigest()
  22. def _execute_batch_insert(images, metadata_list, temp_file_paths):
  23. """
  24. 执行 YOLO -> ViT -> Milvus 插入 -> 图片转存
  25. Args:
  26. images: RGB numpy 图片列表
  27. metadata_list: 包含 meta 信息, md5, 和 new_filename
  28. temp_file_paths: 解压目录下的临时文件路径
  29. """
  30. if not images: return 0
  31. # 1. YOLO 批量预测
  32. try:
  33. global_models.yolo.predict_batch(images)
  34. # 获取裁剪后的图片 (如果没有检测到,内部逻辑会返回原图)
  35. cropped_imgs = global_models.yolo.get_max_img_list(cls_id=0)
  36. except Exception as e:
  37. print(f"Error in YOLO prediction: {e}")
  38. return 0
  39. # 2. ViT 批量提取特征
  40. try:
  41. vectors = global_models.vit.run(cropped_imgs, normalize=True)
  42. except Exception as e:
  43. print(f"Error in ViT extraction: {e}")
  44. return 0
  45. # 3. 准备 Milvus 数据列
  46. entities = [
  47. [], # vector
  48. [], # img_md5
  49. [], # img_path (存储相对路径或文件名,方便前端展示)
  50. [], # source_id
  51. [], # lang
  52. [], # card_name
  53. [], # card_num
  54. ]
  55. for i, meta_data in enumerate(metadata_list):
  56. vec = vectors[i].tolist()
  57. info = meta_data["meta"]
  58. new_filename = meta_data["new_filename"]
  59. # --- 图片转存逻辑 (使用 MD5 命名) ---
  60. # 源文件位置
  61. src_path = temp_file_paths[i]
  62. # 目标文件位置: static/images/{md5}.ext
  63. dest_path = os.path.join(settings.IMAGE_STORAGE_DIR, new_filename)
  64. # 移动文件 (比 copy 快,因为源文件是临时解压的,后面会删除)
  65. # 如果目标文件已存在(理论上MD5去重后不会,但防万一),覆盖它
  66. shutil.move(src_path, dest_path)
  67. # 存入数据库的是相对路径,或者直接存文件名,看你前端怎么拼接
  68. # 这里存相对路径 images/xxxx.png
  69. db_img_path = f"images/{new_filename}"
  70. entities[0].append(vec)
  71. entities[1].append(meta_data["md5"])
  72. entities[2].append(db_img_path)
  73. entities[3].append(info["source_id"])
  74. entities[4].append(info["lang"])
  75. entities[5].append(info["card_name"])
  76. entities[6].append(info["card_num"])
  77. # 4. 插入 Milvus
  78. if entities[0]:
  79. milvus_collection.insert(entities)
  80. return len(entities[0])
  81. return 0
  82. def process_zip_file(zip_path: str):
  83. """
  84. 同步处理 Zip 文件逻辑
  85. """
  86. extract_root = os.path.join(settings.TEMP_UPLOAD_DIR, f"extract_{uuid4().hex}")
  87. os.makedirs(extract_root, exist_ok=True)
  88. total_inserted = 0
  89. total_skipped = 0
  90. try:
  91. # 1. 解压
  92. print(f"Unzipping {zip_path}...")
  93. with zipfile.ZipFile(zip_path, 'r') as zip_ref:
  94. zip_ref.extractall(extract_root)
  95. # 准备批处理容器
  96. batch_images = []
  97. batch_metadata = []
  98. batch_temp_paths = []
  99. # 2. 遍历文件夹
  100. print("Scanning files...")
  101. for root, dirs, files in os.walk(extract_root):
  102. folder_name = os.path.basename(root)
  103. meta = parse_folder_name(folder_name)
  104. # 只有当文件夹名称符合格式,且里面有文件时才处理
  105. if not meta:
  106. continue
  107. for file in files:
  108. if file.lower().endswith(('.png', '.jpg', '.jpeg', '.webp', '.bmp')):
  109. full_temp_path = os.path.join(root, file)
  110. # --- A. MD5 计算与去重 ---
  111. img_md5 = calculate_md5(full_temp_path)
  112. # 查询 Milvus 是否已存在该 MD5
  113. res = milvus_collection.query(
  114. expr=f'img_md5 == "{img_md5}"',
  115. output_fields=["id"]
  116. )
  117. if len(res) > 0:
  118. total_skipped += 1
  119. # 即使跳过,也要把临时文件标记清理(不需要特殊操作,最后删整个文件夹)
  120. continue
  121. # --- B. 构造新文件名 (MD5 + 后缀) ---
  122. # 获取原始后缀 (如 .png)
  123. file_ext = os.path.splitext(file)[1].lower()
  124. # 新文件名: 纯净的MD5 + 后缀
  125. new_filename = f"{img_md5}{file_ext}"
  126. # --- C. 读取图片用于推理 ---
  127. img_cv = cv2.imread(full_temp_path)
  128. if img_cv is None:
  129. print(f"Warning: Failed to read image {full_temp_path}")
  130. continue
  131. img_rgb = cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB)
  132. # --- D. 加入批处理队列 ---
  133. batch_images.append(img_rgb)
  134. batch_temp_paths.append(full_temp_path)
  135. batch_metadata.append({
  136. "meta": meta,
  137. "md5": img_md5,
  138. "new_filename": new_filename
  139. })
  140. # --- E. 达到 Batch Size,执行处理 ---
  141. if len(batch_images) >= settings.BATCH_SIZE:
  142. count = _execute_batch_insert(batch_images, batch_metadata, batch_temp_paths)
  143. total_inserted += count
  144. print(f"Processed batch, inserted: {count}")
  145. # 清空队列
  146. batch_images = []
  147. batch_metadata = []
  148. batch_temp_paths = []
  149. # 处理剩余未满一个 batch 的数据
  150. if batch_images:
  151. count = _execute_batch_insert(batch_images, batch_metadata, batch_temp_paths)
  152. total_inserted += count
  153. # 强制刷新 Milvus 确保数据可见
  154. milvus_collection.flush()
  155. print(f"Import finished. Inserted: {total_inserted}, Skipped: {total_skipped}")
  156. return total_inserted, total_skipped
  157. except Exception as e:
  158. print(f"Critical Error during import: {e}")
  159. raise e
  160. finally:
  161. # 清理解压的临时文件夹
  162. if os.path.exists(extract_root):
  163. shutil.rmtree(extract_root)
  164. # 清理上传的 zip 文件
  165. if os.path.exists(zip_path):
  166. os.remove(zip_path)
  167. @router.post("/batch")
  168. def upload_zip(file: UploadFile = File(...)):
  169. """
  170. 上传 ZIP 并同步等待处理完成。
  171. 注意:这里使用 def 而不是 async def,
  172. FastAPI 会自动将其放入线程池运行,不会阻塞主事件循环,
  173. 但当前的 HTTP 请求会一直挂起直到处理完成。
  174. """
  175. if not file.filename.endswith(".zip"):
  176. raise HTTPException(status_code=400, detail="Only zip files are allowed.")
  177. # 保存上传的 zip
  178. file_location = os.path.join(settings.TEMP_UPLOAD_DIR, f"{uuid4().hex}_{file.filename}")
  179. try:
  180. with open(file_location, "wb+") as file_object:
  181. shutil.copyfileobj(file.file, file_object)
  182. except Exception as e:
  183. raise HTTPException(status_code=500, detail=f"Failed to save zip file: {e}")
  184. try:
  185. # 直接调用处理函数(同步等待)
  186. inserted, skipped = process_zip_file(file_location)
  187. return {
  188. "message": "Batch import completed successfully.",
  189. "data": {
  190. "inserted_count": inserted,
  191. "skipped_count": skipped,
  192. "total_processed": inserted + skipped
  193. }
  194. }
  195. except Exception as e:
  196. raise HTTPException(status_code=500, detail=f"Error processing zip: {str(e)}")