Procházet zdrojové kódy

增加分数配置文件的接口

AnlaAnla před 4 týdny
rodič
revize
5ad3b43123

+ 118 - 5
Test/test01.py

@@ -1,8 +1,121 @@
+import torch
+import matplotlib.pyplot as plt
 import numpy as np
 
-box = [[1,2,3],[4,5,6],[7,8,9]]
-box = np.array(box).tolist()
-data = {"11": "345"}
+plt.rcParams['font.family'] = 'SimHei'
 
-data['box'] = box
-print(data)
+# --- 1. 准备非线性数据 ---
+# 我们来创建一个更复杂的数据模式,比如一条带有线性趋势的正弦曲线
+# 真实函数: y = sin(x) + 0.5*x + noise
+TRUE_FUNCTION = lambda x: 2*x + 3
+# TRUE_FUNCTION = lambda x: torch.sin(x) + 0.5 * x
+
+# 创建 100 个数据点
+X = torch.linspace(0, 16 * np.pi, 120).unsqueeze(1)
+y = TRUE_FUNCTION(X) + torch.randn(X.size()) * 0.4  # 加入一些噪声
+
+# 可视化我们创建的数据集
+plt.figure(figsize=(10, 5))
+plt.scatter(X.numpy(), y.numpy(), c='blue', label='原始数据点')
+plt.plot(X.numpy(), TRUE_FUNCTION(X).numpy(), 'g--', label='真实函数曲线')
+plt.title("非线性模拟数据")
+plt.xlabel("X")
+plt.ylabel("y")
+plt.legend()
+plt.grid(True)
+plt.show()
+
+# --- 2. 定义多项式模型和参数 ---
+# 我们的模型是 y_pred = a*x^3 + b*x^2 + c*x + d
+# 我们需要学习 4 个参数: a, b, c, d
+# 随机初始化它们,并设置 requires_grad=True
+a = torch.randn(1, requires_grad=True)
+b = torch.randn(1, requires_grad=True)
+c = torch.randn(1, requires_grad=True)
+d = torch.randn(1, requires_grad=True)
+
+print("--- 训练开始前 ---")
+print(f"随机初始化的参数: a={a.item():.3f}, b={b.item():.3f}, c={c.item():.3f}, d={d.item():.3f}")
+
+# --- 3. 定义损失函数和优化器 ---
+# 对于更复杂的模型,Adam 优化器通常表现更好
+# 学习率可以适当调高一点
+learning_rate = 0.01
+optimizer = torch.optim.Adam([a, b, c, d], lr=learning_rate)
+
+# 损失函数仍然使用均方误差
+loss_fn = torch.nn.MSELoss()
+
+# --- 4. 训练循环 ---
+# 增加训练周期数,因为模型更复杂,需要更多时间学习
+epochs = 18000
+all_losses = []
+
+plt.figure(figsize=(12, 6))
+plt.ion()
+
+for epoch in range(epochs):
+    # 4.1 前向传播 (Forward Pass) - 这是唯一需要大改的地方!
+    # 根据当前的 a, b, c, d 计算预测值
+    y_pred = a * X ** 3 + b * X ** 2 + c * X + d
+
+    # 4.2 计算损失
+    loss = loss_fn(y_pred, y)
+    all_losses.append(loss.item())
+
+    # 4.3 清空过往梯度
+    optimizer.zero_grad()
+
+    # 4.4 反向传播 - 核心步骤,但代码完全不变!
+    # PyTorch 自动处理复杂的求导链条
+    loss.backward()
+
+    # 4.5 更新参数 - 代码也完全不变!
+    optimizer.step()
+
+    # --- 可视化 ---
+    if (epoch + 1) % 100 == 0:
+        print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}')
+
+        plt.clf()
+        plt.scatter(X.numpy(), y.numpy(), c='blue', s=15, label='原始数据')
+        plt.plot(X.numpy(), TRUE_FUNCTION(X).numpy(), 'g--', label='真实函数曲线')
+        # detach() 是为了切断梯度追踪,因为绘图不需要计算梯度
+        plt.plot(X.numpy(), y_pred.detach().numpy(), 'r-', label='当前拟合曲线')
+
+        plt.title(f"训练过程 - Epoch {epoch + 1}")
+        plt.xlabel("X")
+        plt.ylabel("y")
+        plt.ylim(y.min() - 1, y.max() + 1)  # 固定Y轴范围,防止图像跳动
+        plt.legend()
+        plt.grid(True)
+        plt.pause(0.1)
+
+plt.ioff()
+plt.show()
+
+# --- 5. 最终结果可视化 ---
+print("\n--- 训练完成 ---")
+print(f"学习到的参数: a={a.item():.3f}, b={b.item():.3f}, c={c.item():.3f}, d={d.item():.3f}")
+
+plt.figure(figsize=(10, 5))
+plt.scatter(X.numpy(), y.numpy(), c='blue', s=15, label='原始数据')
+plt.plot(X.numpy(), TRUE_FUNCTION(X).numpy(), 'g--', label='真实函数曲线')
+final_pred = a * X ** 3 + b * X ** 2 + c * X + d
+plt.plot(X.numpy(), final_pred.detach().numpy(), 'r-', label='最终拟合曲线')
+plt.title("最终拟合结果")
+plt.xlabel("X")
+plt.ylabel("y")
+plt.legend()
+plt.grid(True)
+plt.show()
+
+# 绘制损失函数下降曲线
+plt.figure(figsize=(10, 5))
+plt.plot(range(epochs), all_losses)
+plt.title("损失函数下降曲线")
+plt.xlabel("周期 (Epoch)")
+plt.ylabel("损失 (Loss)")
+plt.yscale('log')  # 使用对数坐标轴,可以更清晰地看到早期的快速下降和后期的缓慢优化
+plt.grid(True)
+plt.show()

