from fastapi import APIRouter, UploadFile, File, Form, HTTPException from typing import Optional, List import cv2 import numpy as np from PIL import Image import io import shutil from app.core.models_loader import global_models from app.db.milvus_client import milvus_collection from app.core.config import settings router = APIRouter() @router.post("/image", summary="相似宝可梦图像搜索") async def search_image(file: UploadFile = File(...), top_k: int = 5): # 1. 读取图片 content = await file.read() image = Image.open(io.BytesIO(content)).convert("RGB") image_np = np.array(image) # 2. YOLO 截取 (单张处理) # 注意:MyBatchOnnxYolo 需要 List 输入 global_models.yolo.predict_batch([image_np]) # 提取第0张图的 crop (默认 cls_id=0) cropped_img = global_models.yolo.get_max_img(0, cls_id=0) if cropped_img is None: # 如果没检测到,使用原图 cropped_img = image_np # 3. ViT 提取特征 # MyViTFeatureExtractor 接收 List[Union[str, np.ndarray, Image]] vectors = global_models.vit.run([cropped_img], normalize=True) query_vector = vectors[0].tolist() # 4. Milvus 搜索 search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}} results = milvus_collection.search( data=[query_vector], anns_field="vector", param=search_params, limit=top_k, output_fields=["id", "img_path", "card_name", "card_num", "lang", "source_id"] ) # 5. 格式化结果 formatted_results = [] for hits in results: for hit in hits: # 将本地路径转换为相对 URL 以便前端访问 full_path = hit.entity.get("img_path") web_path = f"/static/images/{full_path.split('/')[-1]}" if full_path else "" formatted_results.append({ "id": hit.id, "score": hit.score, "card_name": hit.entity.get("card_name"), "card_num": hit.entity.get("card_num"), "lang": hit.entity.get("lang"), "img_url": web_path }) return {"results": formatted_results} @router.post("/filter") async def filter_database( card_name: Optional[str] = None, source_id: Optional[str] = None, limit: int = 20 ): """数据库元数据过滤""" expr_list = [] if card_name: # 模糊查询 expr_list.append(f'card_name like "{card_name}%"') if source_id: expr_list.append(f'source_id == "{source_id}"') expr = " && ".join(expr_list) if expr_list else "" # 如果没有条件,查询所有(慎用,有限制) if not expr: expr = "id > 0" res = milvus_collection.query( expr=expr, output_fields=["id", "img_path", "card_name", "card_num", "lang"], limit=limit ) # 格式化路径 for item in res: full_path = item.pop("img_path", "") item["img_url"] = f"/static/images/{full_path.split('/')[-1]}" if full_path else "" return {"count": len(res), "data": res}