image_data.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. import os
  2. import uuid
  3. import json
  4. from typing import Optional, Dict, Any, List
  5. from fastapi import APIRouter, File, UploadFile, Depends, HTTPException, Form, Query
  6. from fastapi.responses import JSONResponse
  7. from fastapi.concurrency import run_in_threadpool
  8. from pydantic import BaseModel, field_validator
  9. import mysql.connector
  10. from app.core.config import settings
  11. from app.core.logger import get_logger
  12. from app.core.database_loader import get_db_connection
  13. from app.services.score_service import ScoreService
  14. from app.api.score_inference import ScoreType
  15. logger = get_logger(__name__)
  16. router = APIRouter()
  17. # --- Pydantic 模型定义 ---
  18. class ImageDataUpdate(BaseModel):
  19. img_result_json: Dict[str, Any]
  20. class ImageDataResponse(BaseModel):
  21. img_id: int
  22. img_name: Optional[str] = None
  23. img_path: str
  24. img_result_json: Dict[str, Any]
  25. # 添加一个验证器来自动处理从数据库来的字符串
  26. @field_validator('img_result_json', mode='before')
  27. @classmethod
  28. def parse_json_string(cls, value):
  29. if isinstance(value, str):
  30. try:
  31. return json.loads(value)
  32. except json.JSONDecodeError:
  33. raise ValueError("Invalid JSON string provided")
  34. return value
  35. # --- API 端点实现 ---
  36. @router.post("/", response_model=ImageDataResponse, summary="创建新的图片记录")
  37. async def create_image_data(
  38. score_type: ScoreType = Form(...),
  39. is_reflect_card: bool = Form(False),
  40. img_name: Optional[str] = Form(None),
  41. file: UploadFile = File(...),
  42. db_conn: mysql.connector.connection.MySQLConnection = Depends(get_db_connection)
  43. ):
  44. # 1. 保存图片到本地
  45. file_extension = os.path.splitext(file.filename)[1]
  46. unique_filename = f"{uuid.uuid4()}{file_extension}"
  47. img_path = settings.DATA_DIR / unique_filename
  48. try:
  49. image_bytes = await file.read()
  50. with open(img_path, "wb") as f:
  51. f.write(image_bytes)
  52. except Exception as e:
  53. logger.error(f"保存图片失败: {e}")
  54. raise HTTPException(status_code=500, detail="无法保存图片文件")
  55. # 2. 调用 ScoreService 生成 JSON 数据
  56. try:
  57. service = ScoreService()
  58. json_result = await run_in_threadpool(
  59. service.score_inference,
  60. score_type=score_type.value,
  61. is_reflect_card=is_reflect_card,
  62. image_bytes=image_bytes
  63. )
  64. except Exception as e:
  65. # 如果推理失败,删除已保存的图片
  66. os.remove(img_path)
  67. logger.error(f"分数推理失败: {e}")
  68. raise HTTPException(status_code=500, detail=f"分数推理时发生错误: {e}")
  69. # 3. 存入数据库
  70. cursor = None
  71. try:
  72. cursor = db_conn.cursor(dictionary=True)
  73. insert_query = (
  74. f"INSERT INTO {settings.DB_TABLE_NAME} (img_name, img_path, img_result_json) "
  75. "VALUES (%s, %s, %s)"
  76. )
  77. json_string = json.dumps(json_result, ensure_ascii=False)
  78. cursor.execute(insert_query, (img_name, str(img_path), json_string))
  79. new_id = cursor.lastrowid
  80. db_conn.commit()
  81. logger.info(f"成功创建记录, ID: {new_id}")
  82. return {
  83. "img_id": new_id,
  84. "img_name": img_name,
  85. "img_path": str(img_path),
  86. "img_result_json": json_result
  87. }
  88. except mysql.connector.Error as err:
  89. db_conn.rollback()
  90. os.remove(img_path)
  91. logger.error(f"数据库插入失败: {err}")
  92. raise HTTPException(status_code=500, detail="数据库操作失败")
  93. finally:
  94. if cursor:
  95. cursor.close()
  96. @router.get("/{img_id}", response_model=ImageDataResponse, summary="根据ID查询记录")
  97. def get_image_data_by_id(
  98. img_id: int,
  99. db_conn: mysql.connector.connection.MySQLConnection = Depends(get_db_connection)
  100. ):
  101. cursor = None
  102. try:
  103. cursor = db_conn.cursor(dictionary=True)
  104. query = f"SELECT * FROM {settings.DB_TABLE_NAME} WHERE img_id = %s"
  105. cursor.execute(query, (img_id,))
  106. record = cursor.fetchone()
  107. if record is None:
  108. raise HTTPException(status_code=404, detail="记录未找到")
  109. # Pydantic 验证器会自动处理转换,这里不再需要手动转换
  110. return record
  111. except mysql.connector.Error as err:
  112. logger.error(f"数据库查询失败: {err}")
  113. raise HTTPException(status_code=500, detail="数据库查询失败")
  114. finally:
  115. if cursor:
  116. cursor.close()
  117. @router.get("/", response_model=List[ImageDataResponse], summary="根据名称查询记录")
  118. def get_image_data_by_name(
  119. img_name: str,
  120. db_conn: mysql.connector.connection.MySQLConnection = Depends(get_db_connection)
  121. ):
  122. cursor = None
  123. try:
  124. cursor = db_conn.cursor(dictionary=True)
  125. query = f"SELECT * FROM {settings.DB_TABLE_NAME} WHERE img_name = %s"
  126. cursor.execute(query, (img_name,))
  127. records = cursor.fetchall()
  128. # Pydantic 验证器会自动处理转换,这里不再需要手动转换
  129. return records
  130. except mysql.connector.Error as err:
  131. logger.error(f"数据库查询失败: {err}")
  132. raise HTTPException(status_code=500, detail="数据库查询失败")
  133. finally:
  134. if cursor:
  135. cursor.close()
  136. @router.put("/{img_id}", response_model=ImageDataResponse, summary="更新记录的JSON数据")
  137. def update_image_data_json(
  138. img_id: int,
  139. data: ImageDataUpdate,
  140. db_conn: mysql.connector.connection.MySQLConnection = Depends(get_db_connection)
  141. ):
  142. # ... (这个函数逻辑保持不变)
  143. cursor = None
  144. try:
  145. cursor = db_conn.cursor(dictionary=True)
  146. check_query = f"SELECT img_path FROM {settings.DB_TABLE_NAME} WHERE img_id = %s"
  147. cursor.execute(check_query, (img_id,))
  148. record = cursor.fetchone()
  149. if not record:
  150. raise HTTPException(status_code=404, detail="记录未找到")
  151. update_query = (
  152. f"UPDATE {settings.DB_TABLE_NAME} SET img_result_json = %s WHERE img_id = %s"
  153. )
  154. json_string = json.dumps(data.img_result_json, ensure_ascii=False)
  155. cursor.execute(update_query, (json_string, img_id))
  156. db_conn.commit()
  157. if cursor.rowcount == 0:
  158. raise HTTPException(status_code=404, detail="记录未找到或数据未改变")
  159. logger.info(f"成功更新记录, ID: {img_id}")
  160. return get_image_data_by_id(img_id, db_conn)
  161. except mysql.connector.Error as err:
  162. db_conn.rollback()
  163. logger.error(f"数据库更新失败: {err}")
  164. raise HTTPException(status_code=500, detail="数据库更新失败")
  165. finally:
  166. if cursor:
  167. cursor.close()
  168. @router.delete("/{img_id}", summary="删除记录和对应的图片文件")
  169. def delete_image_data(
  170. img_id: int,
  171. db_conn: mysql.connector.connection.MySQLConnection = Depends(get_db_connection)
  172. ):
  173. # ... (这个函数无需修改)
  174. cursor = None
  175. try:
  176. cursor = db_conn.cursor(dictionary=True)
  177. query = f"SELECT img_path FROM {settings.DB_TABLE_NAME} WHERE img_id = %s"
  178. cursor.execute(query, (img_id,))
  179. record = cursor.fetchone()
  180. if not record:
  181. raise HTTPException(status_code=404, detail="记录未找到")
  182. img_path = record['img_path']
  183. delete_query = f"DELETE FROM {settings.DB_TABLE_NAME} WHERE img_id = %s"
  184. cursor.execute(delete_query, (img_id,))
  185. db_conn.commit()
  186. if cursor.rowcount == 0:
  187. raise HTTPException(status_code=404, detail="记录删除失败,可能已被删除")
  188. try:
  189. if os.path.exists(img_path):
  190. os.remove(img_path)
  191. logger.info(f"成功删除图片文件: {img_path}")
  192. except Exception as e:
  193. logger.error(f"删除图片文件失败: {img_path}, 错误: {e}")
  194. logger.info(f"成功删除记录, ID: {img_id}")
  195. return JSONResponse(content={"message": f"记录 {img_id} 已成功删除"}, status_code=200)
  196. except mysql.connector.Error as err:
  197. db_conn.rollback()
  198. logger.error(f"数据库删除失败: {err}")
  199. raise HTTPException(status_code=500, detail="数据库删除失败")
  200. finally:
  201. if cursor:
  202. cursor.close()