from fastapi import APIRouter, File, UploadFile, Depends, HTTPException from fastapi.responses import FileResponse, JSONResponse from fastapi.concurrency import run_in_threadpool from enum import Enum from typing import Optional, Dict, Any from ..core.config import settings from app.services.score_service import ScoreService from app.services.stitch_fusion_service import StitchFusionService import numpy as np import cv2 import json from app.core.logger import get_logger logger = get_logger(__name__) router = APIRouter() score_names = settings.SCORE_TYPE ScoreType = Enum("InferenceType", {name: name for name in score_names}) stitch_score_names = settings.STITCH_SCORE_TYPE StitchScoreType = Enum("StitchScoreType", {name: name for name in stitch_score_names}) @router.post("/score_inference", summary="输入卡片类型(正反面, 缺陷类型), 是否为反射卡") async def card_model_inference( score_type: ScoreType, is_reflect_card: bool = False, file: UploadFile = File(...) ): """ 接收一张卡片图片,使用指定类型的模型进行推理,并返回JSON结果。 - **inference_type**: 要使用的模型类型(从下拉列表中选择)。 - **file**: 要上传的图片文件。 """ service = ScoreService() image_bytes = await file.read() # 将字节数据转换为numpy数组 np_arr = np.frombuffer(image_bytes, np.uint8) # 从numpy数组中解码图像 img_bgr = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) if img_bgr is None: raise ValueError("无法解码图像,请确保上传的是有效的图片格式 (JPG, PNG, etc.)") try: json_result = await run_in_threadpool( service.score_inference, score_type=score_type.value, is_reflect_card=is_reflect_card, img_bgr=img_bgr ) return json_result except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: raise HTTPException(status_code=500, detail=f"服务器内部错误: {e}") @router.post("/score_recalculate", summary="输入卡片类型(正反面, 缺陷类型)", description="输入的json数据结构为 " "{'result': {'center_result':..., 'defect_result':...}}") async def score_recalculate(score_type: ScoreType, json_data: Dict[str, Any]): """ 接收分数推理后的结果, 然后重新根据json数据计算居中和缺陷等分数 """ service = ScoreService() try: json_result = await run_in_threadpool( service.recalculate_defect_score, score_type=score_type.value, json_data=json_data ) return json_result except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: raise HTTPException(status_code=500, detail=f"服务器内部错误: {e}") @router.post("/stitch_score_inference", summary="StitchFusion 多模态拼接缺陷推理 (一次提交6张同侧图)", description=""" 一次提交一组 6 张同侧图片, 经 StitchFusion 单 PT 模型大图切片推理后, 返回 labelme 风格 mask 列表。 字段对应 (按正反面分别上传, 一次只提交一侧): - ring -> 环光图 (front_ring / back_ring) - gray -> 灰度图 (front_gray / back_gray) - stripe1 -> 调光1 (front_stripe1 / back_stripe1) - stripe2 -> 调光2 (front_stripe2 / back_stripe2) - stripe3 -> 调光3 (front_stripe3 / back_stripe3) - stripe4 -> 调光4 (front_stripe4 / back_stripe4) """) async def stitch_score_inference( score_type: StitchScoreType, card_name: str = "", ring: UploadFile = File(..., description="环光图"), gray: UploadFile = File(..., description="灰度图"), stripe1: UploadFile = File(..., description="调光1"), stripe2: UploadFile = File(..., description="调光2"), stripe3: UploadFile = File(..., description="调光3"), stripe4: UploadFile = File(..., description="调光4"), ): """接收同一类型(正面或反面) 6 张图, 输出该组的缺陷 JSON。""" variant_files = { "ring": ring, "gray": gray, "stripe1": stripe1, "stripe2": stripe2, "stripe3": stripe3, "stripe4": stripe4, } variant_bytes: Dict[str, bytes] = {} for variant, upload in variant_files.items(): try: variant_bytes[variant] = await upload.read() except Exception as e: raise HTTPException(status_code=400, detail=f"读取上传文件 {variant} 失败: {e}") if not variant_bytes[variant]: raise HTTPException(status_code=400, detail=f"上传文件 {variant} 内容为空") service = StitchFusionService() try: json_result = await run_in_threadpool( service.stitch_score_inference, score_type=score_type.value, card_name=card_name, variant_bytes=variant_bytes, ) return json_result except FileNotFoundError as e: raise HTTPException(status_code=500, detail=str(e)) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: logger.exception("stitch_score_inference 失败") raise HTTPException(status_code=500, detail=f"服务器内部错误: {e}")