from fastapi import APIRouter, UploadFile, File, Form, HTTPException from typing import Optional, List import cv2 import numpy as np from PIL import Image import io import shutil from app.core.models_loader import global_models from app.db.milvus_client import milvus_collection from app.core.config import settings router = APIRouter() @router.post("/image", summary="相似宝可梦图像搜索") async def search_image(file: UploadFile = File(...), top_k: int = 5): # 1. 读取图片 content = await file.read() image = Image.open(io.BytesIO(content)).convert("RGB") image_np = np.array(image) # 2. YOLO 截取 (单张处理) # 注意:MyBatchOnnxYolo 需要 List 输入 global_models.yolo.predict_batch([image_np]) # 提取第0张图的 crop (默认 cls_id=0) cropped_img = global_models.yolo.get_max_img(0, cls_id=0) if cropped_img is None: # 如果没检测到,使用原图 cropped_img = image_np # 3. ViT 提取特征 # MyViTFeatureExtractor 接收 List[Union[str, np.ndarray, Image]] vectors = global_models.vit.run([cropped_img], normalize=True) query_vector = vectors[0].tolist() # 4. Milvus 搜索 search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}} results = milvus_collection.search( data=[query_vector], anns_field="vector", param=search_params, limit=top_k, output_fields=["id", "img_path", "card_name", "card_num", "lang", "source_id"] ) # 5. 格式化结果 formatted_results = [] for hits in results: for hit in hits: # 将本地路径转换为相对 URL 以便前端访问 full_path = hit.entity.get("img_path") web_path = f"/static/images/{full_path.split('/')[-1]}" if full_path else "" formatted_results.append({ "id": hit.id, "score": hit.score, "card_name": hit.entity.get("card_name"), "card_num": hit.entity.get("card_num"), "lang": hit.entity.get("lang"), "img_url": web_path }) return {"results": formatted_results} @router.post("/state", summary="统计") async def filter_database( card_name: Optional[str] = None, limit: int = 20 ): """数据库元数据过滤""" expr_list = [] if card_name: # 模糊查询 expr_list.append(f'card_name like "{card_name}%"') expr = " && ".join(expr_list) if expr_list else "" # 如果没有条件,查询所有 if not expr: expr = "id > 0" try: # 获取符合条件的总数量 (count(*)) count_res = milvus_collection.query( expr=expr, output_fields=["count(*)"] ) # count_res 的格式通常是 [{'count(*)': 105}] total_count = count_res[0]["count(*)"] # 步骤 2: 获取实际数据 (带 limit) res = milvus_collection.query( expr=expr, output_fields=["id", "img_path", "card_name", "card_num", "lang"], limit=limit ) # 格式化路径 for item in res: full_path = item.pop("img_path", "") item["img_url"] = f"/static/images/{full_path.split('/')[-1]}" if full_path else "" return { "total": total_count, "limit": limit, "returned_count": len(res), "data": res } except Exception as e: print(f"Query Error: {e}") return {"error": str(e)}