AnlaAnla 6 dias atrás
commit
85f0ea960a

+ 1 - 0
.gitignore

@@ -0,0 +1 @@
+/uploads/

+ 8 - 0
.idea/.gitignore

@@ -0,0 +1,8 @@
+# 默认忽略的文件
+/shelf/
+/workspace.xml
+# 基于编辑器的 HTTP 客户端请求
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml

+ 8 - 0
.idea/PokemonCardSearch.iml

@@ -0,0 +1,8 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<module type="PYTHON_MODULE" version="4">
+  <component name="NewModuleRootManager">
+    <content url="file://$MODULE_DIR$" />
+    <orderEntry type="inheritedJdk" />
+    <orderEntry type="sourceFolder" forTests="false" />
+  </component>
+</module>

+ 35 - 0
.idea/deployment.xml

@@ -0,0 +1,35 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="PublishConfigData" remoteFilesAllowedToDisappearOnAutoupload="false">
+    <serverData>
+      <paths name="192.168.31.243">
+        <serverdata>
+          <mappings>
+            <mapping local="$PROJECT_DIR$" web="/" />
+          </mappings>
+        </serverdata>
+      </paths>
+      <paths name="martin@192.168.77.66:22 password">
+        <serverdata>
+          <mappings>
+            <mapping local="$PROJECT_DIR$" web="/" />
+          </mappings>
+        </serverdata>
+      </paths>
+      <paths name="martin@192.168.77.78">
+        <serverdata>
+          <mappings>
+            <mapping local="$PROJECT_DIR$" web="/" />
+          </mappings>
+        </serverdata>
+      </paths>
+      <paths name="显卡服务器@192.168.77.249">
+        <serverdata>
+          <mappings>
+            <mapping local="$PROJECT_DIR$" web="/" />
+          </mappings>
+        </serverdata>
+      </paths>
+    </serverData>
+  </component>
+</project>

+ 59 - 0
.idea/inspectionProfiles/Project_Default.xml

@@ -0,0 +1,59 @@
+<component name="InspectionProjectProfileManager">
+  <profile version="1.0">
+    <option name="myName" value="Project Default" />
+    <inspection_tool class="DuplicatedCode" enabled="true" level="WEAK WARNING" enabled_by_default="true">
+      <Languages>
+        <language minSize="102" name="Python" />
+      </Languages>
+    </inspection_tool>
+    <inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
+    <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
+      <option name="ignoredPackages">
+        <value>
+          <list size="31">
+            <item index="0" class="java.lang.String" itemvalue="webargs" />
+            <item index="1" class="java.lang.String" itemvalue="transformers" />
+            <item index="2" class="java.lang.String" itemvalue="timm" />
+            <item index="3" class="java.lang.String" itemvalue="fluent-logger" />
+            <item index="4" class="java.lang.String" itemvalue="towhee" />
+            <item index="5" class="java.lang.String" itemvalue="flask_restful" />
+            <item index="6" class="java.lang.String" itemvalue="opencv_python" />
+            <item index="7" class="java.lang.String" itemvalue="fastapi" />
+            <item index="8" class="java.lang.String" itemvalue="seaborn" />
+            <item index="9" class="java.lang.String" itemvalue="matplotlib" />
+            <item index="10" class="java.lang.String" itemvalue="minio" />
+            <item index="11" class="java.lang.String" itemvalue="ipython" />
+            <item index="12" class="java.lang.String" itemvalue="torch" />
+            <item index="13" class="java.lang.String" itemvalue="uvicorn" />
+            <item index="14" class="java.lang.String" itemvalue="python-multipart" />
+            <item index="15" class="java.lang.String" itemvalue="torchvision" />
+            <item index="16" class="java.lang.String" itemvalue="pymilvus" />
+            <item index="17" class="java.lang.String" itemvalue="psutil" />
+            <item index="18" class="java.lang.String" itemvalue="ultralytics" />
+            <item index="19" class="java.lang.String" itemvalue="picamera2" />
+            <item index="20" class="java.lang.String" itemvalue="posix_ipc" />
+            <item index="21" class="java.lang.String" itemvalue="websocket-client" />
+            <item index="22" class="java.lang.String" itemvalue="yolov10" />
+            <item index="23" class="java.lang.String" itemvalue="kornia" />
+            <item index="24" class="java.lang.String" itemvalue="prettytable" />
+            <item index="25" class="java.lang.String" itemvalue="huggingface_hub" />
+            <item index="26" class="java.lang.String" itemvalue="PIL" />
+            <item index="27" class="java.lang.String" itemvalue="sklearn" />
+            <item index="28" class="java.lang.String" itemvalue="faster_whisper" />
+            <item index="29" class="java.lang.String" itemvalue="pyserial" />
+            <item index="30" class="java.lang.String" itemvalue="requests" />
+          </list>
+        </value>
+      </option>
+    </inspection_tool>
+    <inspection_tool class="PyPep8NamingInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
+      <option name="ignoredErrors">
+        <list>
+          <option value="N803" />
+          <option value="N802" />
+          <option value="N806" />
+        </list>
+      </option>
+    </inspection_tool>
+  </profile>
+</component>

