MyOnnxYolo.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. from ultralytics import YOLO
  2. '''用法
  3. model = MyOnnxYolo(r"yolo_handcard01.onnx")
  4. img = Image.open(img_path).convert('RGB')
  5. model.set_result(img)
  6. yolo_img = model.get_max_img(cls_id=0)
  7. '''
  8. class MyOnnxYolo:
  9. # cls_id {card:0, person:1, hand:2}
  10. def __init__(self, model_path):
  11. # 加载yolo model
  12. self.model = YOLO(model_path, verbose=False, task='detect')
  13. self.results = None
  14. self.cls = None
  15. self.boxes = None
  16. self.img = None
  17. def set_result(self, img, imgsz=640):
  18. self.results = self.model.predict(img, max_det=3, verbose=False, imgsz=imgsz)
  19. self.img = self.results[0].orig_img
  20. self.boxes = self.results[0].boxes.xyxy.cpu()
  21. self.cls = self.results[0].boxes.cls
  22. def check(self, cls_id: int):
  23. if cls_id in self.cls:
  24. return True
  25. return False
  26. def get_max_img(self, cls_id: int):
  27. # cls_id {card:0, person:1, hand:2}
  28. # 排除没有检测到物体 或 截取的id不存在的图片
  29. if len(self.boxes) == 0 or cls_id not in self.cls:
  30. return self.img
  31. max_area = 0
  32. # 选出最大的卡片框
  33. x1, y1, x2, y2 = 0, 0, 0, 0
  34. for i, box in enumerate(self.boxes):
  35. if self.cls[i] != cls_id:
  36. continue
  37. temp_x1, temp_y1, temp_x2, temp_y2 = box
  38. area = (temp_x2 - temp_x1) * (temp_y2 - temp_y1)
  39. if area > max_area:
  40. max_area = area
  41. x1, y1, x2, y2 = temp_x1, temp_y1, temp_x2, temp_y2
  42. x1 = int(x1)
  43. x2 = int(x2)
  44. y1 = int(y1)
  45. y2 = int(y2)
  46. max_img = self.img[y1:y2, x1:x2, :]
  47. return max_img