MyModel.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import torch
  2. import torchvision.models as models
  3. import torchvision.transforms as transforms
  4. from PIL import Image
  5. import timm
  6. class MyModel:
  7. def __init__(self, model_dict_path):
  8. self.norm_mean = [0.485, 0.456, 0.406]
  9. self.norm_std = [0.229, 0.224, 0.225]
  10. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  11. # self.model = models.resnet50(pretrained=True)
  12. self.model = timm.create_model('resnet50', num_classes=2048, pretrained=True)
  13. self.model.eval()
  14. # 自定义模型
  15. # print(list(self.model.children()))
  16. features = list(self.model.children())[:-1] # 去掉全连接层
  17. self.model = torch.nn.Sequential(*features).to(self.device)
  18. # self.model.to(self.device)
  19. def inference_transform(self):
  20. inference_transform = transforms.Compose([
  21. transforms.Resize((256, 256)),
  22. transforms.CenterCrop(224),
  23. transforms.ToTensor(),
  24. transforms.Normalize(self.norm_mean, self.norm_std),
  25. ])
  26. return inference_transform
  27. def img_transform(self, img_rgb, transform=None):
  28. # 将数据转换为模型读取的形式
  29. if transform is None:
  30. raise ValueError("找不到transform!必须有transform对img进行处理")
  31. img_t = transform(img_rgb)
  32. return img_t
  33. def get_model(self):
  34. return self.model
  35. # 输出图片路径或者cv2格式的图片数据
  36. def predict(self, img):
  37. if type(img) == type('path'):
  38. img = Image.open(img).convert('RGB')
  39. transform = self.inference_transform()
  40. img_tensor = transform(img)
  41. img_tensor.unsqueeze_(0)
  42. img_tensor = img_tensor.to(self.device)
  43. # print(img.shape)
  44. with torch.no_grad():
  45. outputs = self.model(img_tensor)
  46. return outputs.reshape(2048).cpu().numpy()