+ 129 - 0
app/api/config_api.py

@@ -0,0 +1,129 @@
+from fastapi import APIRouter, File, Body, HTTPException, status
+from fastapi.responses import FileResponse, JSONResponse
+from ..core.config import settings
+import json
+from app.core.logger import get_logger
+
+logger = get_logger(__name__)
+
+router = APIRouter(tags=["Config"])
+
+
+def compare_json_structure(template: dict, data: dict, path: str = "") -> (bool, str):
+    """
+    递归比较两个字典的结构(键),忽略值。
+    如果结构不匹配,返回 False 和错误原因。
+    :param template: 模板字典(原始配置)
+    :param data: 新的字典(待验证的配置)
+    :param path: 用于错误报告的当前递归路径
+    :return: 一个元组 (is_match: bool, reason: str)
+    """
+    template_keys = set(template.keys())
+    data_keys = set(data.keys())
+
+    # 检查当前层的键是否完全匹配
+    if template_keys != data_keys:
+        missing_keys = template_keys - data_keys
+        extra_keys = data_keys - template_keys
+        error_msg = f"结构不匹配。在路径 '{path or 'root'}' "
+        if missing_keys:
+            error_msg += f"缺少字段: {missing_keys}。"
+        if extra_keys:
+            error_msg += f"存在多余字段: {extra_keys}。"
+        return False, error_msg
+
+    # 递归检查所有子字典
+    for key in template_keys:
+        if isinstance(template[key], dict) and isinstance(data[key], dict):
+            # 如果两个值都是字典,则递归深入
+            is_match, reason = compare_json_structure(template[key], data[key], path=f"{path}.{key}" if path else key)
+            if not is_match:
+                return False, reason
+        elif isinstance(template[key], dict) or isinstance(data[key], dict):
+            # 如果只有一个是字典,说明结构已改变(例如,一个对象被替换成了字符串)
+            current_path = f"{path}.{key}" if path else key
+            return False, f"结构不匹配。字段 '{current_path}' 的类型已从字典变为非字典(或反之)。"
+
+    return True, "结构匹配"
+
+
+@router.get("/scoring_config", summary="获取评分配置")
+async def get_scoring_config():
+    """
+    读取并返回 scoring_config.json 文件的内容。
+    """
+    if not settings.SCORE_CONFIG_PATH.exists():
+        logger.error(f"评分配置文件未找到: {settings.SCORE_CONFIG_PATH}")
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail="评分配置文件未找到"
+        )
+
+    try:
+        with open(settings.SCORE_CONFIG_PATH, 'r', encoding='utf-8') as f:
+            config_data = json.load(f)
+        return config_data
+    except json.JSONDecodeError:
+        logger.error(f"评分配置文件格式错误,无法解析: {settings.SCORE_CONFIG_PATH}")
+        raise HTTPException(
+            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+            detail="评分配置文件格式错误"
+        )
+    except Exception as e:
+        logger.error(f"读取配置文件时发生未知错误: {e}")
+        raise HTTPException(
+            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+            detail=f"读取配置文件时发生未知错误: {e}"
+        )
+
+
+@router.put("/scoring_config", summary="更新评分配置")
+async def update_scoring_config(new_config: dict = Body(...)):
+    """
+    接收新的JSON配置,验证其结构与现有配置完全一致后,覆盖保存。
+    只允许修改值,不允许增、删、改任何字段名。
+    """
+    # 1. 检查并读取当前的配置文件作为模板
+    if not settings.SCORE_CONFIG_PATH.exists():
+        logger.error(f"尝试更新一个不存在的配置文件: {settings.SCORE_CONFIG_PATH}")
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail="无法更新,因为原始配置文件未找到"
+        )
+
+    try:
+        with open(settings.SCORE_CONFIG_PATH, 'r', encoding='utf-8') as f:
+            current_config = json.load(f)
+    except Exception as e:
+        logger.error(f"更新前读取原始配置文件失败: {e}")
+        raise HTTPException(
+            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+            detail=f"读取原始配置文件失败: {e}"
+        )
+
+    # 2. 比较新旧配置的结构
+    is_valid, reason = compare_json_structure(current_config, new_config)
+
+    if not is_valid:
+        logger.warning(f"更新评分配置失败,结构校验未通过: {reason}")
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=f"配置更新失败。{reason} 请确保只修改数值,不要添加、删除或重命名字段。"
+        )
+
+    # 3. 结构验证通过,写入新文件
+    try:
+        with open(settings.SCORE_CONFIG_PATH, 'w', encoding='utf-8') as f:
+            # 使用 indent=2 格式化输出,方便人工阅读
+            json.dump(new_config, f, indent=2, ensure_ascii=False)
+        logger.info("评分配置文件已成功更新。")
+        return JSONResponse(
+            status_code=status.HTTP_200_OK,
+            content={"message": "配置已成功更新"}
+        )
+    except Exception as e:
+        logger.error(f"写入新配置文件时发生错误: {e}")
+        raise HTTPException(
+            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+            detail=f"保存新配置文件时发生错误: {e}"
+        )

