upload.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. from fastapi import APIRouter, UploadFile, File, HTTPException
  2. import os
  3. import zipfile
  4. import hashlib
  5. import shutil
  6. import cv2
  7. import glob
  8. from uuid import uuid4
  9. import shutil
  10. from app.core.config import settings
  11. from app.core.models_loader import global_models
  12. from app.db.milvus_client import milvus_collection
  13. from app.utils.parser import parse_folder_name
  14. router = APIRouter()
  15. def calculate_md5(file_path):
  16. """计算文件的 MD5"""
  17. hash_md5 = hashlib.md5()
  18. with open(file_path, "rb") as f:
  19. for chunk in iter(lambda: f.read(4096), b""):
  20. hash_md5.update(chunk)
  21. return hash_md5.hexdigest()
  22. def fix_text_encoding(text: str) -> str:
  23. """
  24. 修复 Zip 解压后的中文乱码
  25. 原理:Python zipfile 默认用 cp437 解码,导致中文变乱码。
  26. 我们需要反向 encode('cp437') 拿到原始字节,再用 gbk 或 big5 decode 回来。
  27. """
  28. try:
  29. # 尝试 GBK (兼容简体和大部分繁体,Windows 默认)
  30. return text.encode('cp437').decode('gbk')
  31. except:
  32. try:
  33. # 如果是纯繁体系统生成的 zip,可能是 Big5
  34. return text.encode('cp437').decode('big5')
  35. except:
  36. try:
  37. # 还有一种情况是本身是 utf-8 但被错误识别
  38. return text.encode('cp437').decode('utf-8')
  39. except:
  40. # 实在解不了,返回原字符串
  41. return text
  42. def _execute_batch_insert(images, metadata_list, temp_file_paths):
  43. if not images: return 0
  44. # 1. YOLO 预测
  45. try:
  46. global_models.yolo.predict_batch(images)
  47. cropped_imgs = global_models.yolo.get_max_img_list(cls_id=0)
  48. except Exception as e:
  49. print(f"Error in YOLO prediction: {e}")
  50. return 0
  51. # 2. ViT 提取特征
  52. try:
  53. vectors = global_models.vit.run(cropped_imgs, normalize=True)
  54. except Exception as e:
  55. print(f"Error in ViT extraction: {e}")
  56. return 0
  57. # 3. 准备 Milvus 数据列
  58. # Schema: img_md5, img_path, source_id, lang, card_name, card_num, vector
  59. entities = [
  60. [], # 0: img_md5 (VARCHAR)
  61. [], # 1: img_path (VARCHAR)
  62. [], # 2: source_id (VARCHAR)
  63. [], # 3: lang (VARCHAR)
  64. [], # 4: card_name (VARCHAR)
  65. [], # 5: card_num (INT64)
  66. [], # 6: vector (FLOAT_VECTOR)
  67. ]
  68. for i, meta_data in enumerate(metadata_list):
  69. vec = vectors[i].tolist()
  70. info = meta_data["meta"]
  71. new_filename = meta_data["new_filename"]
  72. # 图片转存
  73. src_path = temp_file_paths[i]
  74. dest_path = os.path.join(settings.IMAGE_STORAGE_DIR, new_filename)
  75. shutil.move(src_path, dest_path)
  76. db_img_path = f"images/{new_filename}"
  77. # --- 处理 card_num ---
  78. raw_num = info.get("card_num", "")
  79. final_num = -1 # 默认值,代表 None
  80. if raw_num:
  81. # 去掉可能存在的空格
  82. raw_num_str = str(raw_num).strip()
  83. # 如果是 "None" 字符串 或者为空
  84. if raw_num_str.lower() == "none" or raw_num_str == "":
  85. final_num = -1
  86. else:
  87. # 尝试提取数字 (比如 "004" -> 4)
  88. try:
  89. # 过滤掉非数字字符 (防止有 'No.004' 这种写法)
  90. import re
  91. # 只提取数字部分
  92. digits = re.findall(r'\d+', raw_num_str)
  93. if digits:
  94. final_num = int(digits[0])
  95. else:
  96. final_num = -1
  97. except:
  98. final_num = -1
  99. # ----------------------------
  100. entities[0].append(meta_data["md5"]) # img_md5
  101. entities[1].append(db_img_path) # img_path
  102. entities[2].append(info["source_id"]) # source_id
  103. entities[3].append(info["lang"]) # lang
  104. entities[4].append(info["card_name"]) # card_name
  105. entities[5].append(final_num) # card_num (INT)
  106. entities[6].append(vec) # vector (最后)
  107. # 4. 插入 Milvus
  108. if entities[0]:
  109. try:
  110. milvus_collection.insert(entities)
  111. return len(entities[0])
  112. except Exception as e:
  113. print(f"Milvus Insert Error: {e}")
  114. raise e
  115. return 0
  116. def process_zip_file(zip_path: str):
  117. """
  118. 同步处理 Zip 文件逻辑
  119. """
  120. extract_root = os.path.join(settings.TEMP_UPLOAD_DIR, f"extract_{uuid4().hex}")
  121. os.makedirs(extract_root, exist_ok=True)
  122. total_inserted = 0
  123. total_skipped = 0
  124. try:
  125. # 1. 解压
  126. print(f"Unzipping {zip_path}...")
  127. with zipfile.ZipFile(zip_path, 'r') as zip_ref:
  128. zip_ref.extractall(extract_root)
  129. # 准备批处理容器
  130. batch_images = []
  131. batch_metadata = []
  132. batch_temp_paths = []
  133. # 2. 遍历文件夹
  134. print("Scanning files...")
  135. for root, dirs, files in os.walk(extract_root):
  136. raw_folder_name = os.path.basename(root)
  137. # 修复乱码文件夹名
  138. folder_name = fix_text_encoding(raw_folder_name)
  139. meta = parse_folder_name(folder_name)
  140. # 只有当文件夹名称符合格式,且里面有文件时才处理
  141. if not meta:
  142. continue
  143. for file in files:
  144. if file.lower().endswith(('.png', '.jpg', '.jpeg', '.webp', '.bmp')):
  145. full_temp_path = os.path.join(root, file)
  146. # --- A. MD5 计算与去重 ---
  147. img_md5 = calculate_md5(full_temp_path)
  148. # 查询 Milvus 是否已存在该 MD5
  149. res = milvus_collection.query(
  150. expr=f'img_md5 == "{img_md5}"',
  151. output_fields=["id"]
  152. )
  153. if len(res) > 0:
  154. total_skipped += 1
  155. # 即使跳过,也要把临时文件标记清理(不需要特殊操作,最后删整个文件夹)
  156. continue
  157. # --- B. 构造新文件名 (MD5 + 后缀) ---
  158. # 获取原始后缀 (如 .png)
  159. file_ext = os.path.splitext(file)[1].lower()
  160. # 新文件名: 纯净的MD5 + 后缀
  161. new_filename = f"{img_md5}{file_ext}"
  162. # --- C. 读取图片用于推理 ---
  163. img_cv = cv2.imread(full_temp_path)
  164. if img_cv is None:
  165. print(f"Warning: Failed to read image {full_temp_path}")
  166. continue
  167. img_rgb = cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB)
  168. # --- D. 加入批处理队列 ---
  169. batch_images.append(img_rgb)
  170. batch_temp_paths.append(full_temp_path)
  171. batch_metadata.append({
  172. "meta": meta,
  173. "md5": img_md5,
  174. "new_filename": new_filename
  175. })
  176. # --- E. 达到 Batch Size,执行处理 ---
  177. if len(batch_images) >= settings.BATCH_SIZE:
  178. count = _execute_batch_insert(batch_images, batch_metadata, batch_temp_paths)
  179. total_inserted += count
  180. print(f"Processed batch, inserted: {count}")
  181. # 清空队列
  182. batch_images = []
  183. batch_metadata = []
  184. batch_temp_paths = []
  185. # 处理剩余未满一个 batch 的数据
  186. if batch_images:
  187. count = _execute_batch_insert(batch_images, batch_metadata, batch_temp_paths)
  188. total_inserted += count
  189. # 强制刷新 Milvus 确保数据可见
  190. milvus_collection.flush()
  191. print(f"Import finished. Inserted: {total_inserted}, Skipped: {total_skipped}")
  192. return total_inserted, total_skipped
  193. except Exception as e:
  194. print(f"Critical Error during import: {e}")
  195. raise e
  196. finally:
  197. # 清理解压的临时文件夹
  198. if os.path.exists(extract_root):
  199. shutil.rmtree(extract_root)
  200. # 清理上传的 zip 文件
  201. if os.path.exists(zip_path):
  202. os.remove(zip_path)
  203. @router.post("/batch", summary="上传宝可梦数据压缩包")
  204. def upload_zip(file: UploadFile = File(...)):
  205. """
  206. 上传 ZIP 存入向量库
  207. 大文件夹.zip
  208. 小文件夹必须遵循该格式: ('129873', {'us'}, 'Swadloon'), 2
  209. 图片...
  210. """
  211. if not file.filename.endswith(".zip"):
  212. raise HTTPException(status_code=400, detail="Only zip files are allowed.")
  213. # 保存上传的 zip
  214. file_location = os.path.join(settings.TEMP_UPLOAD_DIR, f"{uuid4().hex}_{file.filename}")
  215. try:
  216. with open(file_location, "wb+") as file_object:
  217. shutil.copyfileobj(file.file, file_object)
  218. except Exception as e:
  219. raise HTTPException(status_code=500, detail=f"Failed to save zip file: {e}")
  220. try:
  221. # 直接调用处理函数(同步等待)
  222. inserted, skipped = process_zip_file(file_location)
  223. return {
  224. "message": "Batch import completed successfully.",
  225. "data": {
  226. "inserted_count": inserted,
  227. "skipped_count": skipped,
  228. "total_processed": inserted + skipped
  229. }
  230. }
  231. except Exception as e:
  232. raise HTTPException(status_code=500, detail=f"Error processing zip: {str(e)}")