import json import requests from typing import Optional, Dict, Any, List from fastapi import APIRouter, File, UploadFile, Depends, HTTPException, Body from fastapi.concurrency import run_in_threadpool from mysql.connector.pooling import PooledMySQLConnection from app.core.config import settings from app.core.logger import get_logger from app.core.database_loader import get_db_connection from app.crud import crud_card from app.utils.scheme import IMAGE_TYPE_TO_SCORE_TYPE from app.utils.labelme_process import convert_internal_to_labelme, convert_labelme_to_internal logger = get_logger(__name__) router = APIRouter() db_dependency = Depends(get_db_connection) @router.get("/export/{image_id}", summary="获取指定图片的 LabelMe 格式 JSON") def export_labelme_json(image_id: int, db_conn: PooledMySQLConnection = db_dependency): """ 获取图片的 JSON 数据并转换为 LabelMe 格式。 优先读取 modified_json,如果没有则读取 detection_json。 """ cursor = None try: cursor = db_conn.cursor(dictionary=True) # 查询图片路径和JSON query = "SELECT image_path, detection_json, modified_json FROM card_images WHERE id = %s" cursor.execute(query, (image_id,)) row = cursor.fetchone() if not row: raise HTTPException(status_code=404, detail=f"图片 ID {image_id} 未找到") # 优先使用已修改的数据 source_json_str = row['modified_json'] if row['modified_json'] else row['detection_json'] if isinstance(source_json_str, str): source_json = json.loads(source_json_str) else: source_json = source_json_str # 已经是 dict (如果是从 Pydantic 模型来的话,但在原生 cursor 里通常是 str 或 dict) if source_json is None: source_json = {} image_path = row['image_path'] # 转换 labelme_data = convert_internal_to_labelme(image_path, source_json) return labelme_data except Exception as e: logger.error(f"导出 LabelMe JSON 失败 (id={image_id}): {e}") raise HTTPException(status_code=500, detail=f"导出失败: {str(e)}") finally: if cursor: cursor.close() @router.put("/import/{image_id}", summary="接收 LabelMe JSON,重计算分数并保存") async def import_labelme_json( image_id: int, labelme_data: Dict[str, Any] = Body(..., description="LabelMe 格式的 JSON 数据"), db_conn: PooledMySQLConnection = db_dependency ): """ 1. 接收 LabelMe JSON。 2. 转换为系统内部 JSON 格式。 3. 调用项目1的 API 进行重计算 (re-inference & score)。 4. 更新数据库 (modified_json)。 """ cursor = None try: cursor = db_conn.cursor(dictionary=True) # 1. 获取图片基础信息 (用于获取 image_type 和原始 JSON 结构参考) query = "SELECT card_id, image_type, detection_json FROM card_images WHERE id = %s" cursor.execute(query, (image_id,)) row = cursor.fetchone() if not row: raise HTTPException(status_code=404, detail=f"图片 ID {image_id} 未找到") card_id = row['card_id'] image_type = row['image_type'] # 解析原始 JSON 用于辅助转换 (例如获取宽高作为兜底) detection_json = row['detection_json'] if isinstance(detection_json, str): detection_json = json.loads(detection_json) # 2. 转换格式: LabelMe -> Internal # 注意:这里生成的 JSON 只有 points 和 labels,面积和分数需要服务端计算 internal_payload = convert_labelme_to_internal(detection_json, labelme_data) # 确定 score_type score_type = IMAGE_TYPE_TO_SCORE_TYPE.get(image_type) if not score_type: raise HTTPException(status_code=400, detail=f"不支持的图片类型: {image_type}") logger.info(f"正在调用计算服务: {settings.SCORE_RECALCULATE_ENDPOINT}, type={score_type}") # 3. 调用项目1的分数重计算接口 try: response = await run_in_threadpool( lambda: requests.post( settings.SCORE_RECALCULATE_ENDPOINT, params={"score_type": score_type}, json=internal_payload, timeout=30 # 稍微增加超时时间,因为可能涉及推理 ) ) except Exception as e: logger.error(f"连接分数计算服务失败: {e}") raise HTTPException(status_code=502, detail="无法连接到分数计算服务") if response.status_code != 200: logger.error(f"分数计算服务返回错误: {response.text}") raise HTTPException(status_code=response.status_code, detail=f"分数计算失败: {response.text}") recalculated_json = response.json() logger.info("分数重计算完成") # 4. 更新数据库 recalculated_json_str = json.dumps(recalculated_json, ensure_ascii=False) update_query = ( "UPDATE card_images " "SET modified_json = %s, is_edited = TRUE " "WHERE id = %s" ) cursor.execute(update_query, (recalculated_json_str, image_id)) db_conn.commit() # 5. 更新 Card 维度的总分状态 try: crud_card.update_card_scores_and_status(db_conn, card_id) except Exception as e: logger.error(f"更新卡牌总分失败: {e}") # 不阻断主流程 return { "success": True, "message": "LabelMe 数据导入并重计算成功", "image_id": image_id } except HTTPException as he: raise he except Exception as e: db_conn.rollback() logger.error(f"导入 LabelMe 数据失败: {e}") raise HTTPException(status_code=500, detail=f"系统内部错误: {str(e)}") finally: if cursor: cursor.close()