+ 6 - 0
.idea/inspectionProfiles/profiles_settings.xml

@@ -0,0 +1,6 @@
+<component name="InspectionProjectProfileManager">
+  <settings>
+    <option name="USE_PROJECT_PROFILE" value="false" />
+    <version value="1.0" />
+  </settings>
+</component>

+ 7 - 0
.idea/misc.xml

@@ -0,0 +1,7 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="Black">
+    <option name="sdkName" value="pytorch" />
+  </component>
+  <component name="ProjectRootManager" version="2" project-jdk-name="pytorch" project-jdk-type="Python SDK" />
+</project>

+ 8 - 0
.idea/modules.xml

@@ -0,0 +1,8 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="ProjectModuleManager">
+    <modules>
+      <module fileurl="file://$PROJECT_DIR$/.idea/PokemonCardSearch.iml" filepath="$PROJECT_DIR$/.idea/PokemonCardSearch.iml" />
+    </modules>
+  </component>
+</project>

+ 6 - 0
.idea/vcs.xml

@@ -0,0 +1,6 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="VcsDirectoryMappings">
+    <mapping directory="$PROJECT_DIR$" vcs="Git" />
+  </component>
+</project>

+ 0 - 0
app/__init__.py


+ 0 - 0
app/core/__init__.py


+ 31 - 0
app/core/config.py

@@ -0,0 +1,31 @@
+import os
+
+
+class Settings:
+    PROJECT_NAME = "Pokemon Card Search"
+
+    # 路径配置
+    BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+    STATIC_DIR = os.path.join(BASE_DIR, "app", "static")
+    IMAGE_STORAGE_DIR = os.path.join(STATIC_DIR, "images")
+    TEMP_UPLOAD_DIR = os.path.join(BASE_DIR, "uploads")
+
+    # 模型路径 (请修改为你实际的绝对路径)
+    YOLO_MODEL_PATH = "/path/to/your/yolo.pt"
+    VIT_MODEL_PATH = "/path/to/your/vit/model/folder"
+
+    # Milvus 配置
+    MILVUS_HOST = "127.0.0.1"
+    MILVUS_PORT = "19530"
+    COLLECTION_NAME = "pokemon_cards"
+    VECTOR_DIM = 768  # ViT Base 通常是 768
+
+    # 批处理大小
+    BATCH_SIZE = 32
+
+    def __init__(self):
+        os.makedirs(self.IMAGE_STORAGE_DIR, exist_ok=True)
+        os.makedirs(self.TEMP_UPLOAD_DIR, exist_ok=True)
+
+
+settings = Settings()

+ 20 - 0
app/core/models_loader.py

@@ -0,0 +1,20 @@
+from app.utils.MyBatchOnnxYolo import MyBatchOnnxYolo
+from app.utils.MyViTFeatureExtractor import MyViTFeatureExtractor
+from app.core.config import settings
+
+
+class GlobalModels:
+    yolo: MyBatchOnnxYolo = None
+    vit: MyViTFeatureExtractor = None
+
+    @classmethod
+    def load_models(cls):
+        print("⏳ Loading YOLO Model...")
+        cls.yolo = MyBatchOnnxYolo(settings.YOLO_MODEL_PATH, task='segment')  # 或 detect
+
+        print("⏳ Loading ViT Model...")
+        cls.vit = MyViTFeatureExtractor(settings.VIT_MODEL_PATH)
+        print("✅ Models Loaded Successfully.")
+
+
+global_models = GlobalModels()

+ 0 - 0
app/db/__init__.py


+ 42 - 0
app/db/milvus_client.py

@@ -0,0 +1,42 @@
+from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
+from app.core.config import settings
+
+
+def init_milvus():
+    connections.connect("default", host=settings.MILVUS_HOST, port=settings.MILVUS_PORT)
+
+    if utility.has_collection(settings.COLLECTION_NAME):
+        return Collection(settings.COLLECTION_NAME)
+
+    # 定义字段
+    fields = [
+        FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
+        FieldSchema(name="img_md5", dtype=DataType.VARCHAR, max_length=64),  # 用于去重
+        FieldSchema(name="img_path", dtype=DataType.VARCHAR, max_length=512),
+        FieldSchema(name="source_id", dtype=DataType.VARCHAR, max_length=64),
+        FieldSchema(name="lang", dtype=DataType.VARCHAR, max_length=10),
+        FieldSchema(name="card_name", dtype=DataType.VARCHAR, max_length=256),
+        FieldSchema(name="card_num", dtype=DataType.VARCHAR, max_length=64),
+        FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=settings.VECTOR_DIM),
+    ]
+
+    schema = CollectionSchema(fields, "Pokemon Card Search Collection")
+    collection = Collection(settings.COLLECTION_NAME, schema)
+
+    # 创建索引
+    index_params = {
+        "metric_type": "COSINE",
+        "index_type": "HNSW",
+        "params": {"M": 8, "efConstruction": 64}
+    }
+    collection.create_index(field_name="vector", index_params=index_params)
+
+    # 为 MD5 创建标量索引,加速去重查询
+    collection.create_index(field_name="img_md5", index_name="idx_md5")
+
+    collection.load()
+    return collection
+
+
+# 全局实例
+milvus_collection = init_milvus()

