Parcourir la source

更新了数据预处理和数据列表刷新

AnlaAnla il y a 1 an
Parent
commit
661bb711af
5 fichiers modifiés avec 90 ajouts et 15 suppressions
  1. 1 1
      .idea/train_model_server.iml
  2. 0 1
      Config.py
  3. 31 8
      main.py
  4. 4 5
      requirements.txt
  5. 54 0
      utils/preprocess_data.py

+ 1 - 1
.idea/train_model_server.iml

@@ -6,7 +6,7 @@
     <orderEntry type="sourceFolder" forTests="false" />
   </component>
   <component name="PackageRequirementsSettings">
-    <option name="versionSpecifier" value="Greater or equal (&gt;=x.y.z)" />
+    <option name="versionSpecifier" value="大于或等于 (&gt;=x.y.z)" />
     <option name="removeUnused" value="true" />
   </component>
 </module>

+ 0 - 1
Config.py

@@ -1,5 +1,4 @@
 class Config:
-    data = 1
     def __init__(self):
         # 用于存储已创建的目录名称
         self.data_dir = './data'

+ 31 - 8
main.py

@@ -9,6 +9,7 @@ from utils.Train_resnet import load_param
 from utils.MyOnnxYolo import MyOnnxYolo
 from utils.MyModel import MyModel
 from utils.preview_img import preview_img_dir
+from utils.preprocess_data import preprocess_data
 from Config import Config
 
 # 用于存储已创建的目录名称
@@ -24,6 +25,17 @@ global_yolo_model = None
 global_cls_model = None
 
 
+def refresh_list():
+    global data_list, model_id_list
+    data_list = os.listdir(data_dir)
+    model_id_list = sorted(os.listdir(model_dir))
+
+    data_dropdown = gr.Dropdown(label="选择数据集", choices=data_list,
+                                interactive=True)
+    model_dropdown = gr.Dropdown(label="模型ID", choices=model_id_list)
+    return data_dropdown, data_dropdown, data_dropdown, model_dropdown, model_dropdown
+
+
 def process_zip(name, zip_file, progress=gr.Progress()):
     global data_list
 
@@ -51,7 +63,7 @@ def process_zip(name, zip_file, progress=gr.Progress()):
 
         # 返回解压后的文件列表和成功消息
         files = os.listdir(dir_path)
-        return (f"Extraction successful!\n\nFiles extracted to directory '{name}':\n" + "\n".join(files),
+        return (f"解压结束!\n\n数据集保存名称为: '{name}':\n" + "\n".join(files),
                 gr.update(choices=data_list, value=name))
 
     except zipfile.BadZipFile:
@@ -66,7 +78,7 @@ def validate_name(name):
     if name.strip():
         return gr.update(visible=False), gr.update(interactive=True)
     else:
-        return gr.update(visible=True, value="Directory name is required"), gr.update(interactive=False)
+        return gr.update(visible=True, value="输入数据集名称"), gr.update(interactive=False)
 
 
 # def update_epoch_display(epoch):
@@ -121,6 +133,8 @@ def predict_img(predict_data_dir_name, predict_model_dir_name, img):
 # 创建 Gradio 界面
 with gr.Blocks() as iface:
     gr.Markdown("# 分类训练")
+    with gr.Row():
+        refresh_button = gr.Button("刷新数据列表")
 
     with gr.Tabs():
         with gr.TabItem("上传ZIP数据"):
@@ -144,7 +158,7 @@ with gr.Blocks() as iface:
                 file_input = gr.File(label="解压ZIP数据")
 
             error_output = gr.Markdown(visible=False)
-            submit_btn = gr.Button("解压 ZIP", interactive=False)
+            submit_btn = gr.Button("解压 ZIP", variant='primary')
 
             result_output = gr.Textbox(label="结果")
 
@@ -154,13 +168,16 @@ with gr.Blocks() as iface:
             with gr.Row():
                 with gr.Column():
                     preview_data_dir_name = gr.Dropdown(label="选择数据集", choices=data_list,
-                                                interactive=True)
+                                                        interactive=True)
                     preview_button = gr.Button("预览部分图片", variant='primary')
                 preview_gallery = gr.Gallery(label="部分训练集图片")
 
+            with gr.Row():
+                preprocess_button = gr.Button("处理数据集", variant='primary')
+            with gr.Row():
+                preprocess_output = gr.Textbox(label="处理结果")
+
         with gr.TabItem("训练模型"):