+ 1 - 1
app/api/score_inference.py

@@ -73,4 +73,4 @@ async def score_recalculate(score_type: ScoreType, json_data: Dict[str, Any]):
     except ValueError as e:
         raise HTTPException(status_code=400, detail=str(e))
     except Exception as e:
-        raise HTTPException(status_code=500, detail=f"服务器内部错误: {e}")
+        raise HTTPException(status_code=500, detail=f"服务器内部错误: {e}")

+ 1 - 0
app/core/config.py

@@ -15,6 +15,7 @@ class CardModelConfig:
 class Settings:
     API_Inference_prefix: str = "/api/card_inference"
     API_Score_prefix: str = "/api/card_score"
+    API_Config_prefix: str = "/api/config"
 
     BASE_PATH = Path(__file__).parent.parent.parent.absolute()
 

+ 38 - 38
app/core/scoring_config.json

@@ -1,5 +1,5 @@
 {
-  "base_score": 10.0,
+  "base_score": 10,
   "corner": {
     "rules": {
       "wear_area": [
@@ -17,11 +17,11 @@
         ],
         [
           0.5,
-          -3.0
+          -3
         ],
         [
           "inf",
-          -5.0
+          -5
         ]
       ],
       "loss_area": [
@@ -39,11 +39,11 @@
         ],
         [
           0.5,
-          -3.0
+          -3
         ],
         [
           "inf",
-          -5.0
+          -5
         ]
       ]
     },
@@ -77,11 +77,11 @@
         ],
         [
           0.5,
-          -3.0
+          -3
         ],
         [
           "inf",
-          -5.0
+          -5
         ]
       ],
       "loss_area": [
@@ -99,11 +99,11 @@
         ],
         [
           0.5,
-          -3.0
+          -3
         ],
         [
           "inf",
-          -5.0
+          -5
         ]
       ]
     },
@@ -137,11 +137,11 @@
         ],
         [
           0.5,
-          -3.0
+          -3
         ],
         [
           "inf",
-          -5.0
+          -5
         ]
       ],
       "pit_area": [
@@ -159,11 +159,11 @@
         ],
         [
           0.5,
-          -3.0
+          -3
         ],
         [
           "inf",
-          -5.0
+          -5
         ]
       ],
       "stain_area": [
@@ -181,41 +181,41 @@
         ],
         [
           0.5,
-          -3.0
+          -3
         ],
         [
           "inf",
-          -5.0
+          -5
         ]
       ],
       "scratch_length": [
         [
-          1.0,
+          1,
           -0.1
         ],
         [
-          2.0,
+          2,
           -0.5
         ],
         [
-          5.0,
-          -1.0
+          5,
+          -1
         ],
         [
-          10.0,
-          -2.0
+          10,
+          -2
         ],
         [
-          20.0,
-          -3.0
+          20,
+          -3
         ],
         [
-          50.0,
-          -4.0
+          50,
+          -4
         ],
         [
           "inf",
-          -5.0
+          -5
         ]
       ]
     },
