score_inference.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. from fastapi import APIRouter, File, UploadFile, Depends, HTTPException
  2. from fastapi.responses import FileResponse, JSONResponse
  3. from fastapi.concurrency import run_in_threadpool
  4. from enum import Enum
  5. from typing import Optional, Dict, Any
  6. from ..core.config import settings
  7. from app.services.score_service import ScoreService
  8. import numpy as np
  9. import cv2
  10. import json
  11. from app.core.logger import get_logger
  12. logger = get_logger(__name__)
  13. router = APIRouter()
  14. score_names = settings.SCORE_TYPE
  15. ScoreType = Enum("InferenceType", {name: name for name in score_names})
  16. @router.post("/score_inference", summary="输入卡片类型(正反面, 缺陷类型), 是否为反射卡")
  17. async def card_model_inference(
  18. score_type: ScoreType,
  19. is_reflect_card: bool = False,
  20. file: UploadFile = File(...)
  21. ):
  22. """
  23. 接收一张卡片图片,使用指定类型的模型进行推理,并返回JSON结果。
  24. - **inference_type**: 要使用的模型类型(从下拉列表中选择)。
  25. - **file**: 要上传的图片文件。
  26. """
  27. service = ScoreService()
  28. image_bytes = await file.read()
  29. # 将字节数据转换为numpy数组
  30. np_arr = np.frombuffer(image_bytes, np.uint8)
  31. # 从numpy数组中解码图像
  32. img_bgr = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
  33. if img_bgr is None:
  34. raise ValueError("无法解码图像,请确保上传的是有效的图片格式 (JPG, PNG, etc.)")
  35. try:
  36. json_result = await run_in_threadpool(
  37. service.score_inference,
  38. score_type=score_type.value,
  39. is_reflect_card=is_reflect_card,
  40. img_bgr=img_bgr
  41. )
  42. return json_result
  43. except ValueError as e:
  44. raise HTTPException(status_code=400, detail=str(e))
  45. except Exception as e:
  46. raise HTTPException(status_code=500, detail=f"服务器内部错误: {e}")
  47. @router.post("/score_recalculate", summary="输入卡片类型(正反面, 缺陷类型)",
  48. description="输入的json数据结构为 "
  49. "{'result': {'center_result':..., 'defect_result':...}}")
  50. async def score_recalculate(score_type: ScoreType, json_data: Dict[str, Any]):
  51. """
  52. 接收分数推理后的结果, 然后重新根据json数据计算居中和缺陷等分数
  53. """
  54. service = ScoreService()
  55. try:
  56. json_result = await run_in_threadpool(
  57. service.recalculate_defect_score,
  58. score_type=score_type.value,
  59. json_data=json_data
  60. )
  61. return json_result
  62. except ValueError as e:
  63. raise HTTPException(status_code=400, detail=str(e))
  64. except Exception as e:
  65. raise HTTPException(status_code=500, detail=f"服务器内部错误: {e}")