import glob import gradio as gr import os import zipfile import tempfile import time from utils.Train_resnet import load_param from utils.MyOnnxYolo import MyOnnxYolo from utils.MyModel import MyModel from utils.preview_img import preview_img_dir from utils.preprocess_data import preprocess_data from Config import Config # 用于存储已创建的目录名称 config = Config() data_dir = config.data_dir model_dir = config.model_dir os.makedirs(data_dir, exist_ok=True) data_list = os.listdir(data_dir) model_id_list = sorted(os.listdir(model_dir)) global_yolo_model = None global_cls_model = None def refresh_list(): global data_list, model_id_list data_list = os.listdir(data_dir) model_id_list = sorted(os.listdir(model_dir)) data_dropdown = gr.Dropdown(label="选择数据集", choices=data_list, interactive=True) model_dropdown = gr.Dropdown(label="模型ID", choices=model_id_list) return data_dropdown, data_dropdown, data_dropdown, model_dropdown, model_dropdown def process_zip(name, zip_file, progress=gr.Progress()): global data_list if not name.strip(): return "Error: 需要输入数据集名称.", gr.update() try: # 创建一个以名称命名的目录 dir_path = os.path.join(data_dir, name) os.makedirs(dir_path, exist_ok=True) # 获取 ZIP 文件中的文件列表 with zipfile.ZipFile(zip_file.name, 'r') as zip_ref: file_list = zip_ref.namelist() total_files = len(file_list) # 解压 ZIP 文件到该目录,同时更新进度条 for i, file in enumerate(file_list): zip_ref.extract(file, dir_path) progress((i + 1) / total_files, f"Extracting file {i + 1} of {total_files}") # 添加新目录到列表中 if name not in data_list: data_list.append(name) # 返回解压后的文件列表和成功消息 files = os.listdir(dir_path) return (f"解压结束!\n\n数据集保存名称为: '{name}':\n" + "\n".join(files), gr.update(choices=data_list, value=name)) except zipfile.BadZipFile: return "Error: The uploaded file is not a valid ZIP file.", gr.update() except PermissionError: return "Error: Permission denied. Unable to create directory or extract files.", gr.update() except Exception as e: return f"Error: An unexpected error occurred: {str(e)}", gr.update() def validate_name(name): if name.strip(): return gr.update(visible=False), gr.update(interactive=True) else: return gr.update(visible=True, value="输入数据集名称"), gr.update(interactive=False) # def update_epoch_display(epoch): # return f"Current epoch: {epoch}" def start_training(epoch, data_dir_name, model_dir_name, freeze_slider, learn_rate, momentum, decay): global global_yolo_model, global_cls_model # 释放内存 global_yolo_model = None global_cls_model = None print(epoch, data_dir_name, model_dir_name, freeze_slider, learn_rate, momentum, decay) train_data_path = os.path.join(data_dir, data_dir_name) if model_dir_name is not None: model_dir_path = os.path.join(model_dir, model_dir_name) model_path = glob.glob(os.path.join(model_dir_path, "*.pt"))[0] else: model_path = None load_param(epoch=epoch, data_dir=train_data_path, model_path=model_path, freeze=freeze_slider, learn_rate=learn_rate, momentum=momentum, decay=decay) # time.sleep(epoch * 0.2) # 模拟短暂的训练过程 return "训练结束", gr.update() def predict_img(predict_data_dir_name, predict_model_dir_name, img): global global_yolo_model, global_cls_model print("predict_img: ", data_dir_name, model_dir_name) if predict_data_dir_name is None or predict_model_dir_name is None: return "2个参数为必选项" cls_list = sorted(os.listdir(os.path.join(data_dir, predict_data_dir_name, 'train'))) print(cls_list) model_dir_path = os.path.join(model_dir, predict_model_dir_name) model_path = glob.glob(os.path.join(model_dir_path, "*.pt"))[0] if global_yolo_model is None: global_yolo_model = MyOnnxYolo(config.yolo_model_path) if global_cls_model is None: global_cls_model = MyModel(model_path) global_yolo_model.set_result(img) card_img = global_yolo_model.get_max_img(cls_id=0) cls_id = global_cls_model.run(card_img) return cls_list[cls_id], gr.update() # 创建 Gradio 界面 with gr.Blocks() as iface: gr.Markdown("# 分类训练") with gr.Row(): refresh_button = gr.Button("刷新数据列表") with gr.Tabs(): with gr.TabItem("上传ZIP数据"): gr.Markdown( ''' 输入数数据名称, 并上传zip数据, 点击解压zip
如果是原始数据, 请将目录结构准备为:
```python ---数据集名称: --train: -分类1:{图片1,图片2, ...}, -分类2:{图片1,图片2, ...}, ...... ``` 之后进行数据处理后训练 ''' ) with gr.Row(): name_input = gr.Textbox(label="输入数据集名称", placeholder="Enter a name for the directory") file_input = gr.File(label="解压ZIP数据") error_output = gr.Markdown(visible=False) submit_btn = gr.Button("解压 ZIP", variant='primary') result_output = gr.Textbox(label="结果") with gr.TabItem("预览, 处理数据"): gr.Markdown( "预览部分训练数据, 如果数据未处理, 那么进行处理, ") with gr.Row(): with gr.Column(): preview_data_dir_name = gr.Dropdown(label="选择数据集", choices=data_list, interactive=True) preview_button = gr.Button("预览部分图片", variant='primary') preview_gallery = gr.Gallery(label="部分训练集图片") with gr.Row(): preprocess_button = gr.Button("处理数据集", variant='primary') with gr.Row(): preprocess_output = gr.Textbox(label="处理结果") with gr.TabItem("训练模型"): gr.Markdown("设置参数训练模型, 注意模型ID不选的话则重新训练") with gr.Row(): data_dir_name = gr.Dropdown(label="训练的数据集", choices=data_list, interactive=True) epoch_slider = gr.Slider(minimum=5, maximum=50, value=10, step=1, label="轮次") # epoch_display = gr.Markdown("Current epoch: 10") with gr.Row(): model_dir_name = gr.Dropdown(label="模型ID", choices=model_id_list) freeze_slider = gr.Slider(minimum=0, maximum=7, value=7, step=1, label="冻结参数层") with gr.Row(): learn_rate = gr.Number(label="学习率", value=0.01) momentum = gr.Number(label="动量", value=0.9) decay = gr.Number(label="衰减率", value=0.7) # epoch_slider.change(update_epoch_display, inputs=[epoch_slider], outputs=[epoch_display]) train_button = gr.Button("开始训练", variant='primary') training_result = gr.Textbox(label="训练结果") with gr.TabItem("预测单张图片"): data_list = os.listdir(data_dir) model_id_list = sorted(os.listdir(model_dir)) gr.Markdown("输入一张图片返回预测结果") with gr.Row(): img = gr.Image(label="上传图片") with gr.Row(): predict_data_dir_name = gr.Dropdown(label="训练的数据集", choices=data_list, interactive=True) predict_model_dir_name = gr.Dropdown(label="模型ID", choices=model_id_list) predict_button = gr.Button("预测图片", variant='primary') predict_result = gr.Textbox(label="预测结果") refresh_button.click(refresh_list, outputs=[preview_data_dir_name, data_dir_name, predict_data_dir_name, model_dir_name, predict_model_dir_name]) name_input.change(validate_name, inputs=[name_input], outputs=[error_output, submit_btn]) submit_btn.click(process_zip, inputs=[name_input, file_input], outputs=[result_output, data_dir_name]) preview_button.click(preview_img_dir, inputs=preview_data_dir_name, outputs=preview_gallery) preprocess_button.click(preprocess_data, inputs=preview_data_dir_name, outputs=preprocess_output) train_button.click(start_training, inputs=[epoch_slider, data_dir_name, model_dir_name, freeze_slider, learn_rate, momentum, decay], outputs=[training_result]) predict_button.click(predict_img, inputs=[predict_data_dir_name, predict_model_dir_name, img], outputs=[predict_result]) # 启动界面 iface.launch(server_name="0.0.0.0")