+ 42 - 0
app/main.py

@@ -0,0 +1,42 @@
+from fastapi import FastAPI
+from fastapi.staticfiles import StaticFiles
+from fastapi.middleware.cors import CORSMiddleware
+from contextlib import asynccontextmanager
+
+from app.core.models_loader import global_models
+from app.routers import search, upload, view
+from app.core.config import settings
+
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+    # 启动时加载模型
+    global_models.load_models()
+    yield
+    # 关闭时清理资源
+    pass
+
+
+app = FastAPI(title="Pokemon Card Search", lifespan=lifespan)
+
+app.add_middleware(
+    CORSMiddleware,
+    allow_origins=["*"],
+    allow_credentials=True,
+    allow_methods=["*"],
+    allow_headers=["*"],
+)
+
+
+# 挂载静态文件
+app.mount("/static", StaticFiles(directory=settings.STATIC_DIR), name="static")
+
+# 注册路由
+app.include_router(search.router)
+app.include_router(upload.router)
+app.include_router(view.router)
+
+if __name__ == "__main__":
+    import uvicorn
+
+    uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)

+ 0 - 0
app/routers/__init__.py


+ 99 - 0
app/routers/search.py

@@ -0,0 +1,99 @@
+from fastapi import APIRouter, UploadFile, File, Form, HTTPException
+from typing import Optional, List
+import cv2
+import numpy as np
+from PIL import Image
+import io
+import shutil
+
+from app.core.models_loader import global_models
+from app.db.milvus_client import milvus_collection
+from app.core.config import settings
+
+router = APIRouter(prefix="/search", tags=["Search"])
+
+
+@router.post("/image")
+async def search_image(file: UploadFile = File(...), top_k: int = 5):
+    # 1. 读取图片
+    content = await file.read()
+    image = Image.open(io.BytesIO(content)).convert("RGB")
+    image_np = np.array(image)
+
+    # 2. YOLO 截取 (单张处理)
+    # 注意:MyBatchOnnxYolo 需要 List 输入
+    global_models.yolo.predict_batch([image_np])
+    # 提取第0张图的 crop (默认 cls_id=0)
+    cropped_img = global_models.yolo.get_max_img(0, cls_id=0)
+
+    if cropped_img is None:
+        # 如果没检测到,使用原图
+        cropped_img = image_np
+
+    # 3. ViT 提取特征
+    # MyViTFeatureExtractor 接收 List[Union[str, np.ndarray, Image]]
+    vectors = global_models.vit.run([cropped_img], normalize=True)
+    query_vector = vectors[0].tolist()
+
+    # 4. Milvus 搜索
+    search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
+    results = milvus_collection.search(
+        data=[query_vector],
+        anns_field="vector",
+        param=search_params,
+        limit=top_k,
+        output_fields=["id", "img_path", "card_name", "card_num", "lang", "source_id"]
+    )
+
+    # 5. 格式化结果
+    formatted_results = []
+    for hits in results:
+        for hit in hits:
+            # 将本地路径转换为相对 URL 以便前端访问
+            full_path = hit.entity.get("img_path")
+            web_path = f"/static/images/{full_path.split('/')[-1]}" if full_path else ""
+
+            formatted_results.append({
+                "id": hit.id,
+                "score": hit.score,
+                "card_name": hit.entity.get("card_name"),
+                "card_num": hit.entity.get("card_num"),
+                "lang": hit.entity.get("lang"),
+                "img_url": web_path
+            })
+
+    return {"results": formatted_results}
+
+
+@router.post("/filter")
+async def filter_database(
+        card_name: Optional[str] = None,
+        source_id: Optional[str] = None,
+        limit: int = 20
+):
+    """数据库元数据过滤"""
+    expr_list = []
+    if card_name:
+        # 模糊查询
+        expr_list.append(f'card_name like "{card_name}%"')
+    if source_id:
+        expr_list.append(f'source_id == "{source_id}"')
+
+    expr = " && ".join(expr_list) if expr_list else ""
+
+    # 如果没有条件,查询所有(慎用,有限制)
+    if not expr:
+        expr = "id > 0"
+
+    res = milvus_collection.query(
+        expr=expr,
+        output_fields=["id", "img_path", "card_name", "card_num", "lang"],
+        limit=limit
+    )
+
+    # 格式化路径
+    for item in res:
+        full_path = item.pop("img_path", "")
+        item["img_url"] = f"/static/images/{full_path.split('/')[-1]}" if full_path else ""
+
+    return {"count": len(res), "data": res}

+ 176 - 0
app/routers/upload.py

