main.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. import glob
  2. import gradio as gr
  3. import os
  4. import zipfile
  5. import tempfile
  6. import time
  7. from utils.Train_resnet import load_param
  8. from utils.MyOnnxYolo import MyOnnxYolo
  9. from utils.MyModel import MyModel
  10. from utils.preview_img import preview_img_dir
  11. from utils.preprocess_data import preprocess_data
  12. from Config import Config
  13. # 用于存储已创建的目录名称
  14. config = Config()
  15. data_dir = config.data_dir
  16. model_dir = config.model_dir
  17. os.makedirs(data_dir, exist_ok=True)
  18. data_list = os.listdir(data_dir)
  19. model_id_list = sorted(os.listdir(model_dir))
  20. global_yolo_model = None
  21. global_cls_model = None
  22. def refresh_list():
  23. global data_list, model_id_list
  24. data_list = os.listdir(data_dir)
  25. model_id_list = sorted(os.listdir(model_dir))
  26. data_dropdown = gr.Dropdown(label="选择数据集", choices=data_list,
  27. interactive=True)
  28. model_dropdown = gr.Dropdown(label="模型ID", choices=model_id_list)
  29. return data_dropdown, data_dropdown, data_dropdown, model_dropdown, model_dropdown
  30. def process_zip(name, zip_file, progress=gr.Progress()):
  31. global data_list
  32. if not name.strip():
  33. return "Error: 需要输入数据集名称.", gr.update()
  34. try:
  35. # 创建一个以名称命名的目录
  36. dir_path = os.path.join(data_dir, name)
  37. os.makedirs(dir_path, exist_ok=True)
  38. # 获取 ZIP 文件中的文件列表
  39. with zipfile.ZipFile(zip_file.name, 'r') as zip_ref:
  40. file_list = zip_ref.namelist()
  41. total_files = len(file_list)
  42. # 解压 ZIP 文件到该目录,同时更新进度条
  43. for i, file in enumerate(file_list):
  44. zip_ref.extract(file, dir_path)
  45. progress((i + 1) / total_files, f"Extracting file {i + 1} of {total_files}")
  46. # 添加新目录到列表中
  47. if name not in data_list:
  48. data_list.append(name)
  49. # 返回解压后的文件列表和成功消息
  50. files = os.listdir(dir_path)
  51. return (f"解压结束!\n\n数据集保存名称为: '{name}':\n" + "\n".join(files),
  52. gr.update(choices=data_list, value=name))
  53. except zipfile.BadZipFile:
  54. return "Error: The uploaded file is not a valid ZIP file.", gr.update()
  55. except PermissionError:
  56. return "Error: Permission denied. Unable to create directory or extract files.", gr.update()
  57. except Exception as e:
  58. return f"Error: An unexpected error occurred: {str(e)}", gr.update()
  59. def validate_name(name):
  60. if name.strip():
  61. return gr.update(visible=False), gr.update(interactive=True)
  62. else:
  63. return gr.update(visible=True, value="输入数据集名称"), gr.update(interactive=False)
  64. # def update_epoch_display(epoch):
  65. # return f"Current epoch: {epoch}"
  66. def start_training(epoch, data_dir_name, model_dir_name, freeze_slider, learn_rate, momentum, decay):
  67. global global_yolo_model, global_cls_model
  68. # 释放内存
  69. global_yolo_model = None
  70. global_cls_model = None
  71. print(epoch, data_dir_name, model_dir_name, freeze_slider, learn_rate, momentum, decay)
  72. train_data_path = os.path.join(data_dir, data_dir_name)
  73. if model_dir_name is not None:
  74. model_dir_path = os.path.join(model_dir, model_dir_name)
  75. model_path = glob.glob(os.path.join(model_dir_path, "*.pt"))[0]
  76. else:
  77. model_path = None
  78. load_param(epoch=epoch, data_dir=train_data_path, model_path=model_path,
  79. freeze=freeze_slider, learn_rate=learn_rate, momentum=momentum, decay=decay)
  80. # time.sleep(epoch * 0.2) # 模拟短暂的训练过程
  81. return "训练结束", gr.update()
  82. def predict_img(predict_data_dir_name, predict_model_dir_name, img):
  83. global global_yolo_model, global_cls_model
  84. print("predict_img: ", data_dir_name, model_dir_name)
  85. if predict_data_dir_name is None or predict_model_dir_name is None:
  86. return "2个参数为必选项"
  87. cls_list = sorted(os.listdir(os.path.join(data_dir, predict_data_dir_name, 'train')))
  88. print(cls_list)
  89. model_dir_path = os.path.join(model_dir, predict_model_dir_name)
  90. model_path = glob.glob(os.path.join(model_dir_path, "*.pt"))[0]
  91. if global_yolo_model is None:
  92. global_yolo_model = MyOnnxYolo(config.yolo_model_path)
  93. if global_cls_model is None:
  94. global_cls_model = MyModel(model_path)
  95. global_yolo_model.set_result(img)
  96. card_img = global_yolo_model.get_max_img(cls_id=0)
  97. cls_id = global_cls_model.run(card_img)
  98. return cls_list[cls_id], gr.update()
  99. # 创建 Gradio 界面
  100. with gr.Blocks() as iface:
  101. gr.Markdown("# 分类训练")
  102. with gr.Row():
  103. refresh_button = gr.Button("刷新数据列表")
  104. with gr.Tabs():
  105. with gr.TabItem("上传ZIP数据"):
  106. gr.Markdown(
  107. '''
  108. 输入数数据名称, 并上传zip数据, 点击解压zip <br>
  109. 如果是原始数据, 请将目录结构准备为: <br>
  110. ```python
  111. ---数据集名称:
  112. --train:
  113. -分类1:{图片1,图片2, ...},
  114. -分类2:{图片1,图片2, ...},
  115. ......
  116. ```
  117. 之后进行数据处理后训练
  118. '''
  119. )
  120. with gr.Row():
  121. name_input = gr.Textbox(label="输入数据集名称", placeholder="Enter a name for the directory")
  122. file_input = gr.File(label="解压ZIP数据")
  123. error_output = gr.Markdown(visible=False)
  124. submit_btn = gr.Button("解压 ZIP", variant='primary')
  125. result_output = gr.Textbox(label="结果")
  126. with gr.TabItem("预览, 处理数据"):
  127. gr.Markdown(
  128. "预览部分训练数据, 如果数据未处理, 那么进行处理, ")
  129. with gr.Row():
  130. with gr.Column():
  131. preview_data_dir_name = gr.Dropdown(label="选择数据集", choices=data_list,
  132. interactive=True)
  133. preview_button = gr.Button("预览部分图片", variant='primary')
  134. preview_gallery = gr.Gallery(label="部分训练集图片")
  135. with gr.Row():
  136. preprocess_button = gr.Button("处理数据集", variant='primary')
  137. with gr.Row():
  138. preprocess_output = gr.Textbox(label="处理结果")
  139. with gr.TabItem("训练模型"):
  140. gr.Markdown("设置参数训练模型, 注意模型ID不选的话则重新训练")
  141. with gr.Row():
  142. data_dir_name = gr.Dropdown(label="训练的数据集", choices=data_list,
  143. interactive=True)
  144. epoch_slider = gr.Slider(minimum=5, maximum=50, value=10, step=1, label="轮次")
  145. # epoch_display = gr.Markdown("Current epoch: 10")
  146. with gr.Row():
  147. model_dir_name = gr.Dropdown(label="模型ID", choices=model_id_list)
  148. freeze_slider = gr.Slider(minimum=0, maximum=7, value=7, step=1, label="冻结参数层")
  149. with gr.Row():
  150. learn_rate = gr.Number(label="学习率", value=0.01)
  151. momentum = gr.Number(label="动量", value=0.9)
  152. decay = gr.Number(label="衰减率", value=0.7)
  153. # epoch_slider.change(update_epoch_display, inputs=[epoch_slider], outputs=[epoch_display])
  154. train_button = gr.Button("开始训练", variant='primary')
  155. training_result = gr.Textbox(label="训练结果")
  156. with gr.TabItem("预测单张图片"):
  157. data_list = os.listdir(data_dir)
  158. model_id_list = sorted(os.listdir(model_dir))
  159. gr.Markdown("输入一张图片返回预测结果")
  160. with gr.Row():
  161. img = gr.Image(label="上传图片")
  162. with gr.Row():
  163. predict_data_dir_name = gr.Dropdown(label="训练的数据集", choices=data_list,
  164. interactive=True)
  165. predict_model_dir_name = gr.Dropdown(label="模型ID", choices=model_id_list)
  166. predict_button = gr.Button("预测图片", variant='primary')
  167. predict_result = gr.Textbox(label="预测结果")
  168. refresh_button.click(refresh_list,
  169. outputs=[preview_data_dir_name, data_dir_name, predict_data_dir_name,
  170. model_dir_name, predict_model_dir_name])
  171. name_input.change(validate_name, inputs=[name_input], outputs=[error_output, submit_btn])
  172. submit_btn.click(process_zip, inputs=[name_input, file_input], outputs=[result_output, data_dir_name])
  173. preview_button.click(preview_img_dir, inputs=preview_data_dir_name, outputs=preview_gallery)
  174. preprocess_button.click(preprocess_data, inputs=preview_data_dir_name, outputs=preprocess_output)
  175. train_button.click(start_training,
  176. inputs=[epoch_slider, data_dir_name,
  177. model_dir_name, freeze_slider, learn_rate, momentum, decay],
  178. outputs=[training_result])
  179. predict_button.click(predict_img,
  180. inputs=[predict_data_dir_name, predict_model_dir_name, img],
  181. outputs=[predict_result])
  182. # 启动界面
  183. iface.launch(server_name="0.0.0.0")