score_inference.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. from fastapi import APIRouter, File, UploadFile, Depends, HTTPException
  2. from fastapi.responses import FileResponse, JSONResponse
  3. from fastapi.concurrency import run_in_threadpool
  4. from enum import Enum
  5. from typing import Optional, Dict, Any
  6. from ..core.config import settings
  7. from app.services.score_service import ScoreService
  8. from app.services.stitch_fusion_service import StitchFusionService
  9. import numpy as np
  10. import cv2
  11. import json
  12. from app.core.logger import get_logger
  13. logger = get_logger(__name__)
  14. router = APIRouter()
  15. score_names = settings.SCORE_TYPE
  16. ScoreType = Enum("InferenceType", {name: name for name in score_names})
  17. stitch_score_names = settings.STITCH_SCORE_TYPE
  18. StitchScoreType = Enum("StitchScoreType", {name: name for name in stitch_score_names})
  19. @router.post("/score_inference", summary="输入卡片类型(正反面, 缺陷类型), 是否为反射卡")
  20. async def card_model_inference(
  21. score_type: ScoreType,
  22. is_reflect_card: bool = False,
  23. file: UploadFile = File(...)
  24. ):
  25. """
  26. 接收一张卡片图片,使用指定类型的模型进行推理,并返回JSON结果。
  27. - **inference_type**: 要使用的模型类型(从下拉列表中选择)。
  28. - **file**: 要上传的图片文件。
  29. """
  30. service = ScoreService()
  31. image_bytes = await file.read()
  32. # 将字节数据转换为numpy数组
  33. np_arr = np.frombuffer(image_bytes, np.uint8)
  34. # 从numpy数组中解码图像
  35. img_bgr = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
  36. if img_bgr is None:
  37. raise ValueError("无法解码图像,请确保上传的是有效的图片格式 (JPG, PNG, etc.)")
  38. try:
  39. json_result = await run_in_threadpool(
  40. service.score_inference,
  41. score_type=score_type.value,
  42. is_reflect_card=is_reflect_card,
  43. img_bgr=img_bgr
  44. )
  45. return json_result
  46. except ValueError as e:
  47. raise HTTPException(status_code=400, detail=str(e))
  48. except Exception as e:
  49. raise HTTPException(status_code=500, detail=f"服务器内部错误: {e}")
  50. @router.post("/score_recalculate", summary="输入卡片类型(正反面, 缺陷类型)",
  51. description="输入的json数据结构为 "
  52. "{'result': {'center_result':..., 'defect_result':...}}")
  53. async def score_recalculate(score_type: ScoreType, json_data: Dict[str, Any]):
  54. """
  55. 接收分数推理后的结果, 然后重新根据json数据计算居中和缺陷等分数
  56. """
  57. service = ScoreService()
  58. try:
  59. json_result = await run_in_threadpool(
  60. service.recalculate_defect_score,
  61. score_type=score_type.value,
  62. json_data=json_data
  63. )
  64. return json_result
  65. except ValueError as e:
  66. raise HTTPException(status_code=400, detail=str(e))
  67. except Exception as e:
  68. raise HTTPException(status_code=500, detail=f"服务器内部错误: {e}")
  69. @router.post("/stitch_score_inference",
  70. summary="StitchFusion 多模态拼接缺陷推理 (一次提交6张同侧图)",
  71. description="""
  72. 一次提交一组 6 张同侧图片, 经 StitchFusion 单 PT 模型大图切片推理后, 返回 labelme 风格 mask 列表。
  73. 字段对应 (按正反面分别上传, 一次只提交一侧):
  74. - ring -> 环光图 (front_ring / back_ring)
  75. - gray -> 灰度图 (front_gray / back_gray)
  76. - stripe1 -> 调光1 (front_stripe1 / back_stripe1)
  77. - stripe2 -> 调光2 (front_stripe2 / back_stripe2)
  78. - stripe3 -> 调光3 (front_stripe3 / back_stripe3)
  79. - stripe4 -> 调光4 (front_stripe4 / back_stripe4)
  80. """)
  81. async def stitch_score_inference(
  82. score_type: StitchScoreType,
  83. card_name: str = "",
  84. ring: UploadFile = File(..., description="环光图"),
  85. gray: UploadFile = File(..., description="灰度图"),
  86. stripe1: UploadFile = File(..., description="调光1"),
  87. stripe2: UploadFile = File(..., description="调光2"),
  88. stripe3: UploadFile = File(..., description="调光3"),
  89. stripe4: UploadFile = File(..., description="调光4"),
  90. ):
  91. """接收同一类型(正面或反面) 6 张图, 输出该组的缺陷 JSON。"""
  92. variant_files = {
  93. "ring": ring,
  94. "gray": gray,
  95. "stripe1": stripe1,
  96. "stripe2": stripe2,
  97. "stripe3": stripe3,
  98. "stripe4": stripe4,
  99. }
  100. variant_bytes: Dict[str, bytes] = {}
  101. for variant, upload in variant_files.items():
  102. try:
  103. variant_bytes[variant] = await upload.read()
  104. except Exception as e:
  105. raise HTTPException(status_code=400, detail=f"读取上传文件 {variant} 失败: {e}")
  106. if not variant_bytes[variant]:
  107. raise HTTPException(status_code=400, detail=f"上传文件 {variant} 内容为空")
  108. service = StitchFusionService()
  109. try:
  110. json_result = await run_in_threadpool(
  111. service.stitch_score_inference,
  112. score_type=score_type.value,
  113. card_name=card_name,
  114. variant_bytes=variant_bytes,
  115. )
  116. return json_result
  117. except FileNotFoundError as e:
  118. raise HTTPException(status_code=500, detail=str(e))
  119. except ValueError as e:
  120. raise HTTPException(status_code=400, detail=str(e))
  121. except Exception as e:
  122. logger.exception("stitch_score_inference 失败")
  123. raise HTTPException(status_code=500, detail=f"服务器内部错误: {e}")