AnlaAnla 1 год назад
Родитель
Сommit
0f814c5c0f
1 измененных файлов с 213 добавлено и 0 удалено
  1. 213 0
      main.py

+ 213 - 0
main.py

@@ -0,0 +1,213 @@
+import glob
+
+import gradio as gr
+import os
+import zipfile
+import tempfile
+import time
+from utils.Train_resnet import load_param
+from utils.MyOnnxYolo import MyOnnxYolo
+from utils.MyModel import MyModel
+from utils.preview_img import preview_img_dir
+from Config import Config
+
+# 用于存储已创建的目录名称
+config = Config()
+data_dir = config.data_dir
+model_dir = config.model_dir
+
+os.makedirs(data_dir, exist_ok=True)
+data_list = os.listdir(data_dir)
+model_id_list = sorted(os.listdir(model_dir))
+
+global_yolo_model = None
+global_cls_model = None
+
+
+def process_zip(name, zip_file, progress=gr.Progress()):
+    global data_list
+
+    if not name.strip():
+        return "Error: 需要输入数据集名称.", gr.update()
+
+    try:
+        # 创建一个以名称命名的目录
+        dir_path = os.path.join(data_dir, name)
+        os.makedirs(dir_path, exist_ok=True)
+
+        # 获取 ZIP 文件中的文件列表
+        with zipfile.ZipFile(zip_file.name, 'r') as zip_ref:
+            file_list = zip_ref.namelist()
+            total_files = len(file_list)
+
+            # 解压 ZIP 文件到该目录,同时更新进度条
+            for i, file in enumerate(file_list):
+                zip_ref.extract(file, dir_path)
+                progress((i + 1) / total_files, f"Extracting file {i + 1} of {total_files}")
+
+        # 添加新目录到列表中
+        if name not in data_list:
+            data_list.append(name)
+
+        # 返回解压后的文件列表和成功消息
+        files = os.listdir(dir_path)
+        return (f"Extraction successful!\n\nFiles extracted to directory '{name}':\n" + "\n".join(files),
+                gr.update(choices=data_list, value=name))
+
+    except zipfile.BadZipFile:
+        return "Error: The uploaded file is not a valid ZIP file.", gr.update()
+    except PermissionError:
+        return "Error: Permission denied. Unable to create directory or extract files.", gr.update()
+    except Exception as e:
+        return f"Error: An unexpected error occurred: {str(e)}", gr.update()
+
+
+def validate_name(name):
+    if name.strip():
+        return gr.update(visible=False), gr.update(interactive=True)
+    else:
+        return gr.update(visible=True, value="Directory name is required"), gr.update(interactive=False)
+
+
+# def update_epoch_display(epoch):
+#     return f"Current epoch: {epoch}"
+
+def start_training(epoch, data_dir_name, model_dir_name, freeze_slider, learn_rate, momentum, decay):
+    global global_yolo_model, global_cls_model
+    # 释放内存
+    global_yolo_model = None
+    global_cls_model = None
+
+    print(epoch, data_dir_name, model_dir_name, freeze_slider, learn_rate, momentum, decay)
+
+    train_data_path = os.path.join(data_dir, data_dir_name)
+
+    if model_dir_name is not None:
+        model_dir_path = os.path.join(model_dir, model_dir_name)
+        model_path = glob.glob(os.path.join(model_dir_path, "*.pt"))[0]
+    else:
+        model_path = None
+
+    load_param(epoch=epoch, data_dir=train_data_path, model_path=model_path,
+               freeze=freeze_slider, learn_rate=learn_rate, momentum=momentum, decay=decay)
+    # time.sleep(epoch * 0.2)  # 模拟短暂的训练过程
+    return "训练结束", gr.update()
+
+
+def predict_img(predict_data_dir_name, predict_model_dir_name, img):
+    global global_yolo_model, global_cls_model
+    print("predict_img: ", data_dir_name, model_dir_name)
+    if predict_data_dir_name is None or predict_model_dir_name is None:
+        return "2个参数为必选项"
+
+    cls_list = sorted(os.listdir(os.path.join(data_dir, predict_data_dir_name, 'train')))
+    print(cls_list)
+
+    model_dir_path = os.path.join(model_dir, predict_model_dir_name)
+    model_path = glob.glob(os.path.join(model_dir_path, "*.pt"))[0]
+
+    if global_yolo_model is None:
+        global_yolo_model = MyOnnxYolo(config.yolo_model_path)
+    if global_cls_model is None:
+        global_cls_model = MyModel(model_path)
+
+    global_yolo_model.set_result(img)
+    card_img = global_yolo_model.get_max_img(cls_id=0)
+    cls_id = global_cls_model.run(card_img)
+
+    return cls_list[cls_id], gr.update()
+
+
+# 创建 Gradio 界面
+with gr.Blocks() as iface:
+    gr.Markdown("# 分类训练")
+
+    with gr.Tabs():
+        with gr.TabItem("上传ZIP数据"):
+            gr.Markdown(
+                '''
+                输入数数据名称, 并上传zip数据, 点击解压zip <br>
+                如果是原始数据, 请将目录结构准备为: <br>
+                ```python
+                ---数据集名称:
+                   --train:
+                       -分类1:{图片1,图片2, ...},
+                       -分类2:{图片1,图片2, ...},
+                        ......
+                ```
+                之后进行数据处理后训练
+                '''
+            )
+
+            with gr.Row():
+                name_input = gr.Textbox(label="输入数据集名称", placeholder="Enter a name for the directory")
+                file_input = gr.File(label="解压ZIP数据")
+
+            error_output = gr.Markdown(visible=False)
+            submit_btn = gr.Button("解压 ZIP", interactive=False)
+
+            result_output = gr.Textbox(label="结果")
+
+        with gr.TabItem("预览, 处理数据"):
+            gr.Markdown(
+                "预览部分训练数据, 如果数据未处理, 那么进行处理, ")
+            with gr.Row():
+                with gr.Column():
+                    preview_data_dir_name = gr.Dropdown(label="选择数据集", choices=data_list,
+                                                interactive=True)
+                    preview_button = gr.Button("预览部分图片", variant='primary')
+                preview_gallery = gr.Gallery(label="部分训练集图片")
+
+        with gr.TabItem("训练模型"):
+            data_list = os.listdir(data_dir)
+            model_id_list = sorted(os.listdir(model_dir))
+            gr.Markdown("设置参数训练模型, 注意模型ID不选的话则重新训练")
+
+            with gr.Row():
+                data_dir_name = gr.Dropdown(label="训练的数据集", choices=data_list,
+                                            interactive=True)
+                epoch_slider = gr.Slider(minimum=5, maximum=50, value=10, step=1, label="轮次")
+                # epoch_display = gr.Markdown("Current epoch: 10")
+
+            with gr.Row():
+                model_dir_name = gr.Dropdown(label="模型ID", choices=model_id_list)
+                freeze_slider = gr.Slider(minimum=0, maximum=7, value=7, step=1, label="冻结参数层")
+
+            with gr.Row():
+                learn_rate = gr.Number(label="学习率", value=0.01)
+                momentum = gr.Number(label="动量", value=0.9)
+                decay = gr.Number(label="衰减率", value=0.7)
+
+            # epoch_slider.change(update_epoch_display, inputs=[epoch_slider], outputs=[epoch_display])
+            train_button = gr.Button("开始训练")
+            training_result = gr.Textbox(label="训练结果")
+
+        with gr.TabItem("预测单张图片"):
+            data_list = os.listdir(data_dir)
+            model_id_list = sorted(os.listdir(model_dir))
+
+            gr.Markdown("输入一张图片返回预测结果")
+            with gr.Row():
+                img = gr.Image(label="上传图片")
+
+            with gr.Row():
+                predict_data_dir_name = gr.Dropdown(label="训练的数据集", choices=data_list,
+                                                    interactive=True)
+                predict_model_dir_name = gr.Dropdown(label="模型ID", choices=model_id_list)
+
+            predict_button = gr.Button("预测图片")
+            predict_result = gr.Textbox(label="预测结果")
+
+    name_input.change(validate_name, inputs=[name_input], outputs=[error_output, submit_btn])
+    submit_btn.click(process_zip, inputs=[name_input, file_input], outputs=[result_output, data_dir_name])
+    preview_button.click(preview_img_dir, inputs=preview_data_dir_name, outputs=preview_gallery)
+    train_button.click(start_training,
+                       inputs=[epoch_slider, data_dir_name,
+                               model_dir_name, freeze_slider, learn_rate, momentum, decay],
+                       outputs=[training_result])
+    predict_button.click(predict_img,
+                         inputs=[predict_data_dir_name, predict_model_dir_name, img],
+                         outputs=[predict_result])
+
+# 启动界面
+iface.launch(server_name="0.0.0.0")