Jelajahi Sumber

折射分类训练

shan.wan 2 tahun lalu
induk
melakukan
fe8f275b11
1 mengubah file dengan 231 tambahan dan 0 penghapusan
  1. 231 0
      train.py

+ 231 - 0
train.py

@@ -0,0 +1,231 @@
+from __future__ import print_function, division
+import json
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.optim import lr_scheduler
+import torch.backends.cudnn as cudnn
+import numpy as np
+import torchvision
+from torchvision import datasets, models, transforms
+import matplotlib.pyplot as plt
+import time
+import os
+import copy
+import platform
+
+cudnn.benchmark = True
+plt.ion()  # interactive mode
+
+
+def imshow(inp, title=None):
+    inp = inp.numpy().transpose((1, 2, 0))
+    mean = np.array([0.485, 0.456, 0.406])
+    std = np.array([0.229, 0.224, 0.225])
+    inp = std * inp + mean
+    inp = np.clip(inp, 0, 1)
+    plt.imshow(inp)
+    if title is not None:
+        plt.title(title)
+    plt.pause(0.001)  # pause a bit so that plots are updated
+
+
+def print_progress_bar(iteration, total, prefix='', suffix='', decimals=1, length=100, fill='█', print_end="\r"):
+    # 计算完成百分比
+    percent_complete = f"{(100 * (iteration / float(total))):.{decimals}f}"
+    # 计算进度条填充长度
+    filled_length = int(length * iteration // total)
+    # 创建进度条字符串
+    bar = fill * filled_length + '-' * (length - filled_length)
+    # 打印进度条
+    print(f'\r{prefix} |{bar}| {percent_complete}% {suffix}', end=print_end)
+    # 完成时打印新行
+    if iteration == total:
+        print()
+
+
+def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
+    since = time.time()
+
+    best_model_wts = copy.deepcopy(model.state_dict())
+    best_acc = 0.0
+
+    for epoch in range(num_epochs):
+        print(f'Epoch {epoch + 1}/{num_epochs}')
+        print('-' * 10)
+
+        for phase in data_phase:
+            if phase == 'train':
+                model.train()
+            else:
+                model.eval()
+
+            running_loss = 0.0
+            running_corrects = 0
+
+            l = len(dataloaders[phase])
+            print_progress_bar(0, l, prefix='进度:', suffix='完成', length=50)
+
+            # Iterate over data.
+            for i, (inputs, labels) in enumerate(dataloaders[phase]):
+                inputs = inputs.to(device)
+                labels = labels.to(device)
+
+                optimizer.zero_grad()
+
+                with torch.set_grad_enabled(phase == 'train'):
+                    outputs = model(inputs)
+                    _, preds = torch.max(outputs, 1)
+                    loss = criterion(outputs, labels)
+
+                    # backward + optimize only if in training phase
+                    if phase == 'train':
+                        loss.backward()
+                        optimizer.step()
+
+                # statistics
+                running_loss += loss.item() * inputs.size(0)
+                running_corrects += torch.sum(preds == labels.data)
+
+                # 更新进度条
+                print_progress_bar(i + 1, l, prefix='进度:', suffix='完成', length=50)
+
+            if phase == 'train':
+                scheduler.step()
+
+            epoch_loss = running_loss / dataset_sizes[phase]
+            epoch_acc = running_corrects.double() / dataset_sizes[phase]
+
+            print(f'{phase} Loss: {epoch_loss:.6f} Acc: {epoch_acc:.6f}')
+
+            # deep copy the model
+            if phase == 'val' and epoch_acc >= best_acc:
+                best_acc = epoch_acc
+                best_model_wts = copy.deepcopy(model.state_dict())
+                print('copy best model')
+
+        # if (epoch + 1) % 10 == 0:
+        #     print("save temp model in ", epoch + 1)
+        #     torch.save(model.state_dict(), 'card_resnet_temp.pth')
+        #
+        # print()
+
+    time_elapsed = time.time() - since
+    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
+    print(f'Best val Acc: {best_acc:4f}')
+
+    # 加载在验证集中表现最好的模型
+    model.load_state_dict(best_model_wts)
+
+    return model
+
+
+def train(epoch=30, save_path='effcient.pth', load_my_model=False, model_path=None, is_freeze=False, freeze_num=8):
+
+    # 如果不加载训练过的模型则加载预训练模型
+    if load_my_model:
+        model = models.efficientnet_b7()
+        # 修改最后一层
+        model.classifier = torch.nn.Sequential(
+            torch.nn.Linear(in_features=2560, out_features=class_num, bias=False)
+        )
+        model.load_state_dict(torch.load(model_path))
+
+    else:
+        model = models.efficientnet_b7(pretrained=True)
+        # 修改最后一层
+        model.classifier = torch.nn.Sequential(
+            torch.nn.Linear(in_features=2560, out_features=class_num, bias=False)
+        )
+
+    model = model.to(device)
+
+    criterion = nn.CrossEntropyLoss()
+    optimizer_ft = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
+    exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
+
+    # 冻结部分参数
+    if is_freeze:
+        for i, c in enumerate(model.features.children()):
+            if i == freeze_num:
+                break
+            for param in c.parameters():
+                param.requires_grad = False
+
+        for param in model.features.parameters():
+            print(param.requires_grad)
+
+    model = train_model(model, criterion, optimizer_ft,
+                        exp_lr_scheduler, num_epochs=epoch)
+
+    torch.save(model.state_dict(), save_path)
+
+
+if __name__ == '__main__':
+    data_dir = "/media/martin/DATA/refraction"
+    model_path = "/media/martin/DATA/effcient_card_out1945_freeze6.pth"
+
+    # if platform.system() == 'Windows':
+    #     print('这是Windows系统')
+    #     # data_dir = os.path.join('D:', data_dir)
+    #     # model_path = os.path.join('D:', model_path)
+    # elif platform.system() == 'Linux':
+    #     print('这是Linux系统')
+    #     data_dir = os.path.join("/mnt/d", data_dir)
+    #     model_path = os.path.join("/mnt/d", model_path)
+    # else:
+    #     print(platform.system())
+
+    print('dataset: ', data_dir)
+    print('model_path', model_path)
+
+    data_transforms = {
+        'train': transforms.Compose([
+            transforms.Resize((224, 224)),
+            # transforms.RandAugment(),
+            transforms.RandomRotation(30),
+            # transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.4),
+            transforms.RandomVerticalFlip(p=0.5),
+            transforms.RandomHorizontalFlip(),
+            transforms.ToTensor(),
+            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
+        ]),
+        'val': transforms.Compose([
+            transforms.Resize((224, 224)),
+            # transforms.Resize(256),
+            # transforms.CenterCrop(224),
+            transforms.ToTensor(),
+            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
+        ]),
+    }
+
+    data_phase = ['train', 'val']
+    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
+                      for x in data_phase}
+    dataloaders = {
+        x: torch.utils.data.DataLoader(image_datasets[x], batch_size=8, shuffle=True, num_workers=16, pin_memory=True)
+        for x in data_phase}
+    dataset_sizes = {x: len(image_datasets[x]) for x in data_phase}
+    class_names = image_datasets['train'].classes
+    class_to_idx = {idx:class_name for idx, class_name in enumerate(class_names)}
+    # 将字典写入JSON文件
+    with open('class_indices.json', 'w') as json_file:
+        json.dump(class_to_idx, json_file)
+    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+    class_num = len(class_names)
+    print('class:', class_num)
+
+    '''
+    epoch: 训练次数
+    save_path: 模型保存路径
+    load_my_model: 是否加载训练过的模型
+    model_path: 加载的模型路径 
+    is_freeze: 是否冻结模型部分层数
+    freeze_num: 冻结层数
+    '''
+    data_phase = ['train', 'val']
+    # 数据集路径在本文件上面
+    train(epoch=15, save_path='refraction.pth',
+          load_my_model=False, model_path=model_path,
+          is_freeze=False)