@@ -0,0 +1,176 @@
+from fastapi import APIRouter, UploadFile, File, BackgroundTasks
+import os
+import zipfile
+import hashlib
+import shutil
+import cv2
+import glob
+from uuid import uuid4
+from PIL import Image
+
+from app.core.config import settings
+from app.core.models_loader import global_models
+from app.db.milvus_client import milvus_collection
+from app.utils.parser import parse_folder_name
+
+router = APIRouter(prefix="/upload", tags=["Upload"])
+
+
+def calculate_md5(file_path):
+    hash_md5 = hashlib.md5()
+    with open(file_path, "rb") as f:
+        for chunk in iter(lambda: f.read(4096), b""):
+            hash_md5.update(chunk)
+    return hash_md5.hexdigest()
+
+
+def process_batch_import(zip_path: str):
+    """后台任务:处理解压和导入"""
+    extract_root = os.path.join(settings.TEMP_UPLOAD_DIR, f"extract_{uuid4().hex}")
+    os.makedirs(extract_root, exist_ok=True)
+
+    try:
+        # 1. 解压
+        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
+            zip_ref.extractall(extract_root)
+
+        # 准备批处理列表
+        batch_images = []  # 存放 numpy 图片用于推理
+        batch_metadata = []  # 存放元数据
+        batch_file_paths = []  # 存放源文件路径 (用于移动)
+
+        # 2. 遍历文件夹
+        # 假设解压后结构: extract_root/folder_name/image.png
+        # 需要递归查找,因为压缩包内可能有一层根目录
+        for root, dirs, files in os.walk(extract_root):
+            folder_name = os.path.basename(root)
+            meta = parse_folder_name(folder_name)
+
+            # 如果当前目录不是目标数据目录,跳过
+            if not meta and files:
+                # 尝试看上一级目录(有的压缩包解压会多一层)
+                continue
+
+            if not meta: continue
+
+            for file in files:
+                if file.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
+                    full_path = os.path.join(root, file)
+
+                    # 2.1 MD5 检查
+                    img_md5 = calculate_md5(full_path)
+
+                    # 查询 Milvus 是否已存在该 MD5
+                    res = milvus_collection.query(
+                        expr=f'img_md5 == "{img_md5}"',
+                        output_fields=["id"]
+                    )
+                    if len(res) > 0:
+                        print(f"Skipping duplicate: {file}")
+                        continue
+
+                    # 读取图片
+                    # cv2 读取用于 YOLO,注意转换 RGB
+                    img_cv = cv2.imread(full_path)
+                    if img_cv is None: continue
+                    img_rgb = cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB)
+
+                    batch_images.append(img_rgb)
+                    batch_metadata.append({
+                        "meta": meta,
+                        "md5": img_md5,
+                        "filename": f"{img_md5}_{file}"  # 重命名防止覆盖
+                    })
+                    batch_file_paths.append(full_path)
+
+                    # 3. 达到 Batch Size,执行推理和插入
+                    if len(batch_images) >= settings.BATCH_SIZE:
+                        _execute_batch_insert(batch_images, batch_metadata, batch_file_paths)
+                        # 清空
+                        batch_images = []
+                        batch_metadata = []
+                        batch_file_paths = []
+
+        # 处理剩余的
+        if batch_images:
+            _execute_batch_insert(batch_images, batch_metadata, batch_file_paths)
+
+        print("Batch import finished.")
+
+    except Exception as e:
+        print(f"Error during import: {e}")
+    finally:
+        # 清理临时文件
+        if os.path.exists(extract_root):
+            shutil.rmtree(extract_root)
+        if os.path.exists(zip_path):
+            os.remove(zip_path)
+
+
+def _execute_batch_insert(images, metadata_list, original_paths):
+    """
+    辅助函数:执行具体的推理和数据库插入
+    """
+    if not images: return
+
+    # A. YOLO 批量预测
+    global_models.yolo.predict_batch(images)
+
+    # B. 获取裁剪后的图片列表
+    cropped_imgs = global_models.yolo.get_max_img_list(cls_id=0)  # 假设卡片 cls_id=0
+
+    # C. ViT 批量提取特征 (注意:ViT 需要接收 List[np.ndarray] 或 PIL)
+    # 处理 get_max_img_list 可能返回的 None (尽管原始代码里如果没有检测到会返回原图)
+    valid_crop_imgs = []
+    # 如果 MyBatchOnnxYolo 在没检测到时返回了 None,这里需要兜底
+    # 但根据你的代码,它返回了 orig_img,所以应该是安全的 numpy 数组
+    vectors = global_models.vit.run(cropped_imgs, normalize=True)
+
+    # D. 准备插入数据
+    entities = [
+        [],  # vector
+        [],  # img_md5
+        [],  # img_path
+        [],  # source_id
+        [],  # lang
+        [],  # card_name
+        [],  # card_num
+    ]
+
+    for i, meta_data in enumerate(metadata_list):
+        vec = vectors[i].tolist()
+        info = meta_data["meta"]
+        new_filename = meta_data["filename"]
+
+        # 保存原始图片到 static/images
+        dest_path = os.path.join(settings.IMAGE_STORAGE_DIR, new_filename)
+        shutil.copy(original_paths[i], dest_path)
+
+        entities[0].append(vec)
+        entities[1].append(meta_data["md5"])
+        entities[2].append(dest_path)  # 存储绝对路径或相对路径
+        entities[3].append(info["source_id"])
+        entities[4].append(info["lang"])
+        entities[5].append(info["card_name"])
+        entities[6].append(info["card_num"])
+
+    # E. 插入 Milvus
+    # 注意 pymilvus insert 格式: [ [vec1, vec2], [md5_1, md5_2], ... ]
+    milvus_collection.insert(entities)
+    milvus_collection.flush()  # 生产环境可以不每次 flush,定期 flush
+    print(f"Inserted {len(images)} records.")
+
+
+@router.post("/batch")
+async def upload_zip(background_tasks: BackgroundTasks, file: UploadFile = File(...)):
+    if not file.filename.endswith(".zip"):
+        return {"error": "Only zip files are allowed."}
+
+    file_location = os.path.join(settings.TEMP_UPLOAD_DIR, f"{uuid4().hex}_{file.filename}")
+    with open(file_location, "wb+") as file_object:
+        file_object.write(await file.read())
+
+    # 后台运行,不阻塞 API
+    background_tasks.add_task(process_batch_import, file_location)
+
+    return {"message": "File uploaded. Processing started in background."}

