| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354 |
- import os
- import glob
- from PIL import Image
- import cv2
- from tqdm import tqdm
- import shutil
- from utils.MyOnnxYolo import MyOnnxYolo
- from Config import Config
- config = Config()
- def preprocess_data(data_dir_name: str):
- train_data_dir = os.path.join(config.data_dir, data_dir_name, 'train')
- val_data_dir = os.path.join(config.data_dir, data_dir_name, 'val')
- if not os.path.exists(train_data_dir):
- return "缺失train文件目录"
- if not os.path.exists(val_data_dir):
- os.makedirs(val_data_dir)
- yolo_model = MyOnnxYolo(config.yolo_model_path)
- img_paths = glob.glob(train_data_dir + '/*/*')
- for img_path in tqdm(img_paths):
- # 排除224*224
- img = Image.open(img_path).convert('RGB')
- if img.size[0] == img.size[1] == 224:
- continue
- yolo_model.set_result(img)
- yolo_img = yolo_model.get_max_img(cls_id=0)
- img_yolo_224 = cv2.resize(yolo_img, (224, 224))
- cv2.imwrite(img_path, img_yolo_224)
- # 处理结束
- # 截取并压缩图片后, 生成val
- for img_cls_dir_name in os.listdir(train_data_dir):
- train_cls_dir_path = os.path.join(train_data_dir, img_cls_dir_name)
- img_num = len(os.listdir(train_cls_dir_path))
- select_num = img_num // 5 if img_num > 11 else img_num
- val_cls_dir_path = os.path.join(val_data_dir, img_cls_dir_name)
- if not os.path.exists(val_cls_dir_path):
- os.makedirs(val_cls_dir_path)
- # 选出部分图片用于做测试集
- for img_name in os.listdir(train_cls_dir_path)[:select_num]:
- img_path = os.path.join(train_cls_dir_path, img_name)
- save_path = os.path.join(val_cls_dir_path, img_name)
- shutil.copy(img_path, save_path)
- return "处理结束"
|