import cv2 import numpy as np from ..core.model_loader import get_predictor from app.utils.CardDefectAggregator import CardDefectAggregator from app.utils.arean_anylize_draw import DefectProcessor, DrawingParams from app.core.config import settings from app.core.logger import logger import json from typing import Tuple class DefectInferenceService: def defect_inference(self, inference_type: str, image_bytes: bytes, is_draw_image=False) -> dict: """ 执行卡片识别推理。 Args: inference_type: 模型类型 (e.g., 'outer_box'). image_bytes: 从API请求中获得的原始图像字节。 Returns: 一个包含推理结果的字典。 """ if inference_type == "pokemon_front_face_no_reflect_defect": # 1. 获取对应的预测器实例 predictor = get_predictor(inference_type) # 2. 将字节流解码为OpenCV图像 # 将字节数据转换为numpy数组 np_arr = np.frombuffer(image_bytes, np.uint8) # 从numpy数组中解码图像 img_bgr = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) if img_bgr is None: raise ValueError("无法解码图像,请确保上传的是有效的图片格式 (JPG, PNG, etc.)") # 3. 调用我们新加的 predict_from_image 方法进行推理 # result = predictor.predict_from_image(img_bgr) # 3. 实例化我们聚合器,传入预测器 aggregator = CardDefectAggregator( predictor=predictor, tile_size=512, overlap_ratio=0.1, # 10% 重叠 ) json_data = aggregator.process_image( image=img_bgr, mode='face' ) # merge_json_path = settings.TEMP_WORK_DIR / 'merge.json' # with open(merge_json_path, 'w', encoding='utf-8') as f: # json.dump(json_data, f, ensure_ascii=False, indent=4) # logger.info(f"合并结束") processor = DefectProcessor(pixel_resolution=settings.pixel_resolution) if is_draw_image: drawing_params_with_rect = DrawingParams(draw_min_rect=True) drawn_image_rect, result_rect = processor.analyze_and_draw(img_bgr, json_data, drawing_params_with_rect) temp_img_path = settings.TEMP_WORK_DIR / 'temp_area_result.jpg' cv2.imwrite(temp_img_path, drawn_image_rect) return result_rect else: result = processor.analyze_from_json(json_data) area_json_path = settings.TEMP_WORK_DIR / 'area.json' with open(area_json_path, 'w', encoding='utf-8') as f: json.dump(result, f, ensure_ascii=False, indent=4) logger.info("面积计算结束") return result else: return {} # 创建一个单例服务 defect_service = DefectInferenceService()