@@ -243,7 +243,7 @@
         ],
         [
           60,
-          -1.0
+          -1
         ],
         [
           62.5,
@@ -251,7 +251,7 @@
         ],
         [
           65,
-          -2.0
+          -2
         ],
         [
           67.5,
@@ -259,7 +259,7 @@
         ],
         [
           70,
-          -3.0
+          -3
         ],
         [
           72.5,
@@ -267,7 +267,7 @@
         ],
         [
           75,
-          -4.0
+          -4
         ],
         [
           77.5,
@@ -275,7 +275,7 @@
         ],
         [
           80,
-          -5.0
+          -5
         ],
         [
           82.5,
@@ -283,7 +283,7 @@
         ],
         [
           85,
-          -6.0
+          -6
         ],
         [
           87.5,
@@ -291,7 +291,7 @@
         ],
         [
           90,
-          -7.0
+          -7
         ],
         [
           92.5,
@@ -299,7 +299,7 @@
         ],
         [
           95,
-          -8.0
+          -8
         ],
         [
           97.5,
@@ -307,7 +307,7 @@
         ],
         [
           "inf",
-          -9.0
+          -9
         ]
       ],
       "coefficients": {
@@ -323,7 +323,7 @@
         ],
         [
           70,
-          -1.0
+          -1
         ],
         [
           75,
@@ -331,7 +331,7 @@
         ],
         [
           85,
-          -2.0
+          -2
         ],
         [
           95,
@@ -339,7 +339,7 @@
         ],
         [
           "inf",
-          -3.0
+          -3
         ]
       ],
       "coefficients": {

+ 369 - 0
app/core/基础备份_scoring_config.json

@@ -0,0 +1,369 @@
+{
+  "base_score": 10.0,
+  "corner": {
+    "rules": {
+      "wear_area": [
+        [
+          0.05,
+          -0.1
+        ],
+        [
+          0.1,
+          -0.5
+        ],
+        [
+          0.25,
+          -1.5
+        ],
+        [
+          0.5,
+          -3.0
+        ],
+        [
+          "inf",
+          -5.0
+        ]
+      ],
+      "loss_area": [
+        [
+          0.05,
+          -0.1
+        ],
+        [
+          0.1,
+          -0.5
+        ],
+        [
+          0.25,
+          -1.5
+        ],
+        [
+          0.5,
+          -3.0
+        ],
+        [
+          "inf",
+          -5.0
+        ]
+      ]
+    },
+    "front_weights": {
+      "wear_area": 0.3,
+      "loss_area": 0.7
+    },
+    "back_weights": {
+      "wear_area": 0.3,
+      "loss_area": 0.7
+    },
+    "final_weights": {
+      "front": 0.7,
+      "back": 0.3
+    }
+  },
+  "edge": {
+    "rules": {
+      "wear_area": [
+        [
+          0.05,
+          -0.1
+        ],
+        [
+          0.1,
+          -0.5
+        ],
+        [
+          0.25,
+          -1.5
+        ],
+        [
+          0.5,
+          -3.0
+        ],
+        [
+          "inf",
+          -5.0
+        ]
+      ],
+      "loss_area": [
+        [
+          0.05,
+          -0.1
+        ],
+        [
+          0.1,
+          -0.5
+        ],
+        [
+          0.25,
+          -1.5
+        ],
+        [
+          0.5,
+          -3.0
+        ],
+        [
+          "inf",
+          -5.0
+        ]
+      ]
+    },
+    "front_weights": {
+      "wear_area": 0.4,
+      "loss_area": 0.6
+    },
+    "back_weights": {
+      "wear_area": 0.4,
+      "loss_area": 0.6
+    },
+    "final_weights": {
+      "front": 0.7,
+      "back": 0.3
+    }
+  },
+  "face": {
+    "rules": {
+      "wear_area": [
+        [
+          0.05,
+          -0.1
+        ],
+        [
+          0.1,
+          -0.5
+        ],
+        [
+          0.25,
+          -1.5
+        ],
+        [
+          0.5,
+          -3.0
+        ],
+        [
+          "inf",
+          -5.0
+        ]
+      ],
+      "pit_area": [
+        [
+          0.05,
+          -0.1
+        ],
+        [
+          0.1,
+          -0.5
+        ],
+        [
+          0.25,
+          -1.5
+        ],
+        [
+          0.5,
+          -3.0
+        ],
+        [
+          "inf",
+          -5.0
+        ]
+      ],
+      "stain_area": [
+        [
+          0.05,
+          -0.1
+        ],
+        [
+          0.1,
+          -0.5
+        ],
+        [
+          0.25,
+          -1.5
+        ],
+        [
+          0.5,
+          -3.0
+        ],
+        [
+          "inf",
+          -5.0
+        ]
+      ],
+      "scratch_length": [
+        [
+          1.0,
+          -0.1
+        ],
+        [
+          2.0,
+          -0.5
+        ],
+        [
+          5.0,
+          -1.0
+        ],
+        [
+          10.0,
+          -2.0
+        ],
+        [
+          20.0,
+          -3.0
+        ],
+        [
+          50.0,
+          -4.0
+        ],
+        [
+          "inf",
+          -5.0
+        ]
+      ]
+    },
+    "coefficients": {
+      "wear_area": 0.25,
+      "scratch_length": 0.25,
+      "dent_area": 0.25,
+      "stain_area": 0.25
+    },
+    "final_weights": {
+      "front": 0.75,
+      "back": 0.25
+    }
+  },
+  "centering": {
+    "front": {
+      "rules": [
+        [
+          52,
+          0
+        ],
+        [
+          55,
+          -0.5
+        ],
+        [
+          60,
+          -1.0
+        ],
+        [
+          62.5,
+          -1.5
+        ],
+        [
+          65,
+          -2.0
+        ],
+        [
+          67.5,
+          -2.5
+        ],
+        [
+          70,
+          -3.0
+        ],
+        [
+          72.5,
+          -3.5
+        ],
+        [
+          75,
+          -4.0
+        ],
+        [
+          77.5,
+          -4.5
+        ],
+        [
+          80,
+          -5.0
+        ],
+        [
+          82.5,
+          -5.5
+        ],
+        [
+          85,
+          -6.0
+        ],
+        [
+          87.5,
+          -6.5
+        ],
+        [
+          90,
+          -7.0
+        ],
+        [
+          92.5,
+          -7.5
+        ],
+        [
+          95,
+          -8.0
+        ],
+        [
+          97.5,
+          -8.5
+        ],
+        [
+          "inf",
+          -9.0
+        ]
+      ],
+      "coefficients": {
+        "horizontal": 1.2,
+        "vertical": 0.9
+      }
+    },
+    "back": {
+      "rules": [
+        [
+          60,
+          -0.5
+        ],
+        [
+          70,
+          -1.0
+        ],
+        [
+          75,
+          -1.5
+        ],
+        [
+          85,
+          -2.0
+        ],
+        [
+          95,
+          -2.5
+        ],
+        [
+          "inf",
+          -3.0
+        ]
+      ],
+      "coefficients": {
+        "horizontal": 1.2,
+        "vertical": 0.9
+      }
+    },
+    "final_weights": {
+      "front": 0.75,
+      "back": 0.25
+    }
+  },
+  "card": {
+    "PSA": {
+      "face": 0.35,
+      "corner": 0.3,
+      "edge": 0.1,
+      "center": 0.25
+    },
+    "BGS": {
+      "face": 0.3,
+      "corner": 0.25,
+      "edge": 0.1,
+      "center": 0.25
+    }
+  }
+}

+ 2 - 0
app/main.py

@@ -4,6 +4,7 @@ from contextlib import asynccontextmanager
 from .core.model_loader import load_models, unload_models
 from app.api.card_inference import router as card_inference_router
 from app.api.score_inference import router as score_inference_router
+from app.api.config_api import router as config_api_router
 import os
 
 from .core.logger import setup_logging, get_logger
@@ -32,3 +33,4 @@ app = FastAPI(title="卡片框和缺陷检测服务", lifespan=lifespan)
 
 app.include_router(card_inference_router, prefix=settings.API_Inference_prefix)
 app.include_router(score_inference_router, prefix=settings.API_Score_prefix)
+app.include_router(config_api_router, prefix=settings.API_Config_prefix)