|
|
@@ -9,6 +9,7 @@ from utils.Train_resnet import load_param
|
|
|
from utils.MyOnnxYolo import MyOnnxYolo
|
|
|
from utils.MyModel import MyModel
|
|
|
from utils.preview_img import preview_img_dir
|
|
|
+from utils.preprocess_data import preprocess_data
|
|
|
from Config import Config
|
|
|
|
|
|
# 用于存储已创建的目录名称
|
|
|
@@ -24,6 +25,17 @@ global_yolo_model = None
|
|
|
global_cls_model = None
|
|
|
|
|
|
|
|
|
+def refresh_list():
|
|
|
+ global data_list, model_id_list
|
|
|
+ data_list = os.listdir(data_dir)
|
|
|
+ model_id_list = sorted(os.listdir(model_dir))
|
|
|
+
|
|
|
+ data_dropdown = gr.Dropdown(label="选择数据集", choices=data_list,
|
|
|
+ interactive=True)
|
|
|
+ model_dropdown = gr.Dropdown(label="模型ID", choices=model_id_list)
|
|
|
+ return data_dropdown, data_dropdown, data_dropdown, model_dropdown, model_dropdown
|
|
|
+
|
|
|
+
|
|
|
def process_zip(name, zip_file, progress=gr.Progress()):
|
|
|
global data_list
|
|
|
|
|
|
@@ -51,7 +63,7 @@ def process_zip(name, zip_file, progress=gr.Progress()):
|
|
|
|
|
|
# 返回解压后的文件列表和成功消息
|
|
|
files = os.listdir(dir_path)
|
|
|
- return (f"Extraction successful!\n\nFiles extracted to directory '{name}':\n" + "\n".join(files),
|
|
|
+ return (f"解压结束!\n\n数据集保存名称为: '{name}':\n" + "\n".join(files),
|
|
|
gr.update(choices=data_list, value=name))
|
|
|
|
|
|
except zipfile.BadZipFile:
|
|
|
@@ -66,7 +78,7 @@ def validate_name(name):
|
|
|
if name.strip():
|
|
|
return gr.update(visible=False), gr.update(interactive=True)
|
|
|
else:
|
|
|
- return gr.update(visible=True, value="Directory name is required"), gr.update(interactive=False)
|
|
|
+ return gr.update(visible=True, value="输入数据集名称"), gr.update(interactive=False)
|
|
|
|
|
|
|
|
|
# def update_epoch_display(epoch):
|
|
|
@@ -121,6 +133,8 @@ def predict_img(predict_data_dir_name, predict_model_dir_name, img):
|
|
|
# 创建 Gradio 界面
|
|
|
with gr.Blocks() as iface:
|
|
|
gr.Markdown("# 分类训练")
|
|
|
+ with gr.Row():
|
|
|
+ refresh_button = gr.Button("刷新数据列表")
|
|
|
|
|
|
with gr.Tabs():
|
|
|
with gr.TabItem("上传ZIP数据"):
|
|
|
@@ -144,7 +158,7 @@ with gr.Blocks() as iface:
|
|
|
file_input = gr.File(label="解压ZIP数据")
|
|
|
|
|
|
error_output = gr.Markdown(visible=False)
|
|
|
- submit_btn = gr.Button("解压 ZIP", interactive=False)
|
|
|
+ submit_btn = gr.Button("解压 ZIP", variant='primary')
|
|
|
|
|
|
result_output = gr.Textbox(label="结果")
|
|
|
|
|
|
@@ -154,13 +168,16 @@ with gr.Blocks() as iface:
|
|
|
with gr.Row():
|
|
|
with gr.Column():
|
|
|
preview_data_dir_name = gr.Dropdown(label="选择数据集", choices=data_list,
|
|
|
- interactive=True)
|
|
|
+ interactive=True)
|
|
|
preview_button = gr.Button("预览部分图片", variant='primary')
|
|
|
preview_gallery = gr.Gallery(label="部分训练集图片")
|
|
|
|
|
|
+ with gr.Row():
|
|
|
+ preprocess_button = gr.Button("处理数据集", variant='primary')
|
|
|
+ with gr.Row():
|
|
|
+ preprocess_output = gr.Textbox(label="处理结果")
|
|
|
+
|
|
|
with gr.TabItem("训练模型"):
|
|
|
- data_list = os.listdir(data_dir)
|
|
|
- model_id_list = sorted(os.listdir(model_dir))
|
|
|
gr.Markdown("设置参数训练模型, 注意模型ID不选的话则重新训练")
|
|
|
|
|
|
with gr.Row():
|
|
|
@@ -179,7 +196,7 @@ with gr.Blocks() as iface:
|
|
|
decay = gr.Number(label="衰减率", value=0.7)
|
|
|
|
|
|
# epoch_slider.change(update_epoch_display, inputs=[epoch_slider], outputs=[epoch_display])
|
|
|
- train_button = gr.Button("开始训练")
|
|
|
+ train_button = gr.Button("开始训练", variant='primary')
|
|
|
training_result = gr.Textbox(label="训练结果")
|
|
|
|
|
|
with gr.TabItem("预测单张图片"):
|
|
|
@@ -195,12 +212,18 @@ with gr.Blocks() as iface:
|
|
|
interactive=True)
|
|
|
predict_model_dir_name = gr.Dropdown(label="模型ID", choices=model_id_list)
|
|
|
|
|
|
- predict_button = gr.Button("预测图片")
|
|
|
+ predict_button = gr.Button("预测图片", variant='primary')
|
|
|
predict_result = gr.Textbox(label="预测结果")
|
|
|
|
|
|
+ refresh_button.click(refresh_list,
|
|
|
+ outputs=[preview_data_dir_name, data_dir_name, predict_data_dir_name,
|
|
|
+ model_dir_name, predict_model_dir_name])
|
|
|
+
|
|
|
name_input.change(validate_name, inputs=[name_input], outputs=[error_output, submit_btn])
|
|
|
submit_btn.click(process_zip, inputs=[name_input, file_input], outputs=[result_output, data_dir_name])
|
|
|
preview_button.click(preview_img_dir, inputs=preview_data_dir_name, outputs=preview_gallery)
|
|
|
+ preprocess_button.click(preprocess_data, inputs=preview_data_dir_name, outputs=preprocess_output)
|
|
|
+
|
|
|
train_button.click(start_training,
|
|
|
inputs=[epoch_slider, data_dir_name,
|
|
|
model_dir_name, freeze_slider, learn_rate, momentum, decay],
|