card_inference.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. from fastapi import APIRouter, File, UploadFile, Depends, HTTPException, Path, Response
  2. from fastapi.responses import FileResponse, JSONResponse
  3. from fastapi.concurrency import run_in_threadpool
  4. from enum import Enum
  5. from ..core.config import settings
  6. from app.services.card_rectify_and_center import CardRectifyAndCenter
  7. from app.services.card_service import CardInferenceService, card_service
  8. from app.services.defect_service import DefectInferenceService
  9. from app.core.logger import get_logger
  10. import cv2
  11. import numpy as np
  12. import json
  13. logger = get_logger(__name__)
  14. router = APIRouter()
  15. model_names = list(settings.CARD_MODELS_CONFIG.keys())
  16. defect_names = list(settings.DEFECT_TYPE.keys())
  17. InferenceType = Enum("InferenceType", {name: name for name in model_names})
  18. DefectType = Enum("InferenceType", {name: name for name in defect_names})
  19. @router.post("/model_inference", description="内外框类型输入大图, 其他输入小图")
  20. async def card_model_inference(
  21. inference_type: InferenceType,
  22. service: CardInferenceService = Depends(lambda: card_service),
  23. file: UploadFile = File(...)
  24. ):
  25. """
  26. 接收一张卡片图片,使用指定类型的模型进行推理,并返回JSON结果。
  27. - **inference_type**: 要使用的模型类型(从下拉列表中选择)。
  28. - **file**: 要上传的图片文件。
  29. """
  30. image_bytes = await file.read()
  31. try:
  32. # 3. 传递参数时,使用 .value 获取 Enum 的字符串值
  33. json_result = await run_in_threadpool(
  34. service.predict,
  35. inference_type=inference_type.value, # 使用 .value
  36. image_bytes=image_bytes
  37. )
  38. return json_result
  39. except ValueError as e:
  40. raise HTTPException(status_code=400, detail=str(e))
  41. except Exception as e:
  42. raise HTTPException(status_code=500, detail=f"服务器内部错误: {e}")
  43. @router.post("/defect_inference",
  44. description="环形光居中计算, 环形光正反边角缺陷, 同轴光正反表面缺陷")
  45. async def card_model_inference(
  46. defect_type: DefectType,
  47. file: UploadFile = File(...),
  48. is_draw_image: bool = False,
  49. ):
  50. service = DefectInferenceService()
  51. image_bytes = await file.read()
  52. # 将字节数据转换为numpy数组
  53. np_arr = np.frombuffer(image_bytes, np.uint8)
  54. # 从numpy数组中解码图像
  55. img_bgr = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
  56. if img_bgr is None:
  57. raise ValueError("无法解码图像,请确保上传的是有效的图片格式 (JPG, PNG, etc.)")
  58. try:
  59. # 3. 传递参数时,使用 .value 获取 Enum 的字符串值
  60. json_result = await run_in_threadpool(
  61. service.defect_inference,
  62. inference_type=defect_type.value,
  63. img_bgr=img_bgr,
  64. is_draw_image=is_draw_image
  65. )
  66. return json_result
  67. except ValueError as e:
  68. logger.error(e)
  69. raise HTTPException(status_code=400, detail=str(e))
  70. except Exception as e:
  71. logger.error(e)
  72. raise HTTPException(status_code=500, detail=f"服务器内部错误: {e}")
  73. @router.post("/card_rectify_and_center",
  74. description="对卡片图像进行转正和居中处理")
  75. async def card_rectify_and_center(
  76. file: UploadFile = File(...)
  77. ):
  78. service = CardRectifyAndCenter()
  79. image_bytes = await file.read()
  80. # 将字节数据转换为numpy数组
  81. np_arr = np.frombuffer(image_bytes, np.uint8)
  82. # 从numpy数组中解码图像
  83. img_bgr = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
  84. if img_bgr is None:
  85. raise ValueError("无法解码图像,请确保上传的是有效的图片格式 (JPG, PNG, etc.)")
  86. try:
  87. # 3. 传递参数时,使用 .value 获取 Enum 的字符串值
  88. img_result = await run_in_threadpool(
  89. service.rectify_and_center,
  90. img_bgr=img_bgr
  91. )
  92. is_success, buffer = cv2.imencode(".jpg", img_result)
  93. jpeg_bytes = buffer.tobytes()
  94. return Response(content=jpeg_bytes, media_type="image/jpeg")
  95. except ValueError as e:
  96. raise HTTPException(status_code=400, detail=str(e))
  97. except Exception as e:
  98. raise HTTPException(status_code=500, detail=f"服务器内部错误: {e}")