Browse Source

数据格式变动

AnlaAnla 6 days ago
parent
commit
5df6c45ebf
5 changed files with 80 additions and 47 deletions
  1. 2 1
      app/db/milvus_client.py
  2. 1 1
      app/routers/search.py
  3. 50 39
      app/routers/upload.py
  4. 5 4
      app/utils/parser.py
  5. 22 2
      run_PokemonCardSearch.py

+ 2 - 1
app/db/milvus_client.py

@@ -3,6 +3,7 @@ from app.core.config import settings
 
 
 def init_milvus():
+    print(f"🔌 连接 Milvus: {settings.MILVUS_HOST}:{settings.MILVUS_PORT}")
     connections.connect("default", host=settings.MILVUS_HOST, port=settings.MILVUS_PORT)
 
     if utility.has_collection(settings.COLLECTION_NAME):
@@ -16,7 +17,7 @@ def init_milvus():
         FieldSchema(name="source_id", dtype=DataType.VARCHAR, max_length=64),
         FieldSchema(name="lang", dtype=DataType.VARCHAR, max_length=10),
         FieldSchema(name="card_name", dtype=DataType.VARCHAR, max_length=256),
-        FieldSchema(name="card_num", dtype=DataType.VARCHAR, max_length=64),
+        FieldSchema(name="card_num", dtype=DataType.INT64),
         FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=settings.VECTOR_DIM),
     ]
 

+ 1 - 1
app/routers/search.py

@@ -13,7 +13,7 @@ from app.core.config import settings
 router = APIRouter(prefix="/search", tags=["Search"])
 
 
-@router.post("/image")
+@router.post("/image", summary="相似宝可梦图像搜索")
 async def search_image(file: UploadFile = File(...), top_k: int = 5):
     # 1. 读取图片
     content = await file.read()

+ 50 - 39
app/routers/upload.py

@@ -25,25 +25,17 @@ def calculate_md5(file_path):
 
 
 def _execute_batch_insert(images, metadata_list, temp_file_paths):
-    """
-    执行 YOLO -> ViT -> Milvus 插入 -> 图片转存
-    Args:
-        images: RGB numpy 图片列表
-        metadata_list: 包含 meta 信息, md5, 和 new_filename
-        temp_file_paths: 解压目录下的临时文件路径
-    """
     if not images: return 0
 
-    # 1. YOLO 批量预测
+    # 1. YOLO 预测
     try:
         global_models.yolo.predict_batch(images)
-        # 获取裁剪后的图片 (如果没有检测到,内部逻辑会返回原图)
         cropped_imgs = global_models.yolo.get_max_img_list(cls_id=0)
     except Exception as e:
         print(f"Error in YOLO prediction: {e}")
         return 0
 
-    # 2. ViT 批量提取特征
+    # 2. ViT 提取特征
     try:
         vectors = global_models.vit.run(cropped_imgs, normalize=True)
     except Exception as e:
@@ -51,14 +43,15 @@ def _execute_batch_insert(images, metadata_list, temp_file_paths):
         return 0
 
     # 3. 准备 Milvus 数据列
+    # Schema: img_md5, img_path, source_id, lang, card_name, card_num, vector
     entities = [
-        [],  # vector
-        [],  # img_md5
-        [],  # img_path (存储相对路径或文件名,方便前端展示)
-        [],  # source_id
-        [],  # lang
-        [],  # card_name
-        [],  # card_num
+        [],  # 0: img_md5 (VARCHAR)
+        [],  # 1: img_path (VARCHAR)
+        [],  # 2: source_id (VARCHAR)
+        [],  # 3: lang (VARCHAR)
+        [],  # 4: card_name (VARCHAR)
+        [],  # 5: card_num (INT64)
+        [],  # 6: vector (FLOAT_VECTOR)
     ]
 
     for i, meta_data in enumerate(metadata_list):
@@ -66,32 +59,53 @@ def _execute_batch_insert(images, metadata_list, temp_file_paths):
         info = meta_data["meta"]
         new_filename = meta_data["new_filename"]
 
-        # --- 图片转存逻辑 (使用 MD5 命名) ---
-        # 源文件位置
+        # 图片转存
         src_path = temp_file_paths[i]
-        # 目标文件位置: static/images/{md5}.ext
         dest_path = os.path.join(settings.IMAGE_STORAGE_DIR, new_filename)
-
-        # 移动文件 (比 copy 快,因为源文件是临时解压的,后面会删除)
-        # 如果目标文件已存在(理论上MD5去重后不会,但防万一),覆盖它
         shutil.move(src_path, dest_path)
