|
@@ -0,0 +1,244 @@
|
|
|
|
|
+# app/api/image_data.py
|
|
|
|
|
+
|
|
|
|
|
+import os
|
|
|
|
|
+import uuid
|
|
|
|
|
+import json # <-- 确保导入了json库
|
|
|
|
|
+from typing import Optional, Dict, Any, List
|
|
|
|
|
+from fastapi import APIRouter, File, UploadFile, Depends, HTTPException, Form, Query
|
|
|
|
|
+from fastapi.responses import JSONResponse
|
|
|
|
|
+from fastapi.concurrency import run_in_threadpool
|
|
|
|
|
+from pydantic import BaseModel, field_validator
|
|
|
|
|
+import mysql.connector
|
|
|
|
|
+
|
|
|
|
|
+from app.core.config import settings
|
|
|
|
|
+from app.core.logger import get_logger
|
|
|
|
|
+from app.core.database_loader import get_db_connection
|
|
|
|
|
+from app.services.score_service import ScoreService
|
|
|
|
|
+from app.api.score_inference import ScoreType
|
|
|
|
|
+
|
|
|
|
|
+logger = get_logger(__name__)
|
|
|
|
|
+router = APIRouter()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# --- Pydantic 模型定义 ---
|
|
|
|
|
+class ImageDataUpdate(BaseModel):
|
|
|
|
|
+ img_result_json: Dict[str, Any]
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class ImageDataResponse(BaseModel):
|
|
|
|
|
+ img_id: int
|
|
|
|
|
+ img_name: Optional[str] = None
|
|
|
|
|
+ img_path: str
|
|
|
|
|
+ img_result_json: Dict[str, Any]
|
|
|
|
|
+
|
|
|
|
|
+ # 添加一个验证器来自动处理从数据库来的字符串
|
|
|
|
|
+ @field_validator('img_result_json', mode='before')
|
|
|
|
|
+ @classmethod
|
|
|
|
|
+ def parse_json_string(cls, value):
|
|
|
|
|
+ if isinstance(value, str):
|
|
|
|
|
+ try:
|
|
|
|
|
+ return json.loads(value)
|
|
|
|
|
+ except json.JSONDecodeError:
|
|
|
|
|
+ raise ValueError("Invalid JSON string provided")
|
|
|
|
|
+ return value
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# --- API 端点实现 ---
|
|
|
|
|
+
|
|
|
|
|
+@router.post("/", response_model=ImageDataResponse, summary="创建新的图片记录")
|
|
|
|
|
+async def create_image_data(
|
|
|
|
|
+ score_type: ScoreType = Form(...),
|
|
|
|
|
+ is_reflect_card: bool = Form(False),
|
|
|
|
|
+ img_name: Optional[str] = Form(None),
|
|
|
|
|
+ file: UploadFile = File(...),
|
|
|
|
|
+ db_conn: mysql.connector.connection.MySQLConnection = Depends(get_db_connection)
|
|
|
|
|
+):
|
|
|
|
|
+ # 1. 保存图片到本地
|
|
|
|
|
+ file_extension = os.path.splitext(file.filename)[1]
|
|
|
|
|
+ unique_filename = f"{uuid.uuid4()}{file_extension}"
|
|
|
|
|
+ img_path = settings.DATA_DIR / unique_filename
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ image_bytes = await file.read()
|
|
|
|
|
+ with open(img_path, "wb") as f:
|
|
|
|
|
+ f.write(image_bytes)
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"保存图片失败: {e}")
|
|
|
|
|
+ raise HTTPException(status_code=500, detail="无法保存图片文件")
|
|
|
|
|
+
|
|
|
|
|
+ # 2. 调用 ScoreService 生成 JSON 数据
|
|
|
|
|
+ try:
|
|
|
|
|
+ service = ScoreService()
|
|
|
|
|
+ json_result = await run_in_threadpool(
|
|
|
|
|
+ service.score_inference,
|
|
|
|
|
+ score_type=score_type.value,
|
|
|
|
|
+ is_reflect_card=is_reflect_card,
|
|
|
|
|
+ image_bytes=image_bytes
|
|
|
|
|
+ )
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ # 如果推理失败,删除已保存的图片
|
|
|
|
|
+ os.remove(img_path)
|
|
|
|
|
+ logger.error(f"分数推理失败: {e}")
|
|
|
|
|
+ raise HTTPException(status_code=500, detail=f"分数推理时发生错误: {e}")
|
|
|
|
|
+
|
|
|
|
|
+ # 3. 存入数据库
|
|
|
|
|
+ cursor = None
|
|
|
|
|
+ try:
|
|
|
|
|
+ cursor = db_conn.cursor(dictionary=True)
|
|
|
|
|
+ insert_query = (
|
|
|
|
|
+ f"INSERT INTO {settings.DB_TABLE_NAME} (img_name, img_path, img_result_json) "
|
|
|
|
|
+ "VALUES (%s, %s, %s)"
|
|
|
|
|
+ )
|
|
|
|
|
+ json_string = json.dumps(json_result, ensure_ascii=False)
|
|
|
|
|
+ cursor.execute(insert_query, (img_name, str(img_path), json_string))
|
|
|
|
|
+ new_id = cursor.lastrowid
|
|
|
|
|
+ db_conn.commit()
|
|
|
|
|
+
|
|
|
|
|
+ logger.info(f"成功创建记录, ID: {new_id}")
|
|
|
|
|
+ return {
|
|
|
|
|
+ "img_id": new_id,
|
|
|
|
|
+ "img_name": img_name,
|
|
|
|
|
+ "img_path": str(img_path),
|
|
|
|
|
+ "img_result_json": json_result
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ except mysql.connector.Error as err:
|
|
|
|
|
+ db_conn.rollback()
|
|
|
|
|
+ os.remove(img_path)
|
|
|
|
|
+ logger.error(f"数据库插入失败: {err}")
|
|
|
|
|
+ raise HTTPException(status_code=500, detail="数据库操作失败")
|
|
|
|
|
+ finally:
|
|
|
|
|
+ if cursor:
|
|
|
|
|
+ cursor.close()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@router.get("/{img_id}", response_model=ImageDataResponse, summary="根据ID查询记录")
|
|
|
|
|
+def get_image_data_by_id(
|
|
|
|
|
+ img_id: int,
|
|
|
|
|
+ db_conn: mysql.connector.connection.MySQLConnection = Depends(get_db_connection)
|
|
|
|
|
+):
|
|
|
|
|
+ cursor = None
|
|
|
|
|
+ try:
|
|
|
|
|
+ cursor = db_conn.cursor(dictionary=True)
|
|
|
|
|
+ query = f"SELECT * FROM {settings.DB_TABLE_NAME} WHERE img_id = %s"
|
|
|
|
|
+ cursor.execute(query, (img_id,))
|
|
|
|
|
+ record = cursor.fetchone()
|
|
|
|
|
+
|
|
|
|
|
+ if record is None:
|
|
|
|
|
+ raise HTTPException(status_code=404, detail="记录未找到")
|
|
|
|
|
+
|
|
|
|
|
+ # Pydantic 验证器会自动处理转换,这里不再需要手动转换
|
|
|
|
|
+ return record
|
|
|
|
|
+
|
|
|
|
|
+ except mysql.connector.Error as err:
|
|
|
|
|
+ logger.error(f"数据库查询失败: {err}")
|
|
|
|
|
+ raise HTTPException(status_code=500, detail="数据库查询失败")
|
|
|
|
|
+ finally:
|
|
|
|
|
+ if cursor:
|
|
|
|
|
+ cursor.close()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@router.get("/", response_model=List[ImageDataResponse], summary="根据名称查询记录")
|
|
|
|
|
+def get_image_data_by_name(
|
|
|
|
|
+ img_name: str,
|
|
|
|
|
+ db_conn: mysql.connector.connection.MySQLConnection = Depends(get_db_connection)
|
|
|
|
|
+):
|
|
|
|
|
+ cursor = None
|
|
|
|
|
+ try:
|
|
|
|
|
+ cursor = db_conn.cursor(dictionary=True)
|
|
|
|
|
+ query = f"SELECT * FROM {settings.DB_TABLE_NAME} WHERE img_name = %s"
|
|
|
|
|
+ cursor.execute(query, (img_name,))
|
|
|
|
|
+ records = cursor.fetchall()
|
|
|
|
|
+
|
|
|
|
|
+ # Pydantic 验证器会自动处理转换,这里不再需要手动转换
|
|
|
|
|
+ return records
|
|
|
|
|
+
|
|
|
|
|
+ except mysql.connector.Error as err:
|
|
|
|
|
+ logger.error(f"数据库查询失败: {err}")
|
|
|
|
|
+ raise HTTPException(status_code=500, detail="数据库查询失败")
|
|
|
|
|
+ finally:
|
|
|
|
|
+ if cursor:
|
|
|
|
|
+ cursor.close()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@router.put("/{img_id}", response_model=ImageDataResponse, summary="更新记录的JSON数据")
|
|
|
|
|
+def update_image_data_json(
|
|
|
|
|
+ img_id: int,
|
|
|
|
|
+ data: ImageDataUpdate,
|
|
|
|
|
+ db_conn: mysql.connector.connection.MySQLConnection = Depends(get_db_connection)
|
|
|
|
|
+):
|
|
|
|
|
+ # ... (这个函数逻辑保持不变)
|
|
|
|
|
+ cursor = None
|
|
|
|
|
+ try:
|
|
|
|
|
+ cursor = db_conn.cursor(dictionary=True)
|
|
|
|
|
+ check_query = f"SELECT img_path FROM {settings.DB_TABLE_NAME} WHERE img_id = %s"
|
|
|
|
|
+ cursor.execute(check_query, (img_id,))
|
|
|
|
|
+ record = cursor.fetchone()
|
|
|
|
|
+ if not record:
|
|
|
|
|
+ raise HTTPException(status_code=404, detail="记录未找到")
|
|
|
|
|
+
|
|
|
|
|
+ update_query = (
|
|
|
|
|
+ f"UPDATE {settings.DB_TABLE_NAME} SET img_result_json = %s WHERE img_id = %s"
|
|
|
|
|
+ )
|
|
|
|
|
+ json_string = json.dumps(data.img_result_json, ensure_ascii=False)
|
|
|
|
|
+ cursor.execute(update_query, (json_string, img_id))
|
|
|
|
|
+ db_conn.commit()
|
|
|
|
|
+
|
|
|
|
|
+ if cursor.rowcount == 0:
|
|
|
|
|
+ raise HTTPException(status_code=404, detail="记录未找到或数据未改变")
|
|
|
|
|
+
|
|
|
|
|
+ logger.info(f"成功更新记录, ID: {img_id}")
|
|
|
|
|
+ return get_image_data_by_id(img_id, db_conn)
|
|
|
|
|
+
|
|
|
|
|
+ except mysql.connector.Error as err:
|
|
|
|
|
+ db_conn.rollback()
|
|
|
|
|
+ logger.error(f"数据库更新失败: {err}")
|
|
|
|
|
+ raise HTTPException(status_code=500, detail="数据库更新失败")
|
|
|
|
|
+ finally:
|
|
|
|
|
+ if cursor:
|
|
|
|
|
+ cursor.close()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@router.delete("/{img_id}", summary="删除记录和对应的图片文件")
|
|
|
|
|
+def delete_image_data(
|
|
|
|
|
+ img_id: int,
|
|
|
|
|
+ db_conn: mysql.connector.connection.MySQLConnection = Depends(get_db_connection)
|
|
|
|
|
+):
|
|
|
|
|
+ # ... (这个函数无需修改)
|
|
|
|
|
+ cursor = None
|
|
|
|
|
+ try:
|
|
|
|
|
+ cursor = db_conn.cursor(dictionary=True)
|
|
|
|
|
+
|
|
|
|
|
+ query = f"SELECT img_path FROM {settings.DB_TABLE_NAME} WHERE img_id = %s"
|
|
|
|
|
+ cursor.execute(query, (img_id,))
|
|
|
|
|
+ record = cursor.fetchone()
|
|
|
|
|
+
|
|
|
|
|
+ if not record:
|
|
|
|
|
+ raise HTTPException(status_code=404, detail="记录未找到")
|
|
|
|
|
+
|
|
|
|
|
+ img_path = record['img_path']
|
|
|
|
|
+
|
|
|
|
|
+ delete_query = f"DELETE FROM {settings.DB_TABLE_NAME} WHERE img_id = %s"
|
|
|
|
|
+ cursor.execute(delete_query, (img_id,))
|
|
|
|
|
+ db_conn.commit()
|
|
|
|
|
+
|
|
|
|
|
+ if cursor.rowcount == 0:
|
|
|
|
|
+ raise HTTPException(status_code=404, detail="记录删除失败,可能已被删除")
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ if os.path.exists(img_path):
|
|
|
|
|
+ os.remove(img_path)
|
|
|
|
|
+ logger.info(f"成功删除图片文件: {img_path}")
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"删除图片文件失败: {img_path}, 错误: {e}")
|
|
|
|
|
+
|
|
|
|
|
+ logger.info(f"成功删除记录, ID: {img_id}")
|
|
|
|
|
+ return JSONResponse(content={"message": f"记录 {img_id} 已成功删除"}, status_code=200)
|
|
|
|
|
+
|
|
|
|
|
+ except mysql.connector.Error as err:
|
|
|
|
|
+ db_conn.rollback()
|
|
|
|
|
+ logger.error(f"数据库删除失败: {err}")
|
|
|
|
|
+ raise HTTPException(status_code=500, detail="数据库删除失败")
|
|
|
|
|
+ finally:
|
|
|
|
|
+ if cursor:
|
|
|
|
|
+ cursor.close()
|