image_data.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. # app/api/image_data.py
  2. import os
  3. import uuid
  4. import json # <-- 确保导入了json库
  5. from typing import Optional, Dict, Any, List
  6. from fastapi import APIRouter, File, UploadFile, Depends, HTTPException, Form, Query
  7. from fastapi.responses import JSONResponse
  8. from fastapi.concurrency import run_in_threadpool
  9. from pydantic import BaseModel, field_validator
  10. import mysql.connector
  11. from app.core.config import settings
  12. from app.core.logger import get_logger
  13. from app.core.database_loader import get_db_connection
  14. from app.services.score_service import ScoreService
  15. from app.api.score_inference import ScoreType
  16. logger = get_logger(__name__)
  17. router = APIRouter()
  18. # --- Pydantic 模型定义 ---
  19. class ImageDataUpdate(BaseModel):
  20. img_result_json: Dict[str, Any]
  21. class ImageDataResponse(BaseModel):
  22. img_id: int
  23. img_name: Optional[str] = None
  24. img_path: str
  25. img_result_json: Dict[str, Any]
  26. # 添加一个验证器来自动处理从数据库来的字符串
  27. @field_validator('img_result_json', mode='before')
  28. @classmethod
  29. def parse_json_string(cls, value):
  30. if isinstance(value, str):
  31. try:
  32. return json.loads(value)
  33. except json.JSONDecodeError:
  34. raise ValueError("Invalid JSON string provided")
  35. return value
  36. # --- API 端点实现 ---
  37. @router.post("/", response_model=ImageDataResponse, summary="创建新的图片记录")
  38. async def create_image_data(
  39. score_type: ScoreType = Form(...),
  40. is_reflect_card: bool = Form(False),
  41. img_name: Optional[str] = Form(None),
  42. file: UploadFile = File(...),
  43. db_conn: mysql.connector.connection.MySQLConnection = Depends(get_db_connection)
  44. ):
  45. # 1. 保存图片到本地
  46. file_extension = os.path.splitext(file.filename)[1]
  47. unique_filename = f"{uuid.uuid4()}{file_extension}"
  48. img_path = settings.DATA_DIR / unique_filename
  49. try:
  50. image_bytes = await file.read()
  51. with open(img_path, "wb") as f:
  52. f.write(image_bytes)
  53. except Exception as e:
  54. logger.error(f"保存图片失败: {e}")
  55. raise HTTPException(status_code=500, detail="无法保存图片文件")
  56. # 2. 调用 ScoreService 生成 JSON 数据
  57. try:
  58. service = ScoreService()
  59. json_result = await run_in_threadpool(
  60. service.score_inference,
  61. score_type=score_type.value,
  62. is_reflect_card=is_reflect_card,
  63. image_bytes=image_bytes
  64. )
  65. except Exception as e:
  66. # 如果推理失败,删除已保存的图片
  67. os.remove(img_path)
  68. logger.error(f"分数推理失败: {e}")
  69. raise HTTPException(status_code=500, detail=f"分数推理时发生错误: {e}")
  70. # 3. 存入数据库
  71. cursor = None
  72. try:
  73. cursor = db_conn.cursor(dictionary=True)
  74. insert_query = (
  75. f"INSERT INTO {settings.DB_TABLE_NAME} (img_name, img_path, img_result_json) "
  76. "VALUES (%s, %s, %s)"
  77. )
  78. json_string = json.dumps(json_result, ensure_ascii=False)
  79. cursor.execute(insert_query, (img_name, str(img_path), json_string))
  80. new_id = cursor.lastrowid
  81. db_conn.commit()
  82. logger.info(f"成功创建记录, ID: {new_id}")
  83. return {
  84. "img_id": new_id,
  85. "img_name": img_name,
  86. "img_path": str(img_path),
  87. "img_result_json": json_result
  88. }
  89. except mysql.connector.Error as err:
  90. db_conn.rollback()
  91. os.remove(img_path)
  92. logger.error(f"数据库插入失败: {err}")
  93. raise HTTPException(status_code=500, detail="数据库操作失败")
  94. finally:
  95. if cursor:
  96. cursor.close()
  97. @router.get("/{img_id}", response_model=ImageDataResponse, summary="根据ID查询记录")
  98. def get_image_data_by_id(
  99. img_id: int,
  100. db_conn: mysql.connector.connection.MySQLConnection = Depends(get_db_connection)
  101. ):
  102. cursor = None
  103. try:
  104. cursor = db_conn.cursor(dictionary=True)
  105. query = f"SELECT * FROM {settings.DB_TABLE_NAME} WHERE img_id = %s"
  106. cursor.execute(query, (img_id,))
  107. record = cursor.fetchone()
  108. if record is None:
  109. raise HTTPException(status_code=404, detail="记录未找到")
  110. # Pydantic 验证器会自动处理转换,这里不再需要手动转换
  111. return record
  112. except mysql.connector.Error as err:
  113. logger.error(f"数据库查询失败: {err}")
  114. raise HTTPException(status_code=500, detail="数据库查询失败")
  115. finally:
  116. if cursor:
  117. cursor.close()
  118. @router.get("/", response_model=List[ImageDataResponse], summary="根据名称查询记录")
  119. def get_image_data_by_name(
  120. img_name: str,
  121. db_conn: mysql.connector.connection.MySQLConnection = Depends(get_db_connection)
  122. ):
  123. cursor = None
  124. try:
  125. cursor = db_conn.cursor(dictionary=True)
  126. query = f"SELECT * FROM {settings.DB_TABLE_NAME} WHERE img_name = %s"
  127. cursor.execute(query, (img_name,))
  128. records = cursor.fetchall()
  129. # Pydantic 验证器会自动处理转换,这里不再需要手动转换
  130. return records
  131. except mysql.connector.Error as err:
  132. logger.error(f"数据库查询失败: {err}")
  133. raise HTTPException(status_code=500, detail="数据库查询失败")
  134. finally:
  135. if cursor:
  136. cursor.close()
  137. @router.put("/{img_id}", response_model=ImageDataResponse, summary="更新记录的JSON数据")
  138. def update_image_data_json(
  139. img_id: int,
  140. data: ImageDataUpdate,
  141. db_conn: mysql.connector.connection.MySQLConnection = Depends(get_db_connection)
  142. ):
  143. # ... (这个函数逻辑保持不变)
  144. cursor = None
  145. try:
  146. cursor = db_conn.cursor(dictionary=True)
  147. check_query = f"SELECT img_path FROM {settings.DB_TABLE_NAME} WHERE img_id = %s"
  148. cursor.execute(check_query, (img_id,))
  149. record = cursor.fetchone()
  150. if not record:
  151. raise HTTPException(status_code=404, detail="记录未找到")
  152. update_query = (
  153. f"UPDATE {settings.DB_TABLE_NAME} SET img_result_json = %s WHERE img_id = %s"
  154. )
  155. json_string = json.dumps(data.img_result_json, ensure_ascii=False)
  156. cursor.execute(update_query, (json_string, img_id))
  157. db_conn.commit()
  158. if cursor.rowcount == 0:
  159. raise HTTPException(status_code=404, detail="记录未找到或数据未改变")
  160. logger.info(f"成功更新记录, ID: {img_id}")
  161. return get_image_data_by_id(img_id, db_conn)
  162. except mysql.connector.Error as err:
  163. db_conn.rollback()
  164. logger.error(f"数据库更新失败: {err}")
  165. raise HTTPException(status_code=500, detail="数据库更新失败")
  166. finally:
  167. if cursor:
  168. cursor.close()
  169. @router.delete("/{img_id}", summary="删除记录和对应的图片文件")
  170. def delete_image_data(
  171. img_id: int,
  172. db_conn: mysql.connector.connection.MySQLConnection = Depends(get_db_connection)
  173. ):
  174. # ... (这个函数无需修改)
  175. cursor = None
  176. try:
  177. cursor = db_conn.cursor(dictionary=True)
  178. query = f"SELECT img_path FROM {settings.DB_TABLE_NAME} WHERE img_id = %s"
  179. cursor.execute(query, (img_id,))
  180. record = cursor.fetchone()
  181. if not record:
  182. raise HTTPException(status_code=404, detail="记录未找到")
  183. img_path = record['img_path']
  184. delete_query = f"DELETE FROM {settings.DB_TABLE_NAME} WHERE img_id = %s"
  185. cursor.execute(delete_query, (img_id,))
  186. db_conn.commit()
  187. if cursor.rowcount == 0:
  188. raise HTTPException(status_code=404, detail="记录删除失败,可能已被删除")
  189. try:
  190. if os.path.exists(img_path):
  191. os.remove(img_path)
  192. logger.info(f"成功删除图片文件: {img_path}")
  193. except Exception as e:
  194. logger.error(f"删除图片文件失败: {img_path}, 错误: {e}")
  195. logger.info(f"成功删除记录, ID: {img_id}")
  196. return JSONResponse(content={"message": f"记录 {img_id} 已成功删除"}, status_code=200)
  197. except mysql.connector.Error as err:
  198. db_conn.rollback()
  199. logger.error(f"数据库删除失败: {err}")
  200. raise HTTPException(status_code=500, detail="数据库删除失败")
  201. finally:
  202. if cursor:
  203. cursor.close()