-
-        # 存入数据库的是相对路径,或者直接存文件名,看你前端怎么拼接
-        # 这里存相对路径 images/xxxx.png
         db_img_path = f"images/{new_filename}"
 
-        entities[0].append(vec)
-        entities[1].append(meta_data["md5"])
-        entities[2].append(db_img_path)
-        entities[3].append(info["source_id"])
-        entities[4].append(info["lang"])
-        entities[5].append(info["card_name"])
-        entities[6].append(info["card_num"])
+        # --- 处理 card_num ---
+        raw_num = info.get("card_num", "")
+        final_num = -1  # 默认值,代表 None
+
+        if raw_num:
+            # 去掉可能存在的空格
+            raw_num_str = str(raw_num).strip()
+            # 如果是 "None" 字符串 或者为空
+            if raw_num_str.lower() == "none" or raw_num_str == "":
+                final_num = -1
+            else:
+                # 尝试提取数字 (比如 "004" -> 4)
+                try:
+                    # 过滤掉非数字字符 (防止有 'No.004' 这种写法)
+                    import re
+                    # 只提取数字部分
+                    digits = re.findall(r'\d+', raw_num_str)
+                    if digits:
+                        final_num = int(digits[0])
+                    else:
+                        final_num = -1
+                except:
+                    final_num = -1
+        # ----------------------------
+
+        entities[0].append(meta_data["md5"])  # img_md5
+        entities[1].append(db_img_path)  # img_path
+        entities[2].append(info["source_id"])  # source_id
+        entities[3].append(info["lang"])  # lang
+        entities[4].append(info["card_name"])  # card_name
+        entities[5].append(final_num)  # card_num (INT)
+        entities[6].append(vec)  # vector (最后)
 
     # 4. 插入 Milvus
     if entities[0]:
-        milvus_collection.insert(entities)
-        return len(entities[0])
+        try:
+            milvus_collection.insert(entities)
+            return len(entities[0])
+        except Exception as e:
+            print(f"Milvus Insert Error: {e}")
+            raise e
     return 0
 
 
@@ -202,10 +216,7 @@ def process_zip_file(zip_path: str):
 @router.post("/batch")
 def upload_zip(file: UploadFile = File(...)):
     """
-    上传 ZIP 并同步等待处理完成。
-    注意:这里使用 def 而不是 async def,
-    FastAPI 会自动将其放入线程池运行,不会阻塞主事件循环,
-    但当前的 HTTP 请求会一直挂起直到处理完成。
+    上传 ZIP 存入向量库
     """
     if not file.filename.endswith(".zip"):
         raise HTTPException(status_code=400, detail="Only zip files are allowed.")

+ 5 - 4
app/utils/parser.py

@@ -3,11 +3,11 @@ import re
 
 def parse_folder_name(folder_name: str):
     """
-    解析格式: ('129873', {'us'}, 'Swadloon'), 2
-    返回: (source_id, lang, card_name, card_num)
+    解析格式:
+    1. ('129873', {'us'}, 'Swadloon'), 2
+    2. ('2150297', {'us'}, 'Grimsley’s Move'), None
     """
-    # 这是一个比较宽松的正则,适应你的格式
-    # Group 1: source_id, Group 2: lang, Group 3: name, Group 4: card_num
+    # 这里的正则匹配最后的 , 之后的所有内容作为 card_num
     pattern = r"\('(.+?)', \{'(.+?)'\}, '(.+?)'\),\s*(.+)"
 
     match = re.search(pattern, folder_name)
@@ -16,6 +16,7 @@ def parse_folder_name(folder_name: str):
             "source_id": match.group(1),
             "lang": match.group(2),
             "card_name": match.group(3),
+            # 这里获取到的可能是 "2" 或者 "None" 或者 "004"
             "card_num": match.group(4).strip()
         }
     return None

+ 22 - 2
run_PokemonCardSearch.py

@@ -1,8 +1,28 @@
 import uvicorn
 import socket
 
+
+def get_host_ip():
+    """
+    查询本机ip地址
+    """
+    try:
+        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+        s.connect(('8.8.8.8', 80))
+        ip = s.getsockname()[0]
+    except Exception:
+        # 如果没有网络,回退到 host name
+        ip = socket.gethostbyname(socket.gethostname())
+    finally:
+        s.close()
+    return ip
+
+
 if __name__ == "__main__":
-    ipv4 = socket.gethostbyname(socket.gethostname())
+    ip = get_host_ip()
     port = 18082
-    print(f"http://{ipv4}:{port}/docs")
+
+    print(f" Server running on: http://{ip}:{port}")
+    print(f" Docs available at: http://{ip}:{port}/docs")
+
     uvicorn.run("app.main:app", host="0.0.0.0", port=port)