| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156 |
- import json
- import requests
- from typing import Optional, Dict, Any, List
- from fastapi import APIRouter, File, UploadFile, Depends, HTTPException, Body
- from fastapi.concurrency import run_in_threadpool
- from mysql.connector.pooling import PooledMySQLConnection
- from app.core.config import settings
- from app.core.logger import get_logger
- from app.core.database_loader import get_db_connection
- from app.crud import crud_card
- from app.utils.scheme import IMAGE_TYPE_TO_SCORE_TYPE
- from app.utils.labelme_process import convert_internal_to_labelme, convert_labelme_to_internal
- logger = get_logger(__name__)
- router = APIRouter()
- db_dependency = Depends(get_db_connection)
- @router.get("/export/{image_id}", summary="获取指定图片的 LabelMe 格式 JSON")
- def export_labelme_json(image_id: int, db_conn: PooledMySQLConnection = db_dependency):
- """
- 获取图片的 JSON 数据并转换为 LabelMe 格式。
- 优先读取 modified_json,如果没有则读取 detection_json。
- """
- cursor = None
- try:
- cursor = db_conn.cursor(dictionary=True)
- # 查询图片路径和JSON
- query = "SELECT image_path, detection_json, modified_json FROM card_images WHERE id = %s"
- cursor.execute(query, (image_id,))
- row = cursor.fetchone()
- if not row:
- raise HTTPException(status_code=404, detail=f"图片 ID {image_id} 未找到")
- # 优先使用已修改的数据
- source_json_str = row['modified_json'] if row['modified_json'] else row['detection_json']
- if isinstance(source_json_str, str):
- source_json = json.loads(source_json_str)
- else:
- source_json = source_json_str # 已经是 dict (如果是从 Pydantic 模型来的话,但在原生 cursor 里通常是 str 或 dict)
- if source_json is None: source_json = {}
- image_path = row['image_path']
- # 转换
- labelme_data = convert_internal_to_labelme(image_path, source_json)
- return labelme_data
- except Exception as e:
- logger.error(f"导出 LabelMe JSON 失败 (id={image_id}): {e}")
- raise HTTPException(status_code=500, detail=f"导出失败: {str(e)}")
- finally:
- if cursor:
- cursor.close()
- @router.put("/import/{image_id}", summary="接收 LabelMe JSON,重计算分数并保存")
- async def import_labelme_json(
- image_id: int,
- labelme_data: Dict[str, Any] = Body(..., description="LabelMe 格式的 JSON 数据"),
- db_conn: PooledMySQLConnection = db_dependency
- ):
- """
- 1. 接收 LabelMe JSON。
- 2. 转换为系统内部 JSON 格式。
- 3. 调用项目1的 API 进行重计算 (re-inference & score)。
- 4. 更新数据库 (modified_json)。
- """
- cursor = None
- try:
- cursor = db_conn.cursor(dictionary=True)
- # 1. 获取图片基础信息 (用于获取 image_type 和原始 JSON 结构参考)
- query = "SELECT card_id, image_type, detection_json FROM card_images WHERE id = %s"
- cursor.execute(query, (image_id,))
- row = cursor.fetchone()
- if not row:
- raise HTTPException(status_code=404, detail=f"图片 ID {image_id} 未找到")
- card_id = row['card_id']
- image_type = row['image_type']
- # 解析原始 JSON 用于辅助转换 (例如获取宽高作为兜底)
- detection_json = row['detection_json']
- if isinstance(detection_json, str):
- detection_json = json.loads(detection_json)
- # 2. 转换格式: LabelMe -> Internal
- # 注意:这里生成的 JSON 只有 points 和 labels,面积和分数需要服务端计算
- internal_payload = convert_labelme_to_internal(detection_json, labelme_data)
- # 确定 score_type
- score_type = IMAGE_TYPE_TO_SCORE_TYPE.get(image_type)
- if not score_type:
- raise HTTPException(status_code=400, detail=f"不支持的图片类型: {image_type}")
- logger.info(f"正在调用计算服务: {settings.SCORE_RECALCULATE_ENDPOINT}, type={score_type}")
- # 3. 调用项目1的分数重计算接口
- try:
- response = await run_in_threadpool(
- lambda: requests.post(
- settings.SCORE_RECALCULATE_ENDPOINT,
- params={"score_type": score_type},
- json=internal_payload,
- timeout=30 # 稍微增加超时时间,因为可能涉及推理
- )
- )
- except Exception as e:
- logger.error(f"连接分数计算服务失败: {e}")
- raise HTTPException(status_code=502, detail="无法连接到分数计算服务")
- if response.status_code != 200:
- logger.error(f"分数计算服务返回错误: {response.text}")
- raise HTTPException(status_code=response.status_code, detail=f"分数计算失败: {response.text}")
- recalculated_json = response.json()
- logger.info("分数重计算完成")
- # 4. 更新数据库
- recalculated_json_str = json.dumps(recalculated_json, ensure_ascii=False)
- update_query = (
- "UPDATE card_images "
- "SET modified_json = %s, is_edited = TRUE "
- "WHERE id = %s"
- )
- cursor.execute(update_query, (recalculated_json_str, image_id))
- db_conn.commit()
- # 5. 更新 Card 维度的总分状态
- try:
- crud_card.update_card_scores_and_status(db_conn, card_id)
- except Exception as e:
- logger.error(f"更新卡牌总分失败: {e}")
- # 不阻断主流程
- return {
- "success": True,
- "message": "LabelMe 数据导入并重计算成功",
- "image_id": image_id
- }
- except HTTPException as he:
- raise he
- except Exception as e:
- db_conn.rollback()
- logger.error(f"导入 LabelMe 数据失败: {e}")
- raise HTTPException(status_code=500, detail=f"系统内部错误: {str(e)}")
- finally:
- if cursor:
- cursor.close()
|