card_inference.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. from fastapi import APIRouter, File, UploadFile, Depends, HTTPException, Path
  2. from fastapi.responses import FileResponse, JSONResponse
  3. from fastapi.concurrency import run_in_threadpool
  4. from enum import Enum
  5. from ..core.config import settings
  6. from app.services.card_service import CardInferenceService, card_service
  7. from app.services.defect_service import DefectInferenceService, defect_service
  8. from app.core.logger import logger
  9. import json
  10. router = APIRouter()
  11. model_names = list(settings.CARD_MODELS_CONFIG.keys())
  12. defect_names = list(settings.DEFECT_TYPE.keys())
  13. InferenceType = Enum("InferenceType", {name: name for name in model_names})
  14. DefectType = Enum("InferenceType", {name: name for name in defect_names})
  15. @router.post("/model_inference", description="内外框类型输入大图, 其他输入小图")
  16. async def card_model_inference(
  17. inference_type: InferenceType,
  18. service: CardInferenceService = Depends(lambda: card_service),
  19. file: UploadFile = File(...)
  20. ):
  21. """
  22. 接收一张卡片图片,使用指定类型的模型进行推理,并返回JSON结果。
  23. - **inference_type**: 要使用的模型类型(从下拉列表中选择)。
  24. - **file**: 要上传的图片文件。
  25. """
  26. image_bytes = await file.read()
  27. try:
  28. # 3. 传递参数时,使用 .value 获取 Enum 的字符串值
  29. json_result = await run_in_threadpool(
  30. service.predict,
  31. inference_type=inference_type.value, # 使用 .value
  32. image_bytes=image_bytes
  33. )
  34. return json_result
  35. except ValueError as e:
  36. raise HTTPException(status_code=400, detail=str(e))
  37. except Exception as e:
  38. raise HTTPException(status_code=500, detail=f"服务器内部错误: {e}")
  39. @router.post("/defect_inference",
  40. description="环形光居中计算, 环形光正反边角缺陷, 同轴光正反表面缺陷")
  41. async def card_model_inference(
  42. defect_type: DefectType,
  43. service: DefectInferenceService = Depends(lambda: defect_service),
  44. file: UploadFile = File(...),
  45. is_draw_image: bool = False,
  46. ):
  47. image_bytes = await file.read()
  48. try:
  49. # 3. 传递参数时,使用 .value 获取 Enum 的字符串值
  50. json_result = await run_in_threadpool(
  51. service.defect_inference,
  52. inference_type=defect_type.value,
  53. image_bytes=image_bytes,
  54. is_draw_image=is_draw_image
  55. )
  56. return json_result
  57. except ValueError as e:
  58. logger.error(e)
  59. raise HTTPException(status_code=400, detail=str(e))
  60. except Exception as e:
  61. logger.error(e)
  62. raise HTTPException(status_code=500, detail=f"服务器内部错误: {e}")
  63. @router.post("/mock_query")
  64. async def mock_query(img_id: int):
  65. # json_data = {"img_id": img_id}
  66. with open("_temp_work/mock_result.json", "r") as f:
  67. json_data = json.load(f)
  68. return json_data