ResnetTest.py 254 B

123456789
  1. import torch
  2. import torchvision.models as models
  3. model = models.resnet50(pretrained=False)
  4. model.fc = torch.nn.Linear(in_features=2048, out_features=314)
  5. model.load_state_dict(torch.load(r"D:\Code\ML\model\card_cls\res_card_freeze2.pth"))
  6. print(model)