| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899 |
- from fastapi import APIRouter, UploadFile, File, Form, HTTPException
- from typing import Optional, List
- import cv2
- import numpy as np
- from PIL import Image
- import io
- import shutil
- from app.core.models_loader import global_models
- from app.db.milvus_client import milvus_collection
- from app.core.config import settings
- router = APIRouter(prefix="/search", tags=["Search"])
- @router.post("/image", summary="相似宝可梦图像搜索")
- async def search_image(file: UploadFile = File(...), top_k: int = 5):
- # 1. 读取图片
- content = await file.read()
- image = Image.open(io.BytesIO(content)).convert("RGB")
- image_np = np.array(image)
- # 2. YOLO 截取 (单张处理)
- # 注意:MyBatchOnnxYolo 需要 List 输入
- global_models.yolo.predict_batch([image_np])
- # 提取第0张图的 crop (默认 cls_id=0)
- cropped_img = global_models.yolo.get_max_img(0, cls_id=0)
- if cropped_img is None:
- # 如果没检测到,使用原图
- cropped_img = image_np
- # 3. ViT 提取特征
- # MyViTFeatureExtractor 接收 List[Union[str, np.ndarray, Image]]
- vectors = global_models.vit.run([cropped_img], normalize=True)
- query_vector = vectors[0].tolist()
- # 4. Milvus 搜索
- search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
- results = milvus_collection.search(
- data=[query_vector],
- anns_field="vector",
- param=search_params,
- limit=top_k,
- output_fields=["id", "img_path", "card_name", "card_num", "lang", "source_id"]
- )
- # 5. 格式化结果
- formatted_results = []
- for hits in results:
- for hit in hits:
- # 将本地路径转换为相对 URL 以便前端访问
- full_path = hit.entity.get("img_path")
- web_path = f"/static/images/{full_path.split('/')[-1]}" if full_path else ""
- formatted_results.append({
- "id": hit.id,
- "score": hit.score,
- "card_name": hit.entity.get("card_name"),
- "card_num": hit.entity.get("card_num"),
- "lang": hit.entity.get("lang"),
- "img_url": web_path
- })
- return {"results": formatted_results}
- @router.post("/filter")
- async def filter_database(
- card_name: Optional[str] = None,
- source_id: Optional[str] = None,
- limit: int = 20
- ):
- """数据库元数据过滤"""
- expr_list = []
- if card_name:
- # 模糊查询
- expr_list.append(f'card_name like "{card_name}%"')
- if source_id:
- expr_list.append(f'source_id == "{source_id}"')
- expr = " && ".join(expr_list) if expr_list else ""
- # 如果没有条件,查询所有(慎用,有限制)
- if not expr:
- expr = "id > 0"
- res = milvus_collection.query(
- expr=expr,
- output_fields=["id", "img_path", "card_name", "card_num", "lang"],
- limit=limit
- )
- # 格式化路径
- for item in res:
- full_path = item.pop("img_path", "")
- item["img_url"] = f"/static/images/{full_path.split('/')[-1]}" if full_path else ""
- return {"count": len(res), "data": res}
|