labelme.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import json
  2. import requests
  3. from typing import Optional, Dict, Any, List
  4. from fastapi import APIRouter, File, UploadFile, Depends, HTTPException, Body
  5. from fastapi.concurrency import run_in_threadpool
  6. from mysql.connector.pooling import PooledMySQLConnection
  7. from app.core.config import settings
  8. from app.core.logger import get_logger
  9. from app.core.database_loader import get_db_connection
  10. from app.crud import crud_card
  11. from app.utils.scheme import IMAGE_TYPE_TO_SCORE_TYPE
  12. from app.utils.labelme_process import convert_internal_to_labelme, convert_labelme_to_internal
  13. logger = get_logger(__name__)
  14. router = APIRouter()
  15. db_dependency = Depends(get_db_connection)
  16. @router.get("/export/{image_id}", summary="获取指定图片的 LabelMe 格式 JSON")
  17. def export_labelme_json(image_id: int, db_conn: PooledMySQLConnection = db_dependency):
  18. """
  19. 获取图片的 JSON 数据并转换为 LabelMe 格式。
  20. 优先读取 modified_json,如果没有则读取 detection_json。
  21. """
  22. cursor = None
  23. try:
  24. cursor = db_conn.cursor(dictionary=True)
  25. # 查询图片路径和JSON
  26. query = "SELECT image_path, detection_json, modified_json FROM card_images WHERE id = %s"
  27. cursor.execute(query, (image_id,))
  28. row = cursor.fetchone()
  29. if not row:
  30. raise HTTPException(status_code=404, detail=f"图片 ID {image_id} 未找到")
  31. # 优先使用已修改的数据
  32. source_json_str = row['modified_json'] if row['modified_json'] else row['detection_json']
  33. if isinstance(source_json_str, str):
  34. source_json = json.loads(source_json_str)
  35. else:
  36. source_json = source_json_str # 已经是 dict (如果是从 Pydantic 模型来的话,但在原生 cursor 里通常是 str 或 dict)
  37. if source_json is None: source_json = {}
  38. image_path = row['image_path']
  39. # 转换
  40. labelme_data = convert_internal_to_labelme(image_path, source_json)
  41. return labelme_data
  42. except Exception as e:
  43. logger.error(f"导出 LabelMe JSON 失败 (id={image_id}): {e}")
  44. raise HTTPException(status_code=500, detail=f"导出失败: {str(e)}")
  45. finally:
  46. if cursor:
  47. cursor.close()
  48. @router.put("/import/{image_id}", summary="接收 LabelMe JSON,重计算分数并保存")
  49. async def import_labelme_json(
  50. image_id: int,
  51. labelme_data: Dict[str, Any] = Body(..., description="LabelMe 格式的 JSON 数据"),
  52. db_conn: PooledMySQLConnection = db_dependency
  53. ):
  54. """
  55. 1. 接收 LabelMe JSON。
  56. 2. 转换为系统内部 JSON 格式。
  57. 3. 调用项目1的 API 进行重计算 (re-inference & score)。
  58. 4. 更新数据库 (modified_json)。
  59. """
  60. cursor = None
  61. try:
  62. cursor = db_conn.cursor(dictionary=True)
  63. # 1. 获取图片基础信息 (用于获取 image_type 和原始 JSON 结构参考)
  64. query = "SELECT card_id, image_type, detection_json FROM card_images WHERE id = %s"
  65. cursor.execute(query, (image_id,))
  66. row = cursor.fetchone()
  67. if not row:
  68. raise HTTPException(status_code=404, detail=f"图片 ID {image_id} 未找到")
  69. card_id = row['card_id']
  70. image_type = row['image_type']
  71. # 解析原始 JSON 用于辅助转换 (例如获取宽高作为兜底)
  72. detection_json = row['detection_json']
  73. if isinstance(detection_json, str):
  74. detection_json = json.loads(detection_json)
  75. # 2. 转换格式: LabelMe -> Internal
  76. # 注意:这里生成的 JSON 只有 points 和 labels,面积和分数需要服务端计算
  77. internal_payload = convert_labelme_to_internal(detection_json, labelme_data)
  78. # 确定 score_type
  79. score_type = IMAGE_TYPE_TO_SCORE_TYPE.get(image_type)
  80. if not score_type:
  81. raise HTTPException(status_code=400, detail=f"不支持的图片类型: {image_type}")
  82. logger.info(f"正在调用计算服务: {settings.SCORE_RECALCULATE_ENDPOINT}, type={score_type}")
  83. # 3. 调用项目1的分数重计算接口
  84. try:
  85. response = await run_in_threadpool(
  86. lambda: requests.post(
  87. settings.SCORE_RECALCULATE_ENDPOINT,
  88. params={"score_type": score_type},
  89. json=internal_payload,
  90. timeout=30 # 稍微增加超时时间,因为可能涉及推理
  91. )
  92. )
  93. except Exception as e:
  94. logger.error(f"连接分数计算服务失败: {e}")
  95. raise HTTPException(status_code=502, detail="无法连接到分数计算服务")
  96. if response.status_code != 200:
  97. logger.error(f"分数计算服务返回错误: {response.text}")
  98. raise HTTPException(status_code=response.status_code, detail=f"分数计算失败: {response.text}")
  99. recalculated_json = response.json()
  100. logger.info("分数重计算完成")
  101. # 4. 更新数据库
  102. recalculated_json_str = json.dumps(recalculated_json, ensure_ascii=False)
  103. update_query = (
  104. "UPDATE card_images "
  105. "SET modified_json = %s, is_edited = TRUE "
  106. "WHERE id = %s"
  107. )
  108. cursor.execute(update_query, (recalculated_json_str, image_id))
  109. db_conn.commit()
  110. # 5. 更新 Card 维度的总分状态
  111. try:
  112. crud_card.update_card_scores_and_status(db_conn, card_id)
  113. except Exception as e:
  114. logger.error(f"更新卡牌总分失败: {e}")
  115. # 不阻断主流程
  116. return {
  117. "success": True,
  118. "message": "LabelMe 数据导入并重计算成功",
  119. "image_id": image_id
  120. }
  121. except HTTPException as he:
  122. raise he
  123. except Exception as e:
  124. db_conn.rollback()
  125. logger.error(f"导入 LabelMe 数据失败: {e}")
  126. raise HTTPException(status_code=500, detail=f"系统内部错误: {str(e)}")
  127. finally:
  128. if cursor:
  129. cursor.close()