yolo_crop_img.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import torch
  2. from PIL import Image, ImageOps, ImageFile
  3. import numpy as np
  4. import os
  5. import glob
  6. import time
  7. from concurrent.futures import ThreadPoolExecutor
  8. ImageFile.LOAD_TRUNCATED_IMAGES = True
  9. yolo_model = torch.hub.load(r"C:\Users\Administrator\.cache\torch\hub\ultralytics_yolov5_master",
  10. 'custom', path=r"D:\Code\ML\item2\towhee_test\yolov5s.pt", source='local')
  11. num = 0
  12. def yolo_detect(img):
  13. results = yolo_model(img)
  14. pred = results.pred[0].cpu().numpy()
  15. # 这是第五列等于0的行,0为card,也就是截出卡的图片
  16. pred = pred[pred[:, 5] == 0][:, :4]
  17. boxes = pred.astype(np.int32)
  18. max_img = get_object(img, boxes)
  19. return max_img
  20. def get_object(img, boxes):
  21. if isinstance(img, str):
  22. img = Image.open(img)
  23. if len(boxes) == 0:
  24. return img
  25. max_area = 0
  26. # 选出最大的人框
  27. x1, y1, x2, y2 = 0, 0, 0, 0
  28. for box in boxes:
  29. temp_x1, temp_y1, temp_x2, temp_y2 = box
  30. area = (temp_x2 - temp_x1) * (temp_y2 - temp_y1)
  31. if area > max_area:
  32. max_area = area
  33. x1, y1, x2, y2 = temp_x1, temp_y1, temp_x2, temp_y2
  34. max_img = img.crop((x1, y1, x2, y2))
  35. return max_img
  36. def yolo_crop_img(img_path):
  37. global num
  38. num += 1
  39. print("{}, {}".format(num, img_path))
  40. try:
  41. max_img = yolo_detect(img_path)
  42. max_img.save(img_path)
  43. except Exception as e:
  44. print("Error processing image {}: {}".format(img_path, e))
  45. if __name__ == '__main__':
  46. source_dir = r"D:\Code\ML\images\Mywork3\card_dataset_yolo"
  47. img_paths = glob.glob(os.path.join(source_dir, "*", "*"))
  48. t1 = time.time()
  49. # with ThreadPoolExecutor(max_workers=6) as executor:
  50. # executor.map(yolo_crop_img, img_paths)
  51. for img_path in img_paths:
  52. yolo_crop_img(img_path)
  53. t2 = time.time()
  54. print('end, time:', (t2 - t1))