-            data_list = os.listdir(data_dir)
-            model_id_list = sorted(os.listdir(model_dir))
             gr.Markdown("设置参数训练模型, 注意模型ID不选的话则重新训练")
 
             with gr.Row():
@@ -179,7 +196,7 @@ with gr.Blocks() as iface:
                 decay = gr.Number(label="衰减率", value=0.7)
 
             # epoch_slider.change(update_epoch_display, inputs=[epoch_slider], outputs=[epoch_display])
-            train_button = gr.Button("开始训练")
+            train_button = gr.Button("开始训练", variant='primary')
             training_result = gr.Textbox(label="训练结果")
 
         with gr.TabItem("预测单张图片"):
@@ -195,12 +212,18 @@ with gr.Blocks() as iface:
                                                     interactive=True)
                 predict_model_dir_name = gr.Dropdown(label="模型ID", choices=model_id_list)
 
-            predict_button = gr.Button("预测图片")
+            predict_button = gr.Button("预测图片", variant='primary')
             predict_result = gr.Textbox(label="预测结果")
 
+    refresh_button.click(refresh_list,
+                         outputs=[preview_data_dir_name, data_dir_name, predict_data_dir_name,
+                                  model_dir_name, predict_model_dir_name])
+
     name_input.change(validate_name, inputs=[name_input], outputs=[error_output, submit_btn])
     submit_btn.click(process_zip, inputs=[name_input, file_input], outputs=[result_output, data_dir_name])
     preview_button.click(preview_img_dir, inputs=preview_data_dir_name, outputs=preview_gallery)
+    preprocess_button.click(preprocess_data, inputs=preview_data_dir_name, outputs=preprocess_output)
+
     train_button.click(start_training,
                        inputs=[epoch_slider, data_dir_name,
                                model_dir_name, freeze_slider, learn_rate, momentum, decay],

+ 4 - 5
requirements.txt

@@ -2,8 +2,7 @@ torch>=2.3.0
 numpy>=1.22.3
 torchvision>=0.18.0
 gradio>=4.32.2
-psutil>=5.9.8
-uvicorn>=0.28.0
-fastapi>=0.110.0
-starlette>=0.36.3
-resnest --pre
+pillow>=10.2.0
+ultralytics>=8.2.73
+opencv-python>=4.9.0.80
+tqdm>=4.66.5

+ 54 - 0
utils/preprocess_data.py

@@ -0,0 +1,54 @@
+import os
+import glob
+from PIL import Image
+import cv2
+from tqdm import tqdm
+import shutil
+from utils.MyOnnxYolo import MyOnnxYolo
+from Config import Config
+
+config = Config()
+
+
+def preprocess_data(data_dir_name: str):
+    train_data_dir = os.path.join(config.data_dir, data_dir_name, 'train')
+    val_data_dir = os.path.join(config.data_dir, data_dir_name, 'val')
+    if not os.path.exists(train_data_dir):
+        return "缺失train文件目录"
+    if not os.path.exists(val_data_dir):
+        os.makedirs(val_data_dir)
+
+    yolo_model = MyOnnxYolo(config.yolo_model_path)
+    img_paths = glob.glob(train_data_dir + '/*/*')
+
+    for img_path in tqdm(img_paths):
+        # 排除224*224
+        img = Image.open(img_path).convert('RGB')
+        if img.size[0] == img.size[1] == 224:
+            continue
+
+        yolo_model.set_result(img)
+        yolo_img = yolo_model.get_max_img(cls_id=0)
+        img_yolo_224 = cv2.resize(yolo_img, (224, 224))
+
+        cv2.imwrite(img_path, img_yolo_224)
+    # 处理结束
+
+    # 截取并压缩图片后, 生成val
+    for img_cls_dir_name in os.listdir(train_data_dir):
+        train_cls_dir_path = os.path.join(train_data_dir, img_cls_dir_name)
+
+        img_num = len(os.listdir(train_cls_dir_path))
+        select_num = img_num // 5 if img_num > 11 else img_num
+
+        val_cls_dir_path = os.path.join(val_data_dir, img_cls_dir_name)
+        if not os.path.exists(val_cls_dir_path):
+            os.makedirs(val_cls_dir_path)
+
+        # 选出部分图片用于做测试集
+        for img_name in os.listdir(train_cls_dir_path)[:select_num]:
+            img_path = os.path.join(train_cls_dir_path, img_name)
+            save_path = os.path.join(val_cls_dir_path, img_name)
+            shutil.copy(img_path, save_path)
+
+    return "处理结束"