|
|
@@ -0,0 +1,231 @@
|
|
|
+from __future__ import print_function, division
|
|
|
+import json
|
|
|
+import torch
|
|
|
+import torch.nn as nn
|
|
|
+import torch.optim as optim
|
|
|
+from torch.optim import lr_scheduler
|
|
|
+import torch.backends.cudnn as cudnn
|
|
|
+import numpy as np
|
|
|
+import torchvision
|
|
|
+from torchvision import datasets, models, transforms
|
|
|
+import matplotlib.pyplot as plt
|
|
|
+import time
|
|
|
+import os
|
|
|
+import copy
|
|
|
+import platform
|
|
|
+
|
|
|
+cudnn.benchmark = True
|
|
|
+plt.ion() # interactive mode
|
|
|
+
|
|
|
+
|
|
|
+def imshow(inp, title=None):
|
|
|
+ inp = inp.numpy().transpose((1, 2, 0))
|
|
|
+ mean = np.array([0.485, 0.456, 0.406])
|
|
|
+ std = np.array([0.229, 0.224, 0.225])
|
|
|
+ inp = std * inp + mean
|
|
|
+ inp = np.clip(inp, 0, 1)
|
|
|
+ plt.imshow(inp)
|
|
|
+ if title is not None:
|
|
|
+ plt.title(title)
|
|
|
+ plt.pause(0.001) # pause a bit so that plots are updated
|
|
|
+
|
|
|
+
|
|
|
+def print_progress_bar(iteration, total, prefix='', suffix='', decimals=1, length=100, fill='█', print_end="\r"):
|
|
|
+ # 计算完成百分比
|
|
|
+ percent_complete = f"{(100 * (iteration / float(total))):.{decimals}f}"
|
|
|
+ # 计算进度条填充长度
|
|
|
+ filled_length = int(length * iteration // total)
|
|
|
+ # 创建进度条字符串
|
|
|
+ bar = fill * filled_length + '-' * (length - filled_length)
|
|
|
+ # 打印进度条
|
|
|
+ print(f'\r{prefix} |{bar}| {percent_complete}% {suffix}', end=print_end)
|
|
|
+ # 完成时打印新行
|
|
|
+ if iteration == total:
|
|
|
+ print()
|
|
|
+
|
|
|
+
|
|
|
+def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
|
|
|
+ since = time.time()
|
|
|
+
|
|
|
+ best_model_wts = copy.deepcopy(model.state_dict())
|
|
|
+ best_acc = 0.0
|
|
|
+
|
|
|
+ for epoch in range(num_epochs):
|
|
|
+ print(f'Epoch {epoch + 1}/{num_epochs}')
|
|
|
+ print('-' * 10)
|
|
|
+
|
|
|
+ for phase in data_phase:
|
|
|
+ if phase == 'train':
|
|
|
+ model.train()
|
|
|
+ else:
|
|
|
+ model.eval()
|
|
|
+
|
|
|
+ running_loss = 0.0
|
|
|
+ running_corrects = 0
|
|
|
+
|
|
|
+ l = len(dataloaders[phase])
|
|
|
+ print_progress_bar(0, l, prefix='进度:', suffix='完成', length=50)
|
|
|
+
|
|
|
+ # Iterate over data.
|
|
|
+ for i, (inputs, labels) in enumerate(dataloaders[phase]):
|
|
|
+ inputs = inputs.to(device)
|
|
|
+ labels = labels.to(device)
|
|
|
+
|
|
|
+ optimizer.zero_grad()
|
|
|
+
|
|
|
+ with torch.set_grad_enabled(phase == 'train'):
|
|
|
+ outputs = model(inputs)
|
|
|
+ _, preds = torch.max(outputs, 1)
|
|
|
+ loss = criterion(outputs, labels)
|
|
|
+
|
|
|
+ # backward + optimize only if in training phase
|
|
|
+ if phase == 'train':
|
|
|
+ loss.backward()
|
|
|
+ optimizer.step()
|
|
|
+
|
|
|
+ # statistics
|
|
|
+ running_loss += loss.item() * inputs.size(0)
|
|
|
+ running_corrects += torch.sum(preds == labels.data)
|
|
|
+
|
|
|
+ # 更新进度条
|
|
|
+ print_progress_bar(i + 1, l, prefix='进度:', suffix='完成', length=50)
|
|
|
+
|
|
|
+ if phase == 'train':
|
|
|
+ scheduler.step()
|
|
|
+
|
|
|
+ epoch_loss = running_loss / dataset_sizes[phase]
|
|
|
+ epoch_acc = running_corrects.double() / dataset_sizes[phase]
|
|
|
+
|
|
|
+ print(f'{phase} Loss: {epoch_loss:.6f} Acc: {epoch_acc:.6f}')
|
|
|
+
|
|
|
+ # deep copy the model
|
|
|
+ if phase == 'val' and epoch_acc >= best_acc:
|
|
|
+ best_acc = epoch_acc
|
|
|
+ best_model_wts = copy.deepcopy(model.state_dict())
|
|
|
+ print('copy best model')
|
|
|
+
|
|
|
+ # if (epoch + 1) % 10 == 0:
|
|
|
+ # print("save temp model in ", epoch + 1)
|
|
|
+ # torch.save(model.state_dict(), 'card_resnet_temp.pth')
|
|
|
+ #
|
|
|
+ # print()
|
|
|
+
|
|
|
+ time_elapsed = time.time() - since
|
|
|
+ print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
|
|
|
+ print(f'Best val Acc: {best_acc:4f}')
|
|
|
+
|
|
|
+ # 加载在验证集中表现最好的模型
|
|
|
+ model.load_state_dict(best_model_wts)
|
|
|
+
|
|
|
+ return model
|
|
|
+
|
|
|
+
|
|
|
+def train(epoch=30, save_path='effcient.pth', load_my_model=False, model_path=None, is_freeze=False, freeze_num=8):
|
|
|
+
|
|
|
+ # 如果不加载训练过的模型则加载预训练模型
|
|
|
+ if load_my_model:
|
|
|
+ model = models.efficientnet_b7()
|
|
|
+ # 修改最后一层
|
|
|
+ model.classifier = torch.nn.Sequential(
|
|
|
+ torch.nn.Linear(in_features=2560, out_features=class_num, bias=False)
|
|
|
+ )
|
|
|
+ model.load_state_dict(torch.load(model_path))
|
|
|
+
|
|
|
+ else:
|
|
|
+ model = models.efficientnet_b7(pretrained=True)
|
|
|
+ # 修改最后一层
|
|
|
+ model.classifier = torch.nn.Sequential(
|
|
|
+ torch.nn.Linear(in_features=2560, out_features=class_num, bias=False)
|
|
|
+ )
|
|
|
+
|
|
|
+ model = model.to(device)
|
|
|
+
|
|
|
+ criterion = nn.CrossEntropyLoss()
|
|
|
+ optimizer_ft = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
|
|
|
+ exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
|
|
|
+
|
|
|
+ # 冻结部分参数
|
|
|
+ if is_freeze:
|
|
|
+ for i, c in enumerate(model.features.children()):
|
|
|
+ if i == freeze_num:
|
|
|
+ break
|
|
|
+ for param in c.parameters():
|
|
|
+ param.requires_grad = False
|
|
|
+
|
|
|
+ for param in model.features.parameters():
|
|
|
+ print(param.requires_grad)
|
|
|
+
|
|
|
+ model = train_model(model, criterion, optimizer_ft,
|
|
|
+ exp_lr_scheduler, num_epochs=epoch)
|
|
|
+
|
|
|
+ torch.save(model.state_dict(), save_path)
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ data_dir = "/media/martin/DATA/refraction"
|
|
|
+ model_path = "/media/martin/DATA/effcient_card_out1945_freeze6.pth"
|
|
|
+
|
|
|
+ # if platform.system() == 'Windows':
|
|
|
+ # print('这是Windows系统')
|
|
|
+ # # data_dir = os.path.join('D:', data_dir)
|
|
|
+ # # model_path = os.path.join('D:', model_path)
|
|
|
+ # elif platform.system() == 'Linux':
|
|
|
+ # print('这是Linux系统')
|
|
|
+ # data_dir = os.path.join("/mnt/d", data_dir)
|
|
|
+ # model_path = os.path.join("/mnt/d", model_path)
|
|
|
+ # else:
|
|
|
+ # print(platform.system())
|
|
|
+
|
|
|
+ print('dataset: ', data_dir)
|
|
|
+ print('model_path', model_path)
|
|
|
+
|
|
|
+ data_transforms = {
|
|
|
+ 'train': transforms.Compose([
|
|
|
+ transforms.Resize((224, 224)),
|
|
|
+ # transforms.RandAugment(),
|
|
|
+ transforms.RandomRotation(30),
|
|
|
+ # transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.4),
|
|
|
+ transforms.RandomVerticalFlip(p=0.5),
|
|
|
+ transforms.RandomHorizontalFlip(),
|
|
|
+ transforms.ToTensor(),
|
|
|
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
|
+ ]),
|
|
|
+ 'val': transforms.Compose([
|
|
|
+ transforms.Resize((224, 224)),
|
|
|
+ # transforms.Resize(256),
|
|
|
+ # transforms.CenterCrop(224),
|
|
|
+ transforms.ToTensor(),
|
|
|
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
|
+ ]),
|
|
|
+ }
|
|
|
+
|
|
|
+ data_phase = ['train', 'val']
|
|
|
+ image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
|
|
|
+ for x in data_phase}
|
|
|
+ dataloaders = {
|
|
|
+ x: torch.utils.data.DataLoader(image_datasets[x], batch_size=8, shuffle=True, num_workers=16, pin_memory=True)
|
|
|
+ for x in data_phase}
|
|
|
+ dataset_sizes = {x: len(image_datasets[x]) for x in data_phase}
|
|
|
+ class_names = image_datasets['train'].classes
|
|
|
+ class_to_idx = {idx:class_name for idx, class_name in enumerate(class_names)}
|
|
|
+ # 将字典写入JSON文件
|
|
|
+ with open('class_indices.json', 'w') as json_file:
|
|
|
+ json.dump(class_to_idx, json_file)
|
|
|
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
+
|
|
|
+ class_num = len(class_names)
|
|
|
+ print('class:', class_num)
|
|
|
+
|
|
|
+ '''
|
|
|
+ epoch: 训练次数
|
|
|
+ save_path: 模型保存路径
|
|
|
+ load_my_model: 是否加载训练过的模型
|
|
|
+ model_path: 加载的模型路径
|
|
|
+ is_freeze: 是否冻结模型部分层数
|
|
|
+ freeze_num: 冻结层数
|
|
|
+ '''
|
|
|
+ data_phase = ['train', 'val']
|
|
|
+ # 数据集路径在本文件上面
|
|
|
+ train(epoch=15, save_path='refraction.pth',
|
|
|
+ load_my_model=False, model_path=model_path,
|
|
|
+ is_freeze=False)
|