|
@@ -1,156 +0,0 @@
|
|
|
-import json
|
|
|
|
|
-import requests
|
|
|
|
|
-from typing import Optional, Dict, Any, List
|
|
|
|
|
-from fastapi import APIRouter, File, UploadFile, Depends, HTTPException, Body
|
|
|
|
|
-from fastapi.concurrency import run_in_threadpool
|
|
|
|
|
-from mysql.connector.pooling import PooledMySQLConnection
|
|
|
|
|
-
|
|
|
|
|
-from app.core.config import settings
|
|
|
|
|
-from app.core.logger import get_logger
|
|
|
|
|
-from app.core.database_loader import get_db_connection
|
|
|
|
|
-from app.crud import crud_card
|
|
|
|
|
-from app.utils.scheme import IMAGE_TYPE_TO_SCORE_TYPE
|
|
|
|
|
-from app.utils.labelme_process import convert_internal_to_labelme, convert_labelme_to_internal
|
|
|
|
|
-
|
|
|
|
|
-logger = get_logger(__name__)
|
|
|
|
|
-router = APIRouter()
|
|
|
|
|
-db_dependency = Depends(get_db_connection)
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-@router.get("/export/{image_id}", summary="获取指定图片的 LabelMe 格式 JSON")
|
|
|
|
|
-def export_labelme_json(image_id: int, db_conn: PooledMySQLConnection = db_dependency):
|
|
|
|
|
- """
|
|
|
|
|
- 获取图片的 JSON 数据并转换为 LabelMe 格式。
|
|
|
|
|
- 优先读取 modified_json,如果没有则读取 detection_json。
|
|
|
|
|
- """
|
|
|
|
|
- cursor = None
|
|
|
|
|
- try:
|
|
|
|
|
- cursor = db_conn.cursor(dictionary=True)
|
|
|
|
|
- # 查询图片路径和JSON
|
|
|
|
|
- query = "SELECT image_path, detection_json, modified_json FROM card_images WHERE id = %s"
|
|
|
|
|
- cursor.execute(query, (image_id,))
|
|
|
|
|
- row = cursor.fetchone()
|
|
|
|
|
-
|
|
|
|
|
- if not row:
|
|
|
|
|
- raise HTTPException(status_code=404, detail=f"图片 ID {image_id} 未找到")
|
|
|
|
|
-
|
|
|
|
|
- # 优先使用已修改的数据
|
|
|
|
|
- source_json_str = row['modified_json'] if row['modified_json'] else row['detection_json']
|
|
|
|
|
-
|
|
|
|
|
- if isinstance(source_json_str, str):
|
|
|
|
|
- source_json = json.loads(source_json_str)
|
|
|
|
|
- else:
|
|
|
|
|
- source_json = source_json_str # 已经是 dict (如果是从 Pydantic 模型来的话,但在原生 cursor 里通常是 str 或 dict)
|
|
|
|
|
- if source_json is None: source_json = {}
|
|
|
|
|
-
|
|
|
|
|
- image_path = row['image_path']
|
|
|
|
|
-
|
|
|
|
|
- # 转换
|
|
|
|
|
- labelme_data = convert_internal_to_labelme(image_path, source_json)
|
|
|
|
|
-
|
|
|
|
|
- return labelme_data
|
|
|
|
|
-
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- logger.error(f"导出 LabelMe JSON 失败 (id={image_id}): {e}")
|
|
|
|
|
- raise HTTPException(status_code=500, detail=f"导出失败: {str(e)}")
|
|
|
|
|
- finally:
|
|
|
|
|
- if cursor:
|
|
|
|
|
- cursor.close()
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-@router.put("/import/{image_id}", summary="接收 LabelMe JSON,重计算分数并保存")
|
|
|
|
|
-async def import_labelme_json(
|
|
|
|
|
- image_id: int,
|
|
|
|
|
- labelme_data: Dict[str, Any] = Body(..., description="LabelMe 格式的 JSON 数据"),
|
|
|
|
|
- db_conn: PooledMySQLConnection = db_dependency
|
|
|
|
|
-):
|
|
|
|
|
- """
|
|
|
|
|
- 1. 接收 LabelMe JSON。
|
|
|
|
|
- 2. 转换为系统内部 JSON 格式。
|
|
|
|
|
- 3. 调用项目1的 API 进行重计算 (re-inference & score)。
|
|
|
|
|
- 4. 更新数据库 (modified_json)。
|
|
|
|
|
- """
|
|
|
|
|
- cursor = None
|
|
|
|
|
- try:
|
|
|
|
|
- cursor = db_conn.cursor(dictionary=True)
|
|
|
|
|
-
|
|
|
|
|
- # 1. 获取图片基础信息 (用于获取 image_type 和原始 JSON 结构参考)
|
|
|
|
|
- query = "SELECT card_id, image_type, detection_json FROM card_images WHERE id = %s"
|
|
|
|
|
- cursor.execute(query, (image_id,))
|
|
|
|
|
- row = cursor.fetchone()
|
|
|
|
|
-
|
|
|
|
|
- if not row:
|
|
|
|
|
- raise HTTPException(status_code=404, detail=f"图片 ID {image_id} 未找到")
|
|
|
|
|
-
|
|
|
|
|
- card_id = row['card_id']
|
|
|
|
|
- image_type = row['image_type']
|
|
|
|
|
-
|
|
|
|
|
- # 解析原始 JSON 用于辅助转换 (例如获取宽高作为兜底)
|
|
|
|
|
- detection_json = row['detection_json']
|
|
|
|
|
- if isinstance(detection_json, str):
|
|
|
|
|
- detection_json = json.loads(detection_json)
|
|
|
|
|
-
|
|
|
|
|
- # 2. 转换格式: LabelMe -> Internal
|
|
|
|
|
- # 注意:这里生成的 JSON 只有 points 和 labels,面积和分数需要服务端计算
|
|
|
|
|
- internal_payload = convert_labelme_to_internal(detection_json, labelme_data)
|
|
|
|
|
-
|
|
|
|
|
- # 确定 score_type
|
|
|
|
|
- score_type = IMAGE_TYPE_TO_SCORE_TYPE.get(image_type)
|
|
|
|
|
- if not score_type:
|
|
|
|
|
- raise HTTPException(status_code=400, detail=f"不支持的图片类型: {image_type}")
|
|
|
|
|
-
|
|
|
|
|
- logger.info(f"正在调用计算服务: {settings.SCORE_RECALCULATE_ENDPOINT}, type={score_type}")
|
|
|
|
|
-
|
|
|
|
|
- # 3. 调用项目1的分数重计算接口
|
|
|
|
|
- try:
|
|
|
|
|
- response = await run_in_threadpool(
|
|
|
|
|
- lambda: requests.post(
|
|
|
|
|
- settings.SCORE_RECALCULATE_ENDPOINT,
|
|
|
|
|
- params={"score_type": score_type},
|
|
|
|
|
- json=internal_payload,
|
|
|
|
|
- timeout=30 # 稍微增加超时时间,因为可能涉及推理
|
|
|
|
|
- )
|
|
|
|
|
- )
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- logger.error(f"连接分数计算服务失败: {e}")
|
|
|
|
|
- raise HTTPException(status_code=502, detail="无法连接到分数计算服务")
|
|
|
|
|
-
|
|
|
|
|
- if response.status_code != 200:
|
|
|
|
|
- logger.error(f"分数计算服务返回错误: {response.text}")
|
|
|
|
|
- raise HTTPException(status_code=response.status_code, detail=f"分数计算失败: {response.text}")
|
|
|
|
|
-
|
|
|
|
|
- recalculated_json = response.json()
|
|
|
|
|
- logger.info("分数重计算完成")
|
|
|
|
|
-
|
|
|
|
|
- # 4. 更新数据库
|
|
|
|
|
- recalculated_json_str = json.dumps(recalculated_json, ensure_ascii=False)
|
|
|
|
|
- update_query = (
|
|
|
|
|
- "UPDATE card_images "
|
|
|
|
|
- "SET modified_json = %s, is_edited = TRUE "
|
|
|
|
|
- "WHERE id = %s"
|
|
|
|
|
- )
|
|
|
|
|
- cursor.execute(update_query, (recalculated_json_str, image_id))
|
|
|
|
|
- db_conn.commit()
|
|
|
|
|
-
|
|
|
|
|
- # 5. 更新 Card 维度的总分状态
|
|
|
|
|
- try:
|
|
|
|
|
- crud_card.update_card_scores_and_status(db_conn, card_id)
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- logger.error(f"更新卡牌总分失败: {e}")
|
|
|
|
|
- # 不阻断主流程
|
|
|
|
|
-
|
|
|
|
|
- return {
|
|
|
|
|
- "success": True,
|
|
|
|
|
- "message": "LabelMe 数据导入并重计算成功",
|
|
|
|
|
- "image_id": image_id
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- except HTTPException as he:
|
|
|
|
|
- raise he
|
|
|
|
|
- except Exception as e:
|
|
|
|
|
- db_conn.rollback()
|
|
|
|
|
- logger.error(f"导入 LabelMe 数据失败: {e}")
|
|
|
|
|
- raise HTTPException(status_code=500, detail=f"系统内部错误: {str(e)}")
|
|
|
|
|
- finally:
|
|
|
|
|
- if cursor:
|
|
|
|
|
- cursor.close()
|
|
|