Train_resnet.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. from __future__ import print_function, division
  2. import shutil
  3. import torch
  4. import torch.nn as nn
  5. import torch.optim as optim
  6. from torch.optim import lr_scheduler
  7. import torch.backends.cudnn as cudnn
  8. import numpy as np
  9. from torchvision import datasets, models, transforms
  10. import time
  11. import os
  12. import zipfile
  13. import copy
  14. import platform
  15. from torch.utils.tensorboard import SummaryWriter
  16. cudnn.benchmark = True
  17. data_phase = ['train', 'val']
  18. database_dir = './data'
  19. output_dir = './runs' # 模型保存和日志备份大目录
  20. newest_log = './newest_log' # 最新日志保存目录
  21. log_port = 6667 # tensorboard日志端口
  22. writer: SummaryWriter
  23. # 用异步接收请求判断
  24. task_list = ['train', 'data_process']
  25. running_task = set()
  26. exit_flag = False
  27. def print_progress_bar(iteration, total, prefix='', suffix='', decimals=1, length=100, fill='█', print_end="\r"):
  28. # 计算完成百分比
  29. percent_complete = f"{(100 * (iteration / float(total))):.{decimals}f}"
  30. # 计算进度条填充长度
  31. filled_length = int(length * iteration // total)
  32. # 创建进度条字符串
  33. bar = fill * filled_length + '-' * (length - filled_length)
  34. # 打印进度条
  35. print(f'\r{prefix} |{bar}| {percent_complete}% {suffix}', end=print_end)
  36. # 完成时打印新行
  37. if iteration == total:
  38. print()
  39. # 备份log
  40. def move_log(model_save_dir):
  41. log_name = os.listdir(newest_log)[0]
  42. log_path = os.path.join(newest_log, log_name)
  43. save_path = os.path.join(model_save_dir, log_name)
  44. shutil.copy(log_path, save_path)
  45. print('log 已备份')
  46. def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
  47. since = time.time()
  48. best_model_wts = copy.deepcopy(model.state_dict())
  49. best_acc = 0.0
  50. for epoch in range(num_epochs):
  51. print(f'Epoch {epoch + 1}/{num_epochs}')
  52. print('-' * 10)
  53. for phase in data_phase:
  54. if phase == 'train':
  55. model.train()
  56. else:
  57. model.eval()
  58. running_loss = 0.0
  59. running_corrects = 0
  60. l = len(dataloaders[phase])
  61. print_progress_bar(0, l, prefix='进度:', suffix='完成', length=50)
  62. # Iterate over data.
  63. for i, (inputs, labels) in enumerate(dataloaders[phase]):
  64. inputs = inputs.to(device)
  65. labels = labels.to(device)
  66. optimizer.zero_grad()
  67. with torch.set_grad_enabled(phase == 'train'):
  68. outputs = model(inputs)
  69. _, preds = torch.max(outputs, 1)
  70. loss = criterion(outputs, labels)
  71. # backward + optimize only if in training phase
  72. if phase == 'train':
  73. loss.backward()
  74. optimizer.step()
  75. # statistics
  76. running_loss += loss.item() * inputs.size(0)
  77. running_corrects += torch.sum(preds == labels.data)
  78. # 更新进度条
  79. print_progress_bar(i + 1, l, prefix='进度:', suffix='完成', length=50)
  80. if phase == 'train':
  81. scheduler.step()
  82. epoch_loss = running_loss / dataset_sizes[phase]
  83. epoch_acc = running_corrects.double() / dataset_sizes[phase]
  84. writer.add_scalar(phase + " loss", epoch_loss, epoch + 1)
  85. writer.add_scalar(phase + " accuracy", epoch_acc, epoch + 1)
  86. print(f'{phase} Loss: {epoch_loss:.6f} Acc: {epoch_acc:.6f}')
  87. # deep copy the model
  88. if phase == 'val' and epoch_acc >= best_acc:
  89. best_acc = epoch_acc
  90. best_model_wts = copy.deepcopy(model.state_dict())
  91. print('copy best model')
  92. print()
  93. time_elapsed = time.time() - since
  94. print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
  95. print(f'Best val Acc: {best_acc:4f}')
  96. # 加载在验证集中表现最好的模型
  97. model.load_state_dict(best_model_wts)
  98. return model
  99. def train(epoch=30, save_path='resnet.pt', model_path=None,
  100. freeze=7, learn_rate=0.01, momentum=0.9, decay=0.7):
  101. # 如果不加载训练过的模型则加载预训练模型
  102. if model_path is None or model_path == '':
  103. model = models.resnet50(pretrained=True)
  104. # 修改最后一层
  105. num_features = model.fc.in_features
  106. model.fc = nn.Linear(num_features, class_num)
  107. else:
  108. model = torch.load(model_path)
  109. old_cls_num = model.fc.out_features
  110. if class_num == old_cls_num:
  111. print('分类头适合, 进行训练')
  112. else:
  113. # 修改最后一层
  114. num_features = model.fc.in_features
  115. model.fc = nn.Linear(num_features, class_num)
  116. print(f"修改分类头: {old_cls_num} --> {class_num}")
  117. model = model.to(device)
  118. criterion = nn.CrossEntropyLoss()
  119. optimizer_ft = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
  120. exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
  121. # 冻结部分参数
  122. for i, c in enumerate(model.children()):
  123. if i == freeze:
  124. break
  125. for param in c.parameters():
  126. param.requires_grad = False
  127. for param in model.parameters():
  128. print(param.requires_grad)
  129. model = train_model(model, criterion, optimizer_ft,
  130. exp_lr_scheduler, num_epochs=epoch)
  131. torch.save(model, save_path)
  132. '''
  133. epoch: 训练次数
  134. save_path: 模型保存路径
  135. model_path: 加载的模型路径
  136. freeze_num: 冻结层数
  137. '''
  138. def load_param(epoch, data_dir, model_path, freeze, learn_rate, momentum, decay):
  139. global device, class_num, dataloaders, dataset_sizes, output_dir, writer
  140. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  141. data_transforms = {
  142. 'train': transforms.Compose([
  143. transforms.Resize((224, 224)),
  144. transforms.RandAugment(),
  145. transforms.RandomPerspective(distortion_scale=0.5, p=0.5),
  146. transforms.ColorJitter(brightness=0.2, contrast=0.2),
  147. transforms.RandomApply([transforms.GaussianBlur(5)], p=0.3),
  148. transforms.ToTensor(),
  149. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  150. ]),
  151. 'val': transforms.Compose([
  152. transforms.Resize((224, 224)),
  153. transforms.ToTensor(),
  154. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  155. ]),
  156. }
  157. image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
  158. for x in data_phase}
  159. dataloaders = {
  160. x: torch.utils.data.DataLoader(image_datasets[x], batch_size=32, shuffle=True, num_workers=16, pin_memory=True)
  161. for x in data_phase}
  162. dataset_sizes = {x: len(image_datasets[x]) for x in data_phase}
  163. class_names = image_datasets['train'].classes
  164. class_num = len(class_names)
  165. print('class:', class_num)
  166. print(f"输入参数: 训练次数: {epoch}, 模型路径: {model_path}")
  167. model_id = len(os.listdir(output_dir)) + 1
  168. model_save_dir = os.path.join(output_dir, str(model_id))
  169. if not os.path.exists(model_save_dir):
  170. os.mkdir(model_save_dir)
  171. # 删除旧的log
  172. if len(os.listdir(newest_log)) > 0:
  173. os.remove(os.path.join(newest_log, os.listdir(newest_log)[0]))
  174. writer = SummaryWriter(newest_log)
  175. writer.add_text('model', "model id: " + str(model_id))
  176. save_path = os.path.join(model_save_dir, 'resnet50_out' + str(class_num) + '.pt')
  177. train(epoch=epoch,
  178. save_path=save_path,
  179. model_path=model_path,
  180. freeze=freeze,
  181. learn_rate=learn_rate,
  182. momentum=momentum,
  183. decay=decay)
  184. writer.flush()
  185. writer.close()
  186. move_log(model_save_dir)