+ 15 - 0
app/routers/view.py

@@ -0,0 +1,15 @@
+from fastapi import APIRouter
+from fastapi.responses import HTMLResponse
+import os
+from app.core.config import settings
+
+router = APIRouter()
+
+@router.get("/", response_class=HTMLResponse)
+async def read_root():
+    # 读取 static/index.html 内容返回
+    index_path = os.path.join(settings.STATIC_DIR, "index.html")
+    if os.path.exists(index_path):
+        with open(index_path, 'r', encoding='utf-8') as f:
+            return f.read()
+    return "<h1>Index.html not found</h1>"

+ 0 - 0
app/static/__init__.py


+ 96 - 0
app/static/index.html

@@ -0,0 +1,96 @@
+<!DOCTYPE html>
+<html lang="zh">
+<head>
+    <meta charset="UTF-8">
+    <title>宝可梦卡片搜索</title>
+    <style>
+        body { font-family: sans-serif; max-width: 1200px; margin: 0 auto; padding: 20px; }
+        .container { display: flex; gap: 20px; }
+        .upload-section { flex: 1; border-right: 1px solid #ccc; padding-right: 20px; }
+        .results-section { flex: 2; }
+        .card { border: 1px solid #ddd; padding: 10px; margin-bottom: 10px; display: flex; align-items: center; border-radius: 8px; }
+        .card img { max-width: 100px; max-height: 140px; margin-right: 20px; }
+        .card-info { font-size: 14px; }
+        .similarity { color: green; font-weight: bold; }
+    </style>
+</head>
+<body>
+    <h1>🔍 Pokemon Card Search</h1>
+
+    <div class="container">
+        <div class="upload-section">
+            <h3>图片搜索</h3>
+            <input type="file" id="imgInput" accept="image/*">
+            <br><br>
+            <label>Top N: <input type="number" id="topN" value="5" style="width: 50px;"></label>
+            <br><br>
+            <button onclick="searchImage()">搜索</button>
+            <br><hr><br>
+            <h3>文字过滤</h3>
+            <input type="text" id="nameFilter" placeholder="卡名 (如 Swadloon)">
+            <button onclick="filterData()">查询</button>
+        </div>
+
+        <div class="results-section">
+            <h3>结果</h3>
+            <div id="results"></div>
+        </div>
+    </div>
+
+    <script>
+        async function searchImage() {
+            const input = document.getElementById('imgInput');
+            const topN = document.getElementById('topN').value;
+            if (!input.files[0]) return alert("请选择图片");
+
+            const formData = new FormData();
+            formData.append('file', input.files[0]);
+
+            const res = await fetch(`/search/image?top_k=${topN}`, {
+                method: 'POST',
+                body: formData
+            });
+            const data = await res.json();
+            renderResults(data.results);
+        }
+
+        async function filterData() {
+            const name = document.getElementById('nameFilter').value;
+            const res = await fetch(`/search/filter?card_name=${name}&limit=10`, {
+                method: 'POST'
+            });
+            const data = await res.json();
+            renderResults(data.data, false);
+        }
+
+        function renderResults(items, isSearch=true) {
+            const container = document.getElementById('results');
+            container.innerHTML = '';
+
+            if (!items || items.length === 0) {
+                container.innerHTML = '<p>无结果</p>';
+                return;
+            }
+
+            items.forEach(item => {
+                const div = document.createElement('div');
+                div.className = 'card';
+
+                // 计算显示内容
+                let scoreHtml = isSearch ? `<p class="similarity">相似度: ${(item.score).toFixed(4)}</p>` : '';
+
+                div.innerHTML = `
+                    <img src="${item.img_url}" alt="${item.card_name}" onerror="this.src=''">
+                    <div class="card-info">
+                        <h3>${item.card_name} #${item.card_num}</h3>
+                        <p>语言: ${item.lang} | Source ID: ${item.source_id}</p>
+                        ${scoreHtml}
+                        <p style="color:gray; font-size:12px;">ID: ${item.id}</p>
+                    </div>
+                `;
+                container.appendChild(div);
+            });
+        }
+    </script>
+</body>
+</html>

+ 164 - 0
app/utils/MyBatchOnnxYolo.py

@@ -0,0 +1,164 @@
+import cv2
+import numpy as np
+from ultralytics import YOLO
+from typing import List, Union, Optional
+import PIL.Image
+
+ImageType = Union[str, np.ndarray, PIL.Image.Image]
+
+
+class MyBatchOnnxYolo:
+    """
+    使用 YOLO 模型进行批处理目标检测/分割,并提供提取最大目标区域的功能。
+    cls_id {card:0} - 根据你的模型调整
+    """
+
+    def __init__(self, model_path: str, task: str = 'segment', verbose: bool = False):
+        # 加载yolo model
+        self.model = YOLO(model_path, task=task, verbose=verbose)
+        self.results: Optional[List] = None  # 将存储批处理的结果列表
+        self.batch_size: int = 0
+
+    def predict_batch(self, image_list: List[ImageType], imgsz: int = 640, **kwargs):
+        """
+        对一批图像进行预测。
+
+        Args:
+            image_list (List[ImageType]): 包含图像路径、PIL Image 或 NumPy 数组的列表。
+            imgsz (int): 推理的图像尺寸。
+            **kwargs: 其他传递给 model.predict 的参数 (例如 conf, iou)。
+        """
+        if not image_list:
+            print("Warning: Input image list is empty.")
+            self.results = []
+            self.batch_size = 0
+            return
+
+        # 使用 YOLO 的批处理能力
+        self.results = self.model.predict(image_list, verbose=False, imgsz=imgsz, **kwargs)
+        self.batch_size = len(self.results)
+
+    def get_batch_size(self) -> int:
+        return self.batch_size
+
+    def _get_result_at_index(self, index: int):
+        """内部辅助方法,获取指定索引的结果,并进行边界检查。"""
+        if self.results is None:
+            raise ValueError("Must call predict_batch() before accessing results.")
+        if not (0 <= index < self.batch_size):
+            raise IndexError(f"Index {index} is out of bounds for batch size {self.batch_size}.")
+        return self.results[index]
+
+    def check(self, index: int, cls_id: int) -> bool:
+        """
+        检查指定索引的图像结果中是否存在特定的类别ID。
+
+        Args:
+            index (int): 图像在批处理中的索引 (从0开始)。
+            cls_id (int): 要检查的类别ID。
+
+        Returns:
+            bool: 如果存在该类别ID,则返回 True,否则返回 False。
+        """
+        result = self._get_result_at_index(index)
+        if result.boxes is None or len(result.boxes) == 0:
+            return False
+        # .cls 可能为空 Tensor,需要检查
+        return result.boxes.cls is not None and cls_id in result.boxes.cls.cpu().tolist()
+
+    def get_max_img(self, index: int, cls_id: int = 0) -> Optional[np.ndarray]:
+        """
+        从指定索引的图像结果中,提取指定类别ID的最大边界框对应的图像区域。
+
+        Args:
+            index (int): 图像在批处理中的索引 (从0开始)。
+            cls_id (int): 要提取的目标类别ID 默认0
+
+        Returns:
+            Optional[np.ndarray]: 裁剪出的最大目标的图像区域 (RGB NumPy 数组),
+                                   如果未找到该类别或无检测结果,则返回原始图像。
+        """
+        result = self._get_result_at_index(index)
+        orig_img = result.orig_img  # 通常是 BGR NumPy 数组
+        boxes = result.boxes
+
+        # 检查是否有检测框以及是否有对应的类别
+        if boxes is None or len(boxes) == 0 or boxes.cls is None or cls_id not in boxes.cls.cpu():
+            print(
+                f"Warning: No detections or cls_id {cls_id} not found for image at index {index}. Returning original image.")
+            # 返回原始图像的 RGB 版本
+            return cv2.cvtColor(orig_img, cv2.COLOR_BGR2RGB) if orig_img is not None else None
+
+        max_area = 0.0
+        max_box = None
+
+        xyxy_boxes = boxes.xyxy.cpu().numpy()
+        cls_list = boxes.cls.cpu().numpy()
+
+        # 选出最大的目标框
+        for i, box in enumerate(xyxy_boxes):
+            if cls_list[i] != cls_id:
+                continue
+
+            temp_x1, temp_y1, temp_x2, temp_y2 = box
+            area = (temp_x2 - temp_x1) * (temp_y2 - temp_y1)
+            if area > max_area:
+                max_area = area
+                max_box = box
+
+        # 如果没有找到对应 cls_id 的框 (理论上前面已检查,但多一层保险)
+        if max_box is None:
+            print(
+                f"Warning: cls_id {cls_id} found in cls_list but failed to find max box for image at index {index}. Returning original image.")
+            return cv2.cvtColor(orig_img, cv2.COLOR_BGR2RGB) if orig_img is not None else None
+
+        x1, y1, x2, y2 = map(int, max_box)  # 转换为整数坐标
+
+        # 边界处理,防止裁剪坐标超出图像范围
+        h, w = orig_img.shape[:2]
+        x1 = max(0, x1)
+        y1 = max(0, y1)
+        x2 = min(w, x2)
+        y2 = min(h, y2)
+
+        # 检查裁剪区域是否有效
+        if x1 >= x2 or y1 >= y2:
+            print(
+                f"Warning: Invalid crop dimensions [{y1}:{y2}, {x1}:{x2}] for image at index {index}. Returning original image.")
+            return cv2.cvtColor(orig_img, cv2.COLOR_BGR2RGB) if orig_img is not None else None
+
+        # 裁剪图像 (orig_img 通常是 BGR)
+        max_img_crop = orig_img[y1:y2, x1:x2]
+
+        # 将裁剪结果转换为 RGB (与 matplotlib 和 PIL 更兼容)
+        max_img_rgb = cv2.cvtColor(max_img_crop, cv2.COLOR_BGR2RGB)
+
+        return max_img_rgb
+
+    def get_max_img_list(self, cls_id: int = 0) -> List[Optional[np.ndarray]]:
+        """
+        对批处理中的每张图片,提取指定类别ID的最大边界框对应的图像区域。
+
+        Args:
+            cls_id (int): 要提取的目标类别ID 默认0
+
+        Returns:
+            List[Optional[np.ndarray]]: 包含处理后图像 (RGB NumPy 数组) 的列表。
+                                        对于成功裁剪的图片,列表元素是裁剪后的图像。
+                                        如果某张图片未找到指定类别或裁剪失败,列表元素是该图片的原始图像(RGB)。
+                                        如果原始图像无效,则列表元素为 None。
+        """
+        if self.results is None:
+            raise ValueError("Must call predict_batch() before calling get_max_img_list().")
+
+        processed_images: List[Optional[np.ndarray]] = []
+        for i in range(self.batch_size):
+            # 调用 get_max_img 获取单张图片的处理结果
+            processed_img = self.get_max_img(index=i, cls_id=cls_id)
+            processed_images.append(processed_img)
+
+        return processed_images
+
+    def get_results(self) -> Optional[List]:
+        """获取完整的批处理结果列表。"""
+        return self.results

+ 184 - 0
app/utils/MyViTFeatureExtractor.py

@@ -0,0 +1,184 @@
+import torch
+from transformers import ViTModel, ViTConfig
+import torchvision.transforms as transforms
+from PIL import Image
+import numpy as np
+from typing import List, Union
+import logging
+import os
+import torch.nn.functional as F
+
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
+
+
+class MyViTFeatureExtractor:
+    def __init__(self, local_model_path: str) -> None:
+        """
+        初始化特征提取器。
+
+        适配: 能够加载由 MetricViT 训练并保存的 backbone 模型。
+        """
+        if not os.path.isdir(local_model_path):
+            raise NotADirectoryError(f"Model path not found: {local_model_path}")
+
+        logging.info(f"Loading model from: {local_model_path}")
+
+        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+        logging.info(f"Using device: {self.device}")
+
+        try:
+            # 加载配置和模型 (这里加载的是纯 ViTModel)
+            self.config = ViTConfig.from_pretrained(local_model_path, local_files_only=True)
+            self.model = ViTModel.from_pretrained(local_model_path, config=self.config, local_files_only=True)
+        except Exception as e:
+            logging.error(f"Failed to load model: {e}")
+            raise
+
+        self.model.to(self.device)
+        self.model.eval()
+
+        # 定义预处理 (与训练时的 Val Transform 保持一致)
+        self.transform = transforms.Compose([
+            transforms.Resize((224, 224)),
+            transforms.ToTensor(),
+            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
+        ])
+
+        self.feature_dim = self.model.config.hidden_size
+        logging.info(f"Feature dimension: {self.feature_dim}")
+
+    def run(self, imgs: List[Union[str, np.ndarray, Image.Image]], normalize: bool = True) -> np.ndarray:
+        """
+        处理一批图像并返回特征向量。
+
+        Args:
+            imgs: 图片路径、numpy数组或PIL Image的列表
+            normalize: 是否进行 L2 归一化 (强烈建议为 True,适配 Milvus/Cosine 搜索)
+        """
+        if not imgs:
+            return np.empty((0, self.feature_dim), dtype=np.float32)
+
+        processed_tensors = []
+        valid_indices = []
+
+        # 1. 预处理
+        for i, img_input in enumerate(imgs):
+            try:
+                # --- 图像读取逻辑 (保持你原有的健壮性逻辑) ---
+                if isinstance(img_input, str):
+                    img = Image.open(img_input).convert('RGB')
+                elif isinstance(img_input, np.ndarray):
+                    if img_input.dtype != np.uint8:
+                        if img_input.max() <= 1.0:
+                            img_input = (img_input * 255).astype(np.uint8)
+                        else:
+                            img_input = img_input.astype(np.uint8)
+                    img = Image.fromarray(img_input, 'RGB')
+                elif isinstance(img_input, Image.Image):
+                    img = img_input.convert('RGB')
+                else:
+                    continue
+                    # ----------------------------------------
+
+                img_tensor = self.transform(img)
+                processed_tensors.append(img_tensor)
+                valid_indices.append(i)
+            except Exception as e:
+                logging.error(f"Error processing image index {i}: {e}")
+
+        if not processed_tensors:
+            return np.empty((0, self.feature_dim), dtype=np.float32)
+
+        # 2. 推理
+        batch_tensor = torch.stack(processed_tensors, dim=0).to(self.device)
+
+        with torch.no_grad():
+            outputs = self.model(batch_tensor)
+            # 【关键】提取 last_hidden_state 的 [CLS] token (Index 0)
+            # 这与训练时的 MetricViT 保持完全一致
+            features = outputs.last_hidden_state[:, 0, :]
+
+        # 3. 后处理
+        if normalize:
+            # 使用 PyTorch 的 normalize 更精确,或者保持 numpy 实现
+            features = F.normalize(features, p=2, dim=1)
+            output_np = features.cpu().numpy()
+        else:
+            output_np = features.cpu().numpy()
+
+        # 4. 填充结果 (保持列表长度一致)
+        if len(output_np) != len(imgs):
+            final_output = np.full((len(imgs), self.feature_dim), np.nan, dtype=np.float32)
+            for idx, vec in zip(valid_indices, output_np):
+                final_output[idx] = vec
+            return final_output
+
+        return output_np
+
+
+def compare_images(extractor, img_path_A, img_path_B):
+    """
+    计算两张图片的相似度
+    """
+    if not os.path.exists(img_path_A) or not os.path.exists(img_path_B):
+        print(f"❌ 错误: 找不到图片路径。\nA: {img_path_A}\nB: {img_path_B}")
+        return
+
+    print(f"🔍 正在对比:")
+    print(f"  图 A: {os.path.basename(img_path_A)}")
+    print(f"  图 B: {os.path.basename(img_path_B)}")
+
+    # 1. 提取特征 (一次传入两张图,效率更高)
+    # run 方法返回的是已经归一化过的 numpy 数组
+    vectors = extractor.run([img_path_A, img_path_B], normalize=True)
+
+    vec_a = vectors[0]
+    vec_b = vectors[1]
+
+    # 2. 计算余弦相似度 (Cosine Similarity)
+    # 因为 vec_a 和 vec_b 模长都为 1,所以点积就是余弦相似度
+    similarity = np.dot(vec_a, vec_b)
+
+    # 3. 计算欧氏距离 (Euclidean Distance) - 辅助参考
+    # 距离越小越相似
+    distance = np.linalg.norm(vec_a - vec_b)
+
+    # 4. 打印结果
+    print("-" * 30)
+    print(f"📊 相似度结果:")
+    print(f"  ★ 余弦相似度 (Cosine): {similarity:.4f}  (越接近 1.0 越相似)")
+    print(f"  ☆ 欧氏距离 (L2 Dist):  {distance:.4f}    (越接近 0.0 越相似)")
+    print("-" * 30)
+
+    # 5. 简单判定建议
+    threshold = 0.85  # 这个阈值可以根据实际情况调整
+    if similarity > threshold:
+        print("✅ 结论: 它们极有可能是同一张卡 (或同一宝可梦的不同语言版本)")
+    else:
+        print("❌ 结论: 它们看起来是不同的卡片")
+    print("\n")
+
+
+if __name__ == "__main__":
+    # ================= 配置 =================
+    # 你的模型保存路径
+    MODEL_PATH = "/home/martin/ML/Model/pokemon_cls/vit-base-patch16-224-Pokemon02"
+
+    # 这里填入你想测试的两张图片的绝对路径
+    # 建议测试:
+    # 1. 一张中文卡 vs 同一张的英文卡
+    # 2. 一张卡 vs 一张完全不同的卡
+    IMG_1 = r"/home/martin/ML/RemoteProject/untitled10/uploads/伊布us1.png"
+    IMG_2 = r"/home/martin/ML/RemoteProject/untitled10/uploads/伊布tc1.png"
+
+    # ================= 运行 =================
+    try:
+        print("正在加载模型,请稍候...")
+        # 初始化提取器
+        extractor = MyViTFeatureExtractor(MODEL_PATH)
+
+        # 执行对比
+        compare_images(extractor, IMG_1, IMG_2)
+
+    except Exception as e:
+        print(f"运行出错: {e}")

+ 0 - 0
app/utils/__init__.py


+ 21 - 0
app/utils/parser.py

@@ -0,0 +1,21 @@
+import re
+
+
+def parse_folder_name(folder_name: str):
+    """
+    解析格式: ('129873', {'us'}, 'Swadloon'), 2
+    返回: (source_id, lang, card_name, card_num)
+    """
+    # 这是一个比较宽松的正则,适应你的格式
+    # Group 1: source_id, Group 2: lang, Group 3: name, Group 4: card_num
+    pattern = r"\('(.+?)', \{'(.+?)'\}, '(.+?)'\),\s*(.+)"
+
+    match = re.search(pattern, folder_name)
+    if match:
+        return {
+            "source_id": match.group(1),
+            "lang": match.group(2),
+            "card_name": match.group(3),
+            "card_num": match.group(4).strip()
+        }
+    return None

+ 0 - 0
run_PokemonCardSearch.py