Переглянути джерело

外框检测 + 图像矫正居中

袁威 1 тиждень тому
батько
коміт
c82b55e0e8
2 змінених файлів з 239 додано та 17 видалено
  1. 139 17
      app/api/auto_import.py
  2. 100 0
      app/crud/crud_card.py

+ 139 - 17
app/api/auto_import.py

@@ -20,6 +20,14 @@ IMPORT_REQUEST_TIMEOUT = aiohttp.ClientTimeout(
 
 # 推理服务 stitch 接口需要的表单字段顺序,按 ring + gray + stripe1..4 组装
 STITCH_FORM_FIELDS = ["ring", "gray", "stripe1", "stripe2", "stripe3", "stripe4"]
+STORAGE_RECTIFY_FIELDS = ["fusion"] + STITCH_FORM_FIELDS
+# 每面选一张彩色图做 center 框检测,结果与同面其它图共享
+CENTER_DETECT_COLOR_FIELDS = ["ring", "stripe1", "stripe2", "stripe3", "stripe4"]
+# score_inference 仅接受 front_ring / back_ring / front_coaxial / back_coaxial
+SIDE_CENTER_SCORE_TYPE = {
+    "front": ImageType.front_ring.value,
+    "back": ImageType.back_ring.value,
+}
 STITCH_DEFECT_KEEP_LABELS = [
     "slight_scratch",
     "scratch",
@@ -294,6 +302,86 @@ async def call_stitch_inference(
         raise HTTPException(status_code=500, detail=f"{side} 面推理结果解析失败: {e}")
 
 
+def _has_center_box_shapes(center_result: Any) -> bool:
+    if not isinstance(center_result, dict):
+        return False
+    box_result = center_result.get("box_result", {})
+    if not isinstance(box_result, dict):
+        return False
+    inner_shapes = box_result.get("inner_box", {}).get("shapes", [])
+    outer_shapes = box_result.get("outer_box", {}).get("shapes", [])
+    return bool(inner_shapes or outer_shapes)
+
+
+def _pick_center_detect_image(
+        side: str,
+        rectified_by_type: Dict[str, Tuple[bytes, str]],
+) -> Optional[Tuple[str, bytes, str]]:
+    """从一面已矫正的彩色图中选第一张用于 center 框检测。"""
+    side_map = SIDE_IMAGE_TYPES[side]
+    for field in CENTER_DETECT_COLOR_FIELDS:
+        image_type = side_map[field]
+        if image_type in rectified_by_type:
+            file_bytes, filename = rectified_by_type[image_type]
+            return image_type, file_bytes, filename
+    return None
+
+
+def _merge_center_result(score_json: Dict[str, Any], center_result: Dict[str, Any]) -> Dict[str, Any]:
+    """将 center 检测结果合并进 stitch JSON(stitch 本身通常不含 box)。"""
+    if not isinstance(score_json, dict) or not isinstance(center_result, dict):
+        return score_json
+    existing = score_json.get("result", {}).get("center_result", {})
+    if _has_center_box_shapes(existing):
+        return score_json
+    if not _has_center_box_shapes(center_result):
+        return score_json
+    score_json.setdefault("result", {})["center_result"] = center_result
+    return score_json
+
+
+async def call_center_inference(
+        session: aiohttp.ClientSession,
+        side: str,
+        image_type: str,
+        file_bytes: bytes,
+        filename: str,
+        is_reflect_card: bool,
+) -> Dict[str, Any]:
+    """对单张矫正后的彩色图调用 score_inference,仅取 center_result。"""
+    inference_base_url = settings.SCORE_UPDATE_SERVER_URL
+    url = f"{inference_base_url}/api/card_score/score_inference"
+    params = {
+        "score_type": SIDE_CENTER_SCORE_TYPE[side],
+        "is_reflect_card": str(is_reflect_card).lower(),
+    }
+    logger.info(
+        "调用 score_inference(center): side=%s image_type=%s score_type=%s",
+        side, image_type, SIDE_CENTER_SCORE_TYPE[side],
+    )
+    status, body = await _post_form(
+        session,
+        url=url,
+        files=[("file", file_bytes, filename or f"{image_type}.jpg")],
+        params=params,
+    )
+    if status >= 300:
+        logger.warning(
+            "%s 面 center 框检测失败: image_type=%s status=%s",
+            side, image_type, status,
+        )
+        return {}
+
+    try:
+        payload = json.loads(body)
+    except json.JSONDecodeError as e:
+        logger.warning("%s 面 center 框检测结果解析失败: %s", side, e)
+        return {}
+
+    center_result = payload.get("result", {}).get("center_result", {})
+    return center_result if isinstance(center_result, dict) else {}
+
+
 async def call_rectify_and_center(
         session: aiohttp.ClientSession,
         image_type: str,
@@ -320,27 +408,33 @@ async def call_rectify_and_center(
     return body, f"{name_root}_rectified.jpg"
 
 
-async def rectify_side_images(
+async def rectify_side_storage_images(
         session: aiohttp.ClientSession,
         side: str,
-        side_bytes: Dict[str, Tuple[bytes, str]],
+        bytes_map: Dict[str, Tuple[bytes, str]],
 ) -> Dict[str, Tuple[bytes, str]]:
-    """将 stitch 需要的一面 6 张图全部转正居中。"""
+    """将一面需入库的 fusion + ring/gray/stripe 图转正居中,返回 image_type -> (bytes, filename)。"""
     side_map = SIDE_IMAGE_TYPES[side]
+    fields_to_rectify = [
+        field for field in STORAGE_RECTIFY_FIELDS
+        if side_map[field] in bytes_map
+    ]
+    if not fields_to_rectify:
+        return {}
 
     async def rectify_one(field: str) -> Tuple[str, Tuple[bytes, str]]:
-        file_bytes, filename = side_bytes[field]
         image_type = side_map[field]
+        file_bytes, filename = bytes_map[image_type]
         rectified_bytes, rectified_filename = await call_rectify_and_center(
             session=session,
             image_type=image_type,
             file_bytes=file_bytes,
             filename=filename,
         )
-        return field, (rectified_bytes, rectified_filename)
+        return image_type, (rectified_bytes, rectified_filename)
 
-    logger.info("%s 面开始转正居中: fields=%s", side, ",".join(STITCH_FORM_FIELDS))
-    rectified_pairs = await asyncio.gather(*[rectify_one(field) for field in STITCH_FORM_FIELDS])
+    logger.info("%s 面开始转正居中: fields=%s", side, ",".join(fields_to_rectify))
+    rectified_pairs = await asyncio.gather(*[rectify_one(field) for field in fields_to_rectify])
     logger.info("%s 面转正居中完成", side)
     return dict(rectified_pairs)
 
@@ -455,6 +549,7 @@ async def _run_import_flow(
         card_type: CardType,
         strict_mode: bool,
         bytes_map: Dict[str, Tuple[bytes, str]],
+        is_reflect_card: bool = True,
         non_gray_to_main: bool = False,
         forward_headers: Optional[Dict[str, str]] = None,
 ) -> Dict[str, Any]:
@@ -490,25 +585,50 @@ async def _run_import_flow(
 
     connector = aiohttp.TCPConnector(limit=20, force_close=True)
     async with aiohttp.ClientSession(timeout=IMPORT_REQUEST_TIMEOUT, connector=connector) as session:
-        # 1. 正反面分别调用 stitch 推理(要求该面 6 张推理图齐全)
+        # 1. 正反面转正居中;stitch 推理与入库共用同一套矫正图
+        upload_bytes_map = dict(bytes_map)
         side_score_json: Dict[str, Optional[Dict[str, Any]]] = {"front": None, "back": None}
         for side in ("front", "back"):
-            side_bytes = _collect_side_bytes(side, bytes_map)
-            if side_bytes is None:
-                logger.warning("%s 面推理跳过,缺少 ring/gray/stripe1..4 中的某些图", side)
-                continue
-            rectified_side_bytes = await rectify_side_images(
+            rectified_by_type = await rectify_side_storage_images(
                 session=session,
                 side=side,
-                side_bytes=side_bytes,
+                bytes_map=bytes_map,
             )
+            upload_bytes_map.update(rectified_by_type)
+
+            side_stitch_bytes = _collect_side_bytes(side, bytes_map)
+            if side_stitch_bytes is None:
+                logger.warning("%s 面推理跳过,缺少 ring/gray/stripe1..4 中的某些图", side)
+                continue
+            side_map = SIDE_IMAGE_TYPES[side]
+            rectified_stitch_bytes = {
+                field: rectified_by_type[side_map[field]]
+                for field in STITCH_FORM_FIELDS
+            }
             stitch_result = await call_stitch_inference(
                 session=session,
                 side=side,
                 card_name=card_name,
-                side_bytes=rectified_side_bytes,
+                side_bytes=rectified_stitch_bytes,
             )
-            side_score_json[side] = _trim_stitch_defects(stitch_result, side)
+            score_json = _trim_stitch_defects(stitch_result, side)
+
+            center_pick = _pick_center_detect_image(side, rectified_by_type)
+            if center_pick:
+                center_image_type, center_bytes, center_filename = center_pick
+                center_result = await call_center_inference(
+                    session=session,
+                    side=side,
+                    image_type=center_image_type,
+                    file_bytes=center_bytes,
+                    filename=center_filename,
+                    is_reflect_card=is_reflect_card,
+                )
+                score_json = _merge_center_result(score_json, center_result)
+            else:
+                logger.warning("%s 面缺少彩色图,跳过 center 框检测", side)
+
+            side_score_json[side] = score_json
 
         # 2. 创建卡牌
         card_id = await create_card_record(
@@ -530,7 +650,7 @@ async def _run_import_flow(
             # 按 fusion -> ring -> gray -> stripe1..4 顺序处理该面所有图
             for field in ["fusion"] + STITCH_FORM_FIELDS:
                 image_type = side_map[field]
-                data = bytes_map.get(image_type)
+                data = upload_bytes_map.get(image_type)
                 if data is None:
                     continue
                 f_bytes, f_name = data
@@ -677,6 +797,7 @@ async def auto_import_script_api(
             card_type=card_type,
             strict_mode=strict_mode,
             bytes_map=bytes_map,
+            is_reflect_card=is_reflect_card,
             non_gray_to_main=True,
             forward_headers=forward_headers or None,
         )
@@ -785,6 +906,7 @@ async def auto_import_url_script_api(
             card_type=card_type,
             strict_mode=strict_mode,
             bytes_map=bytes_map,
+            is_reflect_card=is_reflect_card,
             non_gray_to_main=True,
             forward_headers=forward_headers or None,
         )

+ 100 - 0
app/crud/crud_card.py

@@ -58,6 +58,96 @@ def update_card_scores_and_status(db_conn: PooledMySQLConnection, card_id: int):
         db_conn.commit()
 
 
+def _parse_json_value(value: Any) -> Optional[Dict[str, Any]]:
+    if value is None:
+        return None
+    if isinstance(value, str):
+        try:
+            return json.loads(value)
+        except json.JSONDecodeError:
+            return None
+    return value if isinstance(value, dict) else None
+
+
+def _has_center_box_shapes(center_result: Any) -> bool:
+    if not isinstance(center_result, dict):
+        return False
+    box_result = center_result.get("box_result", {})
+    if not isinstance(box_result, dict):
+        return False
+    inner_shapes = box_result.get("inner_box", {}).get("shapes", [])
+    outer_shapes = box_result.get("outer_box", {}).get("shapes", [])
+    return bool(inner_shapes or outer_shapes)
+
+
+def _extract_center_result(json_data: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
+    if not json_data:
+        return None
+    center_result = json_data.get("result", {}).get("center_result")
+    if _has_center_box_shapes(center_result):
+        return copy.deepcopy(center_result)
+    return None
+
+
+def _apply_center_result(json_data: Optional[Dict[str, Any]], center_result: Dict[str, Any]) -> Dict[str, Any]:
+    payload = copy.deepcopy(json_data) if isinstance(json_data, dict) else copy.deepcopy(EMPTY_DETECTION_JSON)
+    payload.setdefault("result", {})["center_result"] = copy.deepcopy(center_result)
+    return payload
+
+
+def _share_side_center_results(images: List[CardImageResponse]) -> None:
+    """同面各图共享 center_result(优先 fusion/ring,其次 stripe)。"""
+    side_priority_types = {
+        "front": [
+            ImageType.front_fusion.value,
+            ImageType.front_ring.value,
+            ImageType.front_coaxial.value,
+            ImageType.front_stripe1.value,
+            ImageType.front_stripe2.value,
+            ImageType.front_stripe3.value,
+            ImageType.front_stripe4.value,
+        ],
+        "back": [
+            ImageType.back_fusion.value,
+            ImageType.back_ring.value,
+            ImageType.back_coaxial.value,
+            ImageType.back_stripe1.value,
+            ImageType.back_stripe2.value,
+            ImageType.back_stripe3.value,
+            ImageType.back_stripe4.value,
+        ],
+    }
+
+    for side, priority_types in side_priority_types.items():
+        side_center = None
+        for image_type in priority_types:
+            for img in images:
+                if img.image_type != image_type:
+                    continue
+                for json_field in ("modified_json", "detection_json"):
+                    source = _extract_center_result(getattr(img, json_field, None))
+                    if source:
+                        side_center = source
+                        break
+                if side_center:
+                    break
+            if side_center:
+                break
+
+        if not side_center:
+            continue
+
+        prefix = f"{side}_"
+        for img in images:
+            if not img.image_type.startswith(prefix):
+                continue
+            if img.image_type.endswith("_gray"):
+                continue
+            img.detection_json = _apply_center_result(img.detection_json, side_center)
+            if img.modified_json:
+                img.modified_json = _apply_center_result(img.modified_json, side_center)
+
+
 def _construct_gray_image_json(gray_type: str, ring_image_data: Optional[Dict[str, Any]]) -> Dict[str, Any]:
     """
     内部辅助:构建辅助图(灰度图/融合图)的 modified_json
@@ -100,6 +190,10 @@ def _construct_gray_image_json(gray_type: str, ring_image_data: Optional[Dict[st
     gray_modified_json = copy.deepcopy(EMPTY_DETECTION_JSON)
     gray_modified_json["result"]["defect_result"]["defects"] = filtered_defects
 
+    center_result = source_json.get("result", {}).get("center_result")
+    if _has_center_box_shapes(center_result):
+        gray_modified_json["result"]["center_result"] = copy.deepcopy(center_result)
+
     # 还可以把 Ring 图的宽高带过来,防止前端报错
     gray_modified_json["result"]["imageHeight"] = source_json.get("result", {}).get("imageHeight", 0)
     gray_modified_json["result"]["imageWidth"] = source_json.get("result", {}).get("imageWidth", 0)
@@ -161,6 +255,10 @@ def get_card_with_details(db_conn: PooledMySQLConnection, card_id: int) -> Optio
                 virtual_detection_json = copy.deepcopy(EMPTY_DETECTION_JSON)
                 ring_data = main_images_map.get(target_ring_type)
                 virtual_modified_json = _construct_gray_image_json(g_type, ring_data)
+                ring_detection = _parse_json_value(ring_data.get("detection_json") if ring_data else None)
+                center_result = _extract_center_result(ring_detection)
+                if center_result:
+                    virtual_detection_json = _apply_center_result(virtual_detection_json, center_result)
             elif fusion_type and fusion_type in main_images_map:
                 # ring / stripe 等在辅助表时,JSON 与融合图共用
                 fusion_row = main_images_map[fusion_type]
@@ -190,6 +288,8 @@ def get_card_with_details(db_conn: PooledMySQLConnection, card_id: int) -> Optio
             }
             final_images_list.append(CardImageResponse.model_validate(gray_image_dict))
 
+        _share_side_center_results(final_images_list)
+
         # 5. 获取分数详情:fusion/ring/coaxial 参与算分(每面仅用一份 JSON,不重复扣分)
         main_images_objs = [
             img for img in final_images_list