card_inference.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. from fastapi import APIRouter, File, UploadFile, Depends, HTTPException, Path, Response
  2. from fastapi.responses import FileResponse, JSONResponse
  3. from fastapi.concurrency import run_in_threadpool
  4. from enum import Enum
  5. from ..core.config import settings
  6. from app.services.card_rectify_and_center import CardRectifyAndCenter
  7. from app.services.card_service import CardInferenceService, card_service
  8. from app.services.defect_service import DefectInferenceService
  9. from app.core.logger import get_logger
  10. import cv2
  11. import numpy as np
  12. import json
  13. logger = get_logger(__name__)
  14. router = APIRouter()
  15. model_names = list(settings.CARD_MODELS_CONFIG.keys())
  16. defect_names = list(settings.DEFECT_TYPE.keys())
  17. InferenceType = Enum("InferenceType", {name: name for name in model_names})
  18. DefectType = Enum("InferenceType", {name: name for name in defect_names})
  19. @router.post("/model_inference", description="内外框类型输入大图, 其他输入小图")
  20. async def card_model_inference(
  21. inference_type: InferenceType,
  22. service: CardInferenceService = Depends(lambda: card_service),
  23. file: UploadFile = File(...)
  24. ):
  25. """
  26. 接收一张卡片图片,使用指定类型的模型进行推理,并返回JSON结果。
  27. - **inference_type**: 要使用的模型类型(从下拉列表中选择)。
  28. - **file**: 要上传的图片文件。
  29. """
  30. image_bytes = await file.read()
  31. try:
  32. # 3. 传递参数时,使用 .value 获取 Enum 的字符串值
  33. json_result = await run_in_threadpool(
  34. service.predict,
  35. inference_type=inference_type.value, # 使用 .value
  36. image_bytes=image_bytes
  37. )
  38. return json_result
  39. except ValueError as e:
  40. raise HTTPException(status_code=400, detail=str(e))
  41. except Exception as e:
  42. raise HTTPException(status_code=500, detail=f"服务器内部错误: {e}")
  43. @router.post("/defect_inference",
  44. description="环形光居中计算, 环形光正反边角缺陷, 同轴光正反表面缺陷")
  45. async def card_model_inference(
  46. defect_type: DefectType,
  47. # service: DefectInferenceService = Depends(lambda: defect_service),
  48. file: UploadFile = File(...),
  49. is_draw_image: bool = False,
  50. ):
  51. service = DefectInferenceService()
  52. image_bytes = await file.read()
  53. # 将字节数据转换为numpy数组
  54. np_arr = np.frombuffer(image_bytes, np.uint8)
  55. # 从numpy数组中解码图像
  56. img_bgr = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
  57. if img_bgr is None:
  58. raise ValueError("无法解码图像,请确保上传的是有效的图片格式 (JPG, PNG, etc.)")
  59. try:
  60. # 3. 传递参数时,使用 .value 获取 Enum 的字符串值
  61. json_result = await run_in_threadpool(
  62. service.defect_inference,
  63. inference_type=defect_type.value,
  64. img_bgr=img_bgr,
  65. is_draw_image=is_draw_image
  66. )
  67. return json_result
  68. except ValueError as e:
  69. logger.error(e)
  70. raise HTTPException(status_code=400, detail=str(e))
  71. except Exception as e:
  72. logger.error(e)
  73. raise HTTPException(status_code=500, detail=f"服务器内部错误: {e}")
  74. @router.post("/card_rectify_and_center",
  75. description="对卡片图像进行转正和居中处理")
  76. async def card_rectify_and_center(
  77. file: UploadFile = File(...)
  78. ):
  79. service = CardRectifyAndCenter()
  80. image_bytes = await file.read()
  81. # 将字节数据转换为numpy数组
  82. np_arr = np.frombuffer(image_bytes, np.uint8)
  83. # 从numpy数组中解码图像
  84. img_bgr = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
  85. if img_bgr is None:
  86. raise ValueError("无法解码图像,请确保上传的是有效的图片格式 (JPG, PNG, etc.)")
  87. try:
  88. # 3. 传递参数时,使用 .value 获取 Enum 的字符串值
  89. img_result = await run_in_threadpool(
  90. service.rectify_and_center,
  91. img_bgr=img_bgr
  92. )
  93. is_success, buffer = cv2.imencode(".jpg", img_result)
  94. jpeg_bytes = buffer.tobytes()
  95. return Response(content=jpeg_bytes, media_type="image/jpeg")
  96. except ValueError as e:
  97. raise HTTPException(status_code=400, detail=str(e))
  98. except Exception as e:
  99. raise HTTPException(status_code=500, detail=f"服务器内部错误: {e}")