| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231 |
- from __future__ import print_function, division
- import json
- import torch
- import torch.nn as nn
- import torch.optim as optim
- from torch.optim import lr_scheduler
- import torch.backends.cudnn as cudnn
- import numpy as np
- import torchvision
- from torchvision import datasets, models, transforms
- import matplotlib.pyplot as plt
- import time
- import os
- import copy
- import platform
- cudnn.benchmark = True
- plt.ion() # interactive mode
- def imshow(inp, title=None):
- inp = inp.numpy().transpose((1, 2, 0))
- mean = np.array([0.485, 0.456, 0.406])
- std = np.array([0.229, 0.224, 0.225])
- inp = std * inp + mean
- inp = np.clip(inp, 0, 1)
- plt.imshow(inp)
- if title is not None:
- plt.title(title)
- plt.pause(0.001) # pause a bit so that plots are updated
- def print_progress_bar(iteration, total, prefix='', suffix='', decimals=1, length=100, fill='█', print_end="\r"):
- # 计算完成百分比
- percent_complete = f"{(100 * (iteration / float(total))):.{decimals}f}"
- # 计算进度条填充长度
- filled_length = int(length * iteration // total)
- # 创建进度条字符串
- bar = fill * filled_length + '-' * (length - filled_length)
- # 打印进度条
- print(f'\r{prefix} |{bar}| {percent_complete}% {suffix}', end=print_end)
- # 完成时打印新行
- if iteration == total:
- print()
- def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
- since = time.time()
- best_model_wts = copy.deepcopy(model.state_dict())
- best_acc = 0.0
- for epoch in range(num_epochs):
- print(f'Epoch {epoch + 1}/{num_epochs}')
- print('-' * 10)
- for phase in data_phase:
- if phase == 'train':
- model.train()
- else:
- model.eval()
- running_loss = 0.0
- running_corrects = 0
- l = len(dataloaders[phase])
- print_progress_bar(0, l, prefix='进度:', suffix='完成', length=50)
- # Iterate over data.
- for i, (inputs, labels) in enumerate(dataloaders[phase]):
- inputs = inputs.to(device)
- labels = labels.to(device)
- optimizer.zero_grad()
- with torch.set_grad_enabled(phase == 'train'):
- outputs = model(inputs)
- _, preds = torch.max(outputs, 1)
- loss = criterion(outputs, labels)
- # backward + optimize only if in training phase
- if phase == 'train':
- loss.backward()
- optimizer.step()
- # statistics
- running_loss += loss.item() * inputs.size(0)
- running_corrects += torch.sum(preds == labels.data)
- # 更新进度条
- print_progress_bar(i + 1, l, prefix='进度:', suffix='完成', length=50)
- if phase == 'train':
- scheduler.step()
- epoch_loss = running_loss / dataset_sizes[phase]
- epoch_acc = running_corrects.double() / dataset_sizes[phase]
- print(f'{phase} Loss: {epoch_loss:.6f} Acc: {epoch_acc:.6f}')
- # deep copy the model
- if phase == 'val' and epoch_acc >= best_acc:
- best_acc = epoch_acc
- best_model_wts = copy.deepcopy(model.state_dict())
- print('copy best model')
- # if (epoch + 1) % 10 == 0:
- # print("save temp model in ", epoch + 1)
- # torch.save(model.state_dict(), 'card_resnet_temp.pth')
- #
- # print()
- time_elapsed = time.time() - since
- print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
- print(f'Best val Acc: {best_acc:4f}')
- # 加载在验证集中表现最好的模型
- model.load_state_dict(best_model_wts)
- return model
- def train(epoch=30, save_path='effcient.pth', load_my_model=False, model_path=None, is_freeze=False, freeze_num=8):
- # 如果不加载训练过的模型则加载预训练模型
- if load_my_model:
- model = models.efficientnet_b7()
- # 修改最后一层
- model.classifier = torch.nn.Sequential(
- torch.nn.Linear(in_features=2560, out_features=class_num, bias=False)
- )
- model.load_state_dict(torch.load(model_path))
- else:
- model = models.efficientnet_b7(pretrained=True)
- # 修改最后一层
- model.classifier = torch.nn.Sequential(
- torch.nn.Linear(in_features=2560, out_features=class_num, bias=False)
- )
- model = model.to(device)
- criterion = nn.CrossEntropyLoss()
- optimizer_ft = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
- exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
- # 冻结部分参数
- if is_freeze:
- for i, c in enumerate(model.features.children()):
- if i == freeze_num:
- break
- for param in c.parameters():
- param.requires_grad = False
- for param in model.features.parameters():
- print(param.requires_grad)
- model = train_model(model, criterion, optimizer_ft,
- exp_lr_scheduler, num_epochs=epoch)
- torch.save(model.state_dict(), save_path)
- if __name__ == '__main__':
- data_dir = "/media/martin/DATA/refraction"
- model_path = "/media/martin/DATA/effcient_card_out1945_freeze6.pth"
- # if platform.system() == 'Windows':
- # print('这是Windows系统')
- # # data_dir = os.path.join('D:', data_dir)
- # # model_path = os.path.join('D:', model_path)
- # elif platform.system() == 'Linux':
- # print('这是Linux系统')
- # data_dir = os.path.join("/mnt/d", data_dir)
- # model_path = os.path.join("/mnt/d", model_path)
- # else:
- # print(platform.system())
- print('dataset: ', data_dir)
- print('model_path', model_path)
- data_transforms = {
- 'train': transforms.Compose([
- transforms.Resize((224, 224)),
- # transforms.RandAugment(),
- transforms.RandomRotation(30),
- # transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.4),
- transforms.RandomVerticalFlip(p=0.5),
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
- ]),
- 'val': transforms.Compose([
- transforms.Resize((224, 224)),
- # transforms.Resize(256),
- # transforms.CenterCrop(224),
- transforms.ToTensor(),
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
- ]),
- }
- data_phase = ['train', 'val']
- image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
- for x in data_phase}
- dataloaders = {
- x: torch.utils.data.DataLoader(image_datasets[x], batch_size=8, shuffle=True, num_workers=16, pin_memory=True)
- for x in data_phase}
- dataset_sizes = {x: len(image_datasets[x]) for x in data_phase}
- class_names = image_datasets['train'].classes
- class_to_idx = {idx:class_name for idx, class_name in enumerate(class_names)}
- # 将字典写入JSON文件
- with open('class_indices.json', 'w') as json_file:
- json.dump(class_to_idx, json_file)
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- class_num = len(class_names)
- print('class:', class_num)
- '''
- epoch: 训练次数
- save_path: 模型保存路径
- load_my_model: 是否加载训练过的模型
- model_path: 加载的模型路径
- is_freeze: 是否冻结模型部分层数
- freeze_num: 冻结层数
- '''
- data_phase = ['train', 'val']
- # 数据集路径在本文件上面
- train(epoch=15, save_path='refraction.pth',
- load_my_model=False, model_path=model_path,
- is_freeze=False)
|