|
@@ -20,6 +20,18 @@ IMPORT_REQUEST_TIMEOUT = aiohttp.ClientTimeout(
|
|
|
|
|
|
|
|
# 推理服务 stitch 接口需要的表单字段顺序,按 ring + gray + stripe1..4 组装
|
|
# 推理服务 stitch 接口需要的表单字段顺序,按 ring + gray + stripe1..4 组装
|
|
|
STITCH_FORM_FIELDS = ["ring", "gray", "stripe1", "stripe2", "stripe3", "stripe4"]
|
|
STITCH_FORM_FIELDS = ["ring", "gray", "stripe1", "stripe2", "stripe3", "stripe4"]
|
|
|
|
|
+STITCH_DEFECT_KEEP_LABELS = [
|
|
|
|
|
+ "slight_scratch",
|
|
|
|
|
+ "scratch",
|
|
|
|
|
+ "serious_scratch",
|
|
|
|
|
+ "damaged",
|
|
|
|
|
+ "impact",
|
|
|
|
|
+ "pit",
|
|
|
|
|
+ "stain",
|
|
|
|
|
+ "wear",
|
|
|
|
|
+]
|
|
|
|
|
+STITCH_DEFECT_PER_LABEL_LIMIT = 10
|
|
|
|
|
+STITCH_DEFECT_TOTAL_LIMIT = 80
|
|
|
|
|
|
|
|
# 一面(front/back)对应的所有 image_type
|
|
# 一面(front/back)对应的所有 image_type
|
|
|
SIDE_IMAGE_TYPES: Dict[str, Dict[str, str]] = {
|
|
SIDE_IMAGE_TYPES: Dict[str, Dict[str, str]] = {
|
|
@@ -64,6 +76,97 @@ def _flat_image_types() -> List[str]:
|
|
|
return flat
|
|
return flat
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+def _to_float_prob(value: Any) -> float:
|
|
|
|
|
+ try:
|
|
|
|
|
+ return float(value)
|
|
|
|
|
+ except (TypeError, ValueError):
|
|
|
|
|
+ return 0.0
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _trim_stitch_defects(score_json: Dict[str, Any], side: str) -> Dict[str, Any]:
|
|
|
|
|
+ """
|
|
|
|
|
+ stitch 返回缺陷裁剪:
|
|
|
|
|
+ - 仅保留指定 8 个 label
|
|
|
|
|
+ - 每个 label 按 prob 降序最多 10 条
|
|
|
|
|
+ - 总数最多 80 条
|
|
|
|
|
+ """
|
|
|
|
|
+ if not isinstance(score_json, dict):
|
|
|
|
|
+ return score_json
|
|
|
|
|
+
|
|
|
|
|
+ defects = (
|
|
|
|
|
+ score_json.get("result", {})
|
|
|
|
|
+ .get("defect_result", {})
|
|
|
|
|
+ .get("defects", [])
|
|
|
|
|
+ )
|
|
|
|
|
+ if not isinstance(defects, list):
|
|
|
|
|
+ return score_json
|
|
|
|
|
+
|
|
|
|
|
+ by_label: Dict[str, List[Dict[str, Any]]] = {label: [] for label in STITCH_DEFECT_KEEP_LABELS}
|
|
|
|
|
+ for defect in defects:
|
|
|
|
|
+ if not isinstance(defect, dict):
|
|
|
|
|
+ continue
|
|
|
|
|
+ label = defect.get("label")
|
|
|
|
|
+ if label not in by_label:
|
|
|
|
|
+ continue
|
|
|
|
|
+ by_label[label].append(defect)
|
|
|
|
|
+
|
|
|
|
|
+ trimmed: List[Dict[str, Any]] = []
|
|
|
|
|
+ for label in STITCH_DEFECT_KEEP_LABELS:
|
|
|
|
|
+ top_items = sorted(
|
|
|
|
|
+ by_label[label],
|
|
|
|
|
+ key=lambda item: _to_float_prob(item.get("prob")),
|
|
|
|
|
+ reverse=True,
|
|
|
|
|
+ )[:STITCH_DEFECT_PER_LABEL_LIMIT]
|
|
|
|
|
+ trimmed.extend(top_items)
|
|
|
|
|
+
|
|
|
|
|
+ # 总量兜底到 80,按 prob 全局降序截断
|
|
|
|
|
+ trimmed = sorted(
|
|
|
|
|
+ trimmed,
|
|
|
|
|
+ key=lambda item: _to_float_prob(item.get("prob")),
|
|
|
|
|
+ reverse=True,
|
|
|
|
|
+ )[:STITCH_DEFECT_TOTAL_LIMIT]
|
|
|
|
|
+
|
|
|
|
|
+ score_json.setdefault("result", {}).setdefault("defect_result", {})["defects"] = trimmed
|
|
|
|
|
+ logger.info(
|
|
|
|
|
+ "stitch 缺陷裁剪: side=%s before=%s after=%s labels=%s per_label_limit=%s total_limit=%s",
|
|
|
|
|
+ side,
|
|
|
|
|
+ len(defects),
|
|
|
|
|
+ len(trimmed),
|
|
|
|
|
+ ",".join(STITCH_DEFECT_KEEP_LABELS),
|
|
|
|
|
+ STITCH_DEFECT_PER_LABEL_LIMIT,
|
|
|
|
|
+ STITCH_DEFECT_TOTAL_LIMIT,
|
|
|
|
|
+ )
|
|
|
|
|
+ return score_json
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+async def _pregenerate_defect_images(card_id: int) -> None:
|
|
|
|
|
+ """
|
|
|
|
|
+ 导入完成后预生成缺陷裁图(写入 MinIO),让查询接口直接命中缓存。
|
|
|
|
|
+ 裁图为同步阻塞操作,放到线程池执行;任何异常都不影响导入主流程。
|
|
|
|
|
+ """
|
|
|
|
|
+ def _run() -> None:
|
|
|
|
|
+ from app.core import database_loader
|
|
|
|
|
+ from app.api.formate_xy import pregenerate_defect_images_for_card
|
|
|
|
|
+
|
|
|
|
|
+ if database_loader.db_connection_pool is None:
|
|
|
|
|
+ logger.warning("预生成缺陷裁图跳过:数据库连接池未初始化 card_id=%s", card_id)
|
|
|
|
|
+ return
|
|
|
|
|
+
|
|
|
|
|
+ db_conn = None
|
|
|
|
|
+ try:
|
|
|
|
|
+ db_conn = database_loader.db_connection_pool.get_connection()
|
|
|
|
|
+ pregenerate_defect_images_for_card(db_conn, card_id)
|
|
|
|
|
+ finally:
|
|
|
|
|
+ if db_conn and db_conn.is_connected():
|
|
|
|
|
+ db_conn.close()
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ from fastapi.concurrency import run_in_threadpool
|
|
|
|
|
+ await run_in_threadpool(_run)
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error("预生成缺陷裁图失败 card_id=%s error=%s", card_id, e, exc_info=True)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
def _resolve_internal_base_url(request: Request) -> str:
|
|
def _resolve_internal_base_url(request: Request) -> str:
|
|
|
"""
|
|
"""
|
|
|
导入流程内部调用本服务时优先走集群内地址,避免经 ingress 再次鉴权导致 401。
|
|
导入流程内部调用本服务时优先走集群内地址,避免经 ingress 再次鉴权导致 401。
|
|
@@ -371,12 +474,13 @@ async def _run_import_flow(
|
|
|
side=side,
|
|
side=side,
|
|
|
side_bytes=side_bytes,
|
|
side_bytes=side_bytes,
|
|
|
)
|
|
)
|
|
|
- side_score_json[side] = await call_stitch_inference(
|
|
|
|
|
|
|
+ stitch_result = await call_stitch_inference(
|
|
|
session=session,
|
|
session=session,
|
|
|
side=side,
|
|
side=side,
|
|
|
card_name=card_name,
|
|
card_name=card_name,
|
|
|
side_bytes=rectified_side_bytes,
|
|
side_bytes=rectified_side_bytes,
|
|
|
)
|
|
)
|
|
|
|
|
+ side_score_json[side] = _trim_stitch_defects(stitch_result, side)
|
|
|
|
|
|
|
|
# 2. 创建卡牌
|
|
# 2. 创建卡牌
|
|
|
card_id = await create_card_record(
|
|
card_id = await create_card_record(
|
|
@@ -440,6 +544,11 @@ async def _run_import_flow(
|
|
|
"yes" if side_score_json["front"] is not None else "no",
|
|
"yes" if side_score_json["front"] is not None else "no",
|
|
|
"yes" if side_score_json["back"] is not None else "no",
|
|
"yes" if side_score_json["back"] is not None else "no",
|
|
|
)
|
|
)
|
|
|
|
|
+
|
|
|
|
|
+ # 导入阶段预生成缺陷裁图,保证后续查询无需实时裁图。
|
|
|
|
|
+ # 失败不影响导入主流程,查询时会自动回退到实时裁图。
|
|
|
|
|
+ await _pregenerate_defect_images(card_id)
|
|
|
|
|
+
|
|
|
return {
|
|
return {
|
|
|
"message": "导入成功",
|
|
"message": "导入成功",
|
|
|
"card_id": card_id,
|
|
"card_id": card_id,
|