train.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. from __future__ import print_function, division
  2. import json
  3. import torch
  4. import torch.nn as nn
  5. import torch.optim as optim
  6. from torch.optim import lr_scheduler
  7. import torch.backends.cudnn as cudnn
  8. import numpy as np
  9. import torchvision
  10. from torchvision import datasets, models, transforms
  11. import matplotlib.pyplot as plt
  12. import time
  13. import os
  14. import copy
  15. import platform
  16. cudnn.benchmark = True
  17. plt.ion() # interactive mode
  18. def imshow(inp, title=None):
  19. inp = inp.numpy().transpose((1, 2, 0))
  20. mean = np.array([0.485, 0.456, 0.406])
  21. std = np.array([0.229, 0.224, 0.225])
  22. inp = std * inp + mean
  23. inp = np.clip(inp, 0, 1)
  24. plt.imshow(inp)
  25. if title is not None:
  26. plt.title(title)
  27. plt.pause(0.001) # pause a bit so that plots are updated
  28. def print_progress_bar(iteration, total, prefix='', suffix='', decimals=1, length=100, fill='█', print_end="\r"):
  29. # 计算完成百分比
  30. percent_complete = f"{(100 * (iteration / float(total))):.{decimals}f}"
  31. # 计算进度条填充长度
  32. filled_length = int(length * iteration // total)
  33. # 创建进度条字符串
  34. bar = fill * filled_length + '-' * (length - filled_length)
  35. # 打印进度条
  36. print(f'\r{prefix} |{bar}| {percent_complete}% {suffix}', end=print_end)
  37. # 完成时打印新行
  38. if iteration == total:
  39. print()
  40. def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
  41. since = time.time()
  42. best_model_wts = copy.deepcopy(model.state_dict())
  43. best_acc = 0.0
  44. for epoch in range(num_epochs):
  45. print(f'Epoch {epoch + 1}/{num_epochs}')
  46. print('-' * 10)
  47. for phase in data_phase:
  48. if phase == 'train':
  49. model.train()
  50. else:
  51. model.eval()
  52. running_loss = 0.0
  53. running_corrects = 0
  54. l = len(dataloaders[phase])
  55. print_progress_bar(0, l, prefix='进度:', suffix='完成', length=50)
  56. # Iterate over data.
  57. for i, (inputs, labels) in enumerate(dataloaders[phase]):
  58. inputs = inputs.to(device)
  59. labels = labels.to(device)
  60. optimizer.zero_grad()
  61. with torch.set_grad_enabled(phase == 'train'):
  62. outputs = model(inputs)
  63. _, preds = torch.max(outputs, 1)
  64. loss = criterion(outputs, labels)
  65. # backward + optimize only if in training phase
  66. if phase == 'train':
  67. loss.backward()
  68. optimizer.step()
  69. # statistics
  70. running_loss += loss.item() * inputs.size(0)
  71. running_corrects += torch.sum(preds == labels.data)
  72. # 更新进度条
  73. print_progress_bar(i + 1, l, prefix='进度:', suffix='完成', length=50)
  74. if phase == 'train':
  75. scheduler.step()
  76. epoch_loss = running_loss / dataset_sizes[phase]
  77. epoch_acc = running_corrects.double() / dataset_sizes[phase]
  78. print(f'{phase} Loss: {epoch_loss:.6f} Acc: {epoch_acc:.6f}')
  79. # deep copy the model
  80. if phase == 'val' and epoch_acc >= best_acc:
  81. best_acc = epoch_acc
  82. best_model_wts = copy.deepcopy(model.state_dict())
  83. print('copy best model')
  84. # if (epoch + 1) % 10 == 0:
  85. # print("save temp model in ", epoch + 1)
  86. # torch.save(model.state_dict(), 'card_resnet_temp.pth')
  87. #
  88. # print()
  89. time_elapsed = time.time() - since
  90. print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
  91. print(f'Best val Acc: {best_acc:4f}')
  92. # 加载在验证集中表现最好的模型
  93. model.load_state_dict(best_model_wts)
  94. return model
  95. def train(epoch=30, save_path='effcient.pth', load_my_model=False, model_path=None, is_freeze=False, freeze_num=8):
  96. # 如果不加载训练过的模型则加载预训练模型
  97. if load_my_model:
  98. model = models.efficientnet_b7()
  99. # 修改最后一层
  100. model.classifier = torch.nn.Sequential(
  101. torch.nn.Linear(in_features=2560, out_features=class_num, bias=False)
  102. )
  103. model.load_state_dict(torch.load(model_path))
  104. else:
  105. model = models.efficientnet_b7(pretrained=True)
  106. # 修改最后一层
  107. model.classifier = torch.nn.Sequential(
  108. torch.nn.Linear(in_features=2560, out_features=class_num, bias=False)
  109. )
  110. model = model.to(device)
  111. criterion = nn.CrossEntropyLoss()
  112. optimizer_ft = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
  113. exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
  114. # 冻结部分参数
  115. if is_freeze:
  116. for i, c in enumerate(model.features.children()):
  117. if i == freeze_num:
  118. break
  119. for param in c.parameters():
  120. param.requires_grad = False
  121. for param in model.features.parameters():
  122. print(param.requires_grad)
  123. model = train_model(model, criterion, optimizer_ft,
  124. exp_lr_scheduler, num_epochs=epoch)
  125. torch.save(model.state_dict(), save_path)
  126. if __name__ == '__main__':
  127. data_dir = "/media/martin/DATA/refraction"
  128. model_path = "/media/martin/DATA/effcient_card_out1945_freeze6.pth"
  129. # if platform.system() == 'Windows':
  130. # print('这是Windows系统')
  131. # # data_dir = os.path.join('D:', data_dir)
  132. # # model_path = os.path.join('D:', model_path)
  133. # elif platform.system() == 'Linux':
  134. # print('这是Linux系统')
  135. # data_dir = os.path.join("/mnt/d", data_dir)
  136. # model_path = os.path.join("/mnt/d", model_path)
  137. # else:
  138. # print(platform.system())
  139. print('dataset: ', data_dir)
  140. print('model_path', model_path)
  141. data_transforms = {
  142. 'train': transforms.Compose([
  143. transforms.Resize((224, 224)),
  144. # transforms.RandAugment(),
  145. transforms.RandomRotation(30),
  146. # transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.4),
  147. transforms.RandomVerticalFlip(p=0.5),
  148. transforms.RandomHorizontalFlip(),
  149. transforms.ToTensor(),
  150. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  151. ]),
  152. 'val': transforms.Compose([
  153. transforms.Resize((224, 224)),
  154. # transforms.Resize(256),
  155. # transforms.CenterCrop(224),
  156. transforms.ToTensor(),
  157. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  158. ]),
  159. }
  160. data_phase = ['train', 'val']
  161. image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
  162. for x in data_phase}
  163. dataloaders = {
  164. x: torch.utils.data.DataLoader(image_datasets[x], batch_size=8, shuffle=True, num_workers=16, pin_memory=True)
  165. for x in data_phase}
  166. dataset_sizes = {x: len(image_datasets[x]) for x in data_phase}
  167. class_names = image_datasets['train'].classes
  168. class_to_idx = {idx:class_name for idx, class_name in enumerate(class_names)}
  169. # 将字典写入JSON文件
  170. with open('class_indices.json', 'w') as json_file:
  171. json.dump(class_to_idx, json_file)
  172. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  173. class_num = len(class_names)
  174. print('class:', class_num)
  175. '''
  176. epoch: 训练次数
  177. save_path: 模型保存路径
  178. load_my_model: 是否加载训练过的模型
  179. model_path: 加载的模型路径
  180. is_freeze: 是否冻结模型部分层数
  181. freeze_num: 冻结层数
  182. '''
  183. data_phase = ['train', 'val']
  184. # 数据集路径在本文件上面
  185. train(epoch=15, save_path='refraction.pth',
  186. load_my_model=False, model_path=model_path,
  187. is_freeze=False)