search.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. from fastapi import APIRouter, UploadFile, File, Form, HTTPException
  2. from typing import Optional, List
  3. import cv2
  4. import numpy as np
  5. from PIL import Image
  6. import io
  7. import shutil
  8. from app.core.models_loader import global_models
  9. from app.db.milvus_client import milvus_collection
  10. from app.core.config import settings
  11. router = APIRouter()
  12. @router.post("/image", summary="相似宝可梦图像搜索")
  13. async def search_image(file: UploadFile = File(...), top_k: int = 5):
  14. # 1. 读取图片
  15. content = await file.read()
  16. image = Image.open(io.BytesIO(content)).convert("RGB")
  17. image_np = np.array(image)
  18. # 2. YOLO 截取 (单张处理)
  19. # 注意:MyBatchOnnxYolo 需要 List 输入
  20. global_models.yolo.predict_batch([image_np])
  21. # 提取第0张图的 crop (默认 cls_id=0)
  22. cropped_img = global_models.yolo.get_max_img(0, cls_id=0)
  23. if cropped_img is None:
  24. # 如果没检测到,使用原图
  25. cropped_img = image_np
  26. # 3. ViT 提取特征
  27. # MyViTFeatureExtractor 接收 List[Union[str, np.ndarray, Image]]
  28. vectors = global_models.vit.run([cropped_img], normalize=True)
  29. query_vector = vectors[0].tolist()
  30. # 4. Milvus 搜索
  31. search_params = {"metric_type": "COSINE", "params": {"nprobe": 10}}
  32. results = milvus_collection.search(
  33. data=[query_vector],
  34. anns_field="vector",
  35. param=search_params,
  36. limit=top_k,
  37. output_fields=["id", "img_path", "card_name", "card_num", "lang", "source_id"]
  38. )
  39. # 5. 格式化结果
  40. formatted_results = []
  41. for hits in results:
  42. for hit in hits:
  43. # 将本地路径转换为相对 URL 以便前端访问
  44. full_path = hit.entity.get("img_path")
  45. web_path = f"/static/images/{full_path.split('/')[-1]}" if full_path else ""
  46. formatted_results.append({
  47. "id": hit.id,
  48. "score": hit.score,
  49. "card_name": hit.entity.get("card_name"),
  50. "card_num": hit.entity.get("card_num"),
  51. "lang": hit.entity.get("lang"),
  52. "img_url": web_path
  53. })
  54. return {"results": formatted_results}
  55. @router.post("/filter")
  56. async def filter_database(
  57. card_name: Optional[str] = None,
  58. source_id: Optional[str] = None,
  59. limit: int = 20
  60. ):
  61. """数据库元数据过滤"""
  62. expr_list = []
  63. if card_name:
  64. # 模糊查询
  65. expr_list.append(f'card_name like "{card_name}%"')
  66. if source_id:
  67. expr_list.append(f'source_id == "{source_id}"')
  68. expr = " && ".join(expr_list) if expr_list else ""
  69. # 如果没有条件,查询所有(慎用,有限制)
  70. if not expr:
  71. expr = "id > 0"
  72. res = milvus_collection.query(
  73. expr=expr,
  74. output_fields=["id", "img_path", "card_name", "card_num", "lang"],
  75. limit=limit
  76. )
  77. # 格式化路径
  78. for item in res:
  79. full_path = item.pop("img_path", "")
  80. item["img_url"] = f"/static/images/{full_path.split('/')[-1]}" if full_path else ""
  81. return {"count": len(res), "data": res}