MyBatchOnnxYolo.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import cv2
  2. import numpy as np
  3. from ultralytics import YOLO
  4. from typing import List, Union, Optional
  5. import PIL.Image
  6. ImageType = Union[str, np.ndarray, PIL.Image.Image]
  7. class MyBatchOnnxYolo:
  8. """
  9. 使用 YOLO 模型进行批处理目标检测/分割,并提供提取最大目标区域的功能。
  10. cls_id {card:0} - 根据你的模型调整
  11. """
  12. def __init__(self, model_path: str, task: str = 'segment', verbose: bool = False):
  13. # 加载yolo model
  14. self.model = YOLO(model_path, task=task, verbose=verbose)
  15. self.results: Optional[List] = None # 将存储批处理的结果列表
  16. self.batch_size: int = 0
  17. def predict_batch(self, image_list: List[ImageType], imgsz: int = 640, **kwargs):
  18. """
  19. 对一批图像进行预测。
  20. Args:
  21. image_list (List[ImageType]): 包含图像路径、PIL Image 或 NumPy 数组的列表。
  22. imgsz (int): 推理的图像尺寸。
  23. **kwargs: 其他传递给 model.predict 的参数 (例如 conf, iou)。
  24. """
  25. if not image_list:
  26. print("Warning: Input image list is empty.")
  27. self.results = []
  28. self.batch_size = 0
  29. return
  30. # 使用 YOLO 的批处理能力
  31. self.results = self.model.predict(image_list, verbose=False, imgsz=imgsz, **kwargs)
  32. self.batch_size = len(self.results)
  33. def get_batch_size(self) -> int:
  34. return self.batch_size
  35. def _get_result_at_index(self, index: int):
  36. """内部辅助方法,获取指定索引的结果,并进行边界检查。"""
  37. if self.results is None:
  38. raise ValueError("Must call predict_batch() before accessing results.")
  39. if not (0 <= index < self.batch_size):
  40. raise IndexError(f"Index {index} is out of bounds for batch size {self.batch_size}.")
  41. return self.results[index]
  42. def check(self, index: int, cls_id: int) -> bool:
  43. """
  44. 检查指定索引的图像结果中是否存在特定的类别ID。
  45. Args:
  46. index (int): 图像在批处理中的索引 (从0开始)。
  47. cls_id (int): 要检查的类别ID。
  48. Returns:
  49. bool: 如果存在该类别ID,则返回 True,否则返回 False。
  50. """
  51. result = self._get_result_at_index(index)
  52. if result.boxes is None or len(result.boxes) == 0:
  53. return False
  54. # .cls 可能为空 Tensor,需要检查
  55. return result.boxes.cls is not None and cls_id in result.boxes.cls.cpu().tolist()
  56. def get_max_img(self, index: int, cls_id: int = 0) -> Optional[np.ndarray]:
  57. """
  58. 从指定索引的图像结果中,提取指定类别ID的最大边界框对应的图像区域。
  59. Args:
  60. index (int): 图像在批处理中的索引 (从0开始)。
  61. cls_id (int): 要提取的目标类别ID 默认0
  62. Returns:
  63. Optional[np.ndarray]: 裁剪出的最大目标的图像区域 (RGB NumPy 数组),
  64. 如果未找到该类别或无检测结果,则返回原始图像。
  65. """
  66. result = self._get_result_at_index(index)
  67. orig_img = result.orig_img # 通常是 BGR NumPy 数组
  68. boxes = result.boxes
  69. # 检查是否有检测框以及是否有对应的类别
  70. if boxes is None or len(boxes) == 0 or boxes.cls is None or cls_id not in boxes.cls.cpu():
  71. print(
  72. f"Warning: No detections or cls_id {cls_id} not found for image at index {index}. Returning original image.")
  73. # 返回原始图像的 RGB 版本
  74. return cv2.cvtColor(orig_img, cv2.COLOR_BGR2RGB) if orig_img is not None else None
  75. max_area = 0.0
  76. max_box = None
  77. xyxy_boxes = boxes.xyxy.cpu().numpy()
  78. cls_list = boxes.cls.cpu().numpy()
  79. # 选出最大的目标框
  80. for i, box in enumerate(xyxy_boxes):
  81. if cls_list[i] != cls_id:
  82. continue
  83. temp_x1, temp_y1, temp_x2, temp_y2 = box
  84. area = (temp_x2 - temp_x1) * (temp_y2 - temp_y1)
  85. if area > max_area:
  86. max_area = area
  87. max_box = box
  88. # 如果没有找到对应 cls_id 的框 (理论上前面已检查,但多一层保险)
  89. if max_box is None:
  90. print(
  91. f"Warning: cls_id {cls_id} found in cls_list but failed to find max box for image at index {index}. Returning original image.")
  92. return cv2.cvtColor(orig_img, cv2.COLOR_BGR2RGB) if orig_img is not None else None
  93. x1, y1, x2, y2 = map(int, max_box) # 转换为整数坐标
  94. # 边界处理,防止裁剪坐标超出图像范围
  95. h, w = orig_img.shape[:2]
  96. x1 = max(0, x1)
  97. y1 = max(0, y1)
  98. x2 = min(w, x2)
  99. y2 = min(h, y2)
  100. # 检查裁剪区域是否有效
  101. if x1 >= x2 or y1 >= y2:
  102. print(
  103. f"Warning: Invalid crop dimensions [{y1}:{y2}, {x1}:{x2}] for image at index {index}. Returning original image.")
  104. return cv2.cvtColor(orig_img, cv2.COLOR_BGR2RGB) if orig_img is not None else None
  105. # 裁剪图像 (orig_img 通常是 BGR)
  106. max_img_crop = orig_img[y1:y2, x1:x2]
  107. # 将裁剪结果转换为 RGB (与 matplotlib 和 PIL 更兼容)
  108. max_img_rgb = cv2.cvtColor(max_img_crop, cv2.COLOR_BGR2RGB)
  109. return max_img_rgb
  110. def get_max_img_list(self, cls_id: int = 0) -> List[Optional[np.ndarray]]:
  111. """
  112. 对批处理中的每张图片,提取指定类别ID的最大边界框对应的图像区域。
  113. Args:
  114. cls_id (int): 要提取的目标类别ID 默认0
  115. Returns:
  116. List[Optional[np.ndarray]]: 包含处理后图像 (RGB NumPy 数组) 的列表。
  117. 对于成功裁剪的图片,列表元素是裁剪后的图像。
  118. 如果某张图片未找到指定类别或裁剪失败,列表元素是该图片的原始图像(RGB)。
  119. 如果原始图像无效,则列表元素为 None。
  120. """
  121. if self.results is None:
  122. raise ValueError("Must call predict_batch() before calling get_max_img_list().")
  123. processed_images: List[Optional[np.ndarray]] = []
  124. for i in range(self.batch_size):
  125. # 调用 get_max_img 获取单张图片的处理结果
  126. processed_img = self.get_max_img(index=i, cls_id=cls_id)
  127. processed_images.append(processed_img)
  128. return processed_images
  129. def get_results(self) -> Optional[List]:
  130. """获取完整的批处理结果列表。"""
  131. return self.results