Przeglądaj źródła

算法模块前置,算法加限制条件

袁威 1 tydzień temu
rodzic
commit
633cec7ef1
2 zmienionych plików z 158 dodań i 1 usunięć
  1. 110 1
      app/api/auto_import.py
  2. 48 0
      app/api/formate_xy.py

+ 110 - 1
app/api/auto_import.py

@@ -20,6 +20,18 @@ IMPORT_REQUEST_TIMEOUT = aiohttp.ClientTimeout(
 
 # 推理服务 stitch 接口需要的表单字段顺序,按 ring + gray + stripe1..4 组装
 STITCH_FORM_FIELDS = ["ring", "gray", "stripe1", "stripe2", "stripe3", "stripe4"]
+STITCH_DEFECT_KEEP_LABELS = [
+    "slight_scratch",
+    "scratch",
+    "serious_scratch",
+    "damaged",
+    "impact",
+    "pit",
+    "stain",
+    "wear",
+]
+STITCH_DEFECT_PER_LABEL_LIMIT = 10
+STITCH_DEFECT_TOTAL_LIMIT = 80
 
 # 一面(front/back)对应的所有 image_type
 SIDE_IMAGE_TYPES: Dict[str, Dict[str, str]] = {
@@ -64,6 +76,97 @@ def _flat_image_types() -> List[str]:
     return flat
 
 
+def _to_float_prob(value: Any) -> float:
+    try:
+        return float(value)
+    except (TypeError, ValueError):
+        return 0.0
+
+
+def _trim_stitch_defects(score_json: Dict[str, Any], side: str) -> Dict[str, Any]:
+    """
+    stitch 返回缺陷裁剪:
+    - 仅保留指定 8 个 label
+    - 每个 label 按 prob 降序最多 10 条
+    - 总数最多 80 条
+    """
+    if not isinstance(score_json, dict):
+        return score_json
+
+    defects = (
+        score_json.get("result", {})
+        .get("defect_result", {})
+        .get("defects", [])
+    )
+    if not isinstance(defects, list):
+        return score_json
+
+    by_label: Dict[str, List[Dict[str, Any]]] = {label: [] for label in STITCH_DEFECT_KEEP_LABELS}
+    for defect in defects:
+        if not isinstance(defect, dict):
+            continue
+        label = defect.get("label")
+        if label not in by_label:
+            continue
+        by_label[label].append(defect)
+
+    trimmed: List[Dict[str, Any]] = []
+    for label in STITCH_DEFECT_KEEP_LABELS:
+        top_items = sorted(
+            by_label[label],
+            key=lambda item: _to_float_prob(item.get("prob")),
+            reverse=True,
+        )[:STITCH_DEFECT_PER_LABEL_LIMIT]
+        trimmed.extend(top_items)
+
+    # 总量兜底到 80,按 prob 全局降序截断
+    trimmed = sorted(
+        trimmed,
+        key=lambda item: _to_float_prob(item.get("prob")),
+        reverse=True,
+    )[:STITCH_DEFECT_TOTAL_LIMIT]
+
+    score_json.setdefault("result", {}).setdefault("defect_result", {})["defects"] = trimmed
+    logger.info(
+        "stitch 缺陷裁剪: side=%s before=%s after=%s labels=%s per_label_limit=%s total_limit=%s",
+        side,
+        len(defects),
+        len(trimmed),
+        ",".join(STITCH_DEFECT_KEEP_LABELS),
+        STITCH_DEFECT_PER_LABEL_LIMIT,
+        STITCH_DEFECT_TOTAL_LIMIT,
+    )
+    return score_json
+
+
+async def _pregenerate_defect_images(card_id: int) -> None:
+    """
+    导入完成后预生成缺陷裁图(写入 MinIO),让查询接口直接命中缓存。
+    裁图为同步阻塞操作,放到线程池执行;任何异常都不影响导入主流程。
+    """
+    def _run() -> None:
+        from app.core import database_loader
+        from app.api.formate_xy import pregenerate_defect_images_for_card
+
+        if database_loader.db_connection_pool is None:
+            logger.warning("预生成缺陷裁图跳过:数据库连接池未初始化 card_id=%s", card_id)
+            return
+
+        db_conn = None
+        try:
+            db_conn = database_loader.db_connection_pool.get_connection()
+            pregenerate_defect_images_for_card(db_conn, card_id)
+        finally:
+            if db_conn and db_conn.is_connected():
+                db_conn.close()
+
+    try:
+        from fastapi.concurrency import run_in_threadpool
+        await run_in_threadpool(_run)
+    except Exception as e:
+        logger.error("预生成缺陷裁图失败 card_id=%s error=%s", card_id, e, exc_info=True)
+
+
 def _resolve_internal_base_url(request: Request) -> str:
     """
     导入流程内部调用本服务时优先走集群内地址,避免经 ingress 再次鉴权导致 401。
@@ -371,12 +474,13 @@ async def _run_import_flow(
                 side=side,
                 side_bytes=side_bytes,
             )
-            side_score_json[side] = await call_stitch_inference(
+            stitch_result = await call_stitch_inference(
                 session=session,
                 side=side,
                 card_name=card_name,
                 side_bytes=rectified_side_bytes,
             )
+            side_score_json[side] = _trim_stitch_defects(stitch_result, side)
 
         # 2. 创建卡牌
         card_id = await create_card_record(
@@ -440,6 +544,11 @@ async def _run_import_flow(
             "yes" if side_score_json["front"] is not None else "no",
             "yes" if side_score_json["back"] is not None else "no",
         )
+
+        # 导入阶段预生成缺陷裁图,保证后续查询无需实时裁图。
+        # 失败不影响导入主流程,查询时会自动回退到实时裁图。
+        await _pregenerate_defect_images(card_id)
+
         return {
             "message": "导入成功",
             "card_id": card_id,

+ 48 - 0
app/api/formate_xy.py

@@ -511,6 +511,54 @@ def _process_images_to_xy_format(
     return card_data
 
 
+def pregenerate_defect_images_for_card(db_conn: PooledMySQLConnection, card_id: int) -> int:
+    """
+    导入阶段预生成缺陷裁图:以融合图 JSON 为准,按面在同面各类型原图上裁图并写入 MinIO。
+    查询接口 get_card_details 命中已存在的裁图后即可直接拼 URL,无需再实时裁图。
+
+    返回本次涉及裁图的缺陷数量(仅用于日志/统计)。
+    """
+    start_time = perf_counter()
+    card_data = crud_card.get_card_with_details(db_conn, card_id)
+    if not card_data:
+        logger.warning("预生成缺陷裁图跳过:card_id=%s 未找到卡牌", card_id)
+        return 0
+
+    all_images = card_data.get("images", [])
+    if not all_images:
+        return 0
+
+    fusion_by_side = _resolve_fusion_images_by_side(all_images)
+    crop_pool_by_type = _build_defect_crop_pool_by_type(card_id, all_images, db_conn=db_conn)
+
+    total_defects = 0
+    for side_key, fusion_img in fusion_by_side.items():
+        if not fusion_img:
+            continue
+        for json_field in ("detection_json", "modified_json"):
+            raw_json = getattr(fusion_img, json_field, None)
+            if isinstance(raw_json, str):
+                raw_json = json.loads(raw_json) if raw_json else None
+            if not raw_json:
+                continue
+            # 复制一份,避免污染 Pydantic 对象;只关心裁图副作用(写入 MinIO)
+            cache = _generate_defect_img_urls_for_json(
+                card_id,
+                fusion_img.id,
+                copy.deepcopy(raw_json),
+                side_key,
+                crop_pool_by_type,
+                generate_related_images=True,
+            )
+            total_defects += len(cache)
+
+    logger.info(
+        "预生成缺陷裁图完成: card_id=%s defects=%s elapsed_ms=%.2f",
+        card_id, total_defects, (perf_counter() - start_time) * 1000,
+    )
+    return total_defects
+
+
 @router.get("/query", response_model=CardDetailResponse, summary="获取卡牌详细信息(格式化xy), 支持前后翻页 [用户调用]")
 def get_card_details(
         card_id: int = Query(..., description="基准卡牌ID"),