preprocess_data.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import os
  2. import glob
  3. from PIL import Image
  4. import cv2
  5. from tqdm import tqdm
  6. import shutil
  7. from utils.MyOnnxYolo import MyOnnxYolo
  8. from Config import Config
  9. config = Config()
  10. def preprocess_data(data_dir_name: str):
  11. train_data_dir = os.path.join(config.data_dir, data_dir_name, 'train')
  12. val_data_dir = os.path.join(config.data_dir, data_dir_name, 'val')
  13. if not os.path.exists(train_data_dir):
  14. return "缺失train文件目录"
  15. if not os.path.exists(val_data_dir):
  16. os.makedirs(val_data_dir)
  17. yolo_model = MyOnnxYolo(config.yolo_model_path)
  18. img_paths = glob.glob(train_data_dir + '/*/*')
  19. for img_path in tqdm(img_paths):
  20. # 排除224*224
  21. img = Image.open(img_path).convert('RGB')
  22. if img.size[0] == img.size[1] == 224:
  23. continue
  24. yolo_model.set_result(img)
  25. yolo_img = yolo_model.get_max_img(cls_id=0)
  26. img_yolo_224 = cv2.resize(yolo_img, (224, 224))
  27. cv2.imwrite(img_path, img_yolo_224)
  28. # 处理结束
  29. # 截取并压缩图片后, 生成val
  30. for img_cls_dir_name in os.listdir(train_data_dir):
  31. train_cls_dir_path = os.path.join(train_data_dir, img_cls_dir_name)
  32. img_num = len(os.listdir(train_cls_dir_path))
  33. select_num = img_num // 5 if img_num > 11 else img_num
  34. val_cls_dir_path = os.path.join(val_data_dir, img_cls_dir_name)
  35. if not os.path.exists(val_cls_dir_path):
  36. os.makedirs(val_cls_dir_path)
  37. # 选出部分图片用于做测试集
  38. for img_name in os.listdir(train_cls_dir_path)[:select_num]:
  39. img_path = os.path.join(train_cls_dir_path, img_name)
  40. save_path = os.path.join(val_cls_dir_path, img_name)
  41. shutil.copy(img_path, save_path)
  42. return "处理结束"