MyModel2.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import torch
  2. import torchvision.models as models
  3. import torchvision.transforms as transforms
  4. import cv2
  5. from PIL import Image
  6. import numpy as np
  7. import timm
  8. class MyModel:
  9. def __init__(self, model_dict_path, out_features=2048):
  10. self.out_features = out_features
  11. self.norm_mean = [0.485, 0.456, 0.406]
  12. self.norm_std = [0.229, 0.224, 0.225]
  13. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  14. self.model = models.resnet50(pretrained=False)
  15. self.model.fc = torch.nn.Linear(in_features=2048, out_features=self.out_features)
  16. self.model.load_state_dict(torch.load(model_dict_path, map_location=self.device))
  17. # self.model = timm.create_model('resnet50', num_classes=2048, pretrained=True)
  18. self.model.eval()
  19. # 自定义模型
  20. # print(list(self.model.children()))
  21. features = list(self.model.children())[:-1] # 去掉dropout 和 Linear
  22. self.model = torch.nn.Sequential(*features).to(self.device)
  23. # self.model.to(self.device)
  24. def inference_transform(self):
  25. inference_transform = transforms.Compose([
  26. transforms.Resize((224, 224)),
  27. transforms.ToTensor(),
  28. transforms.Normalize(self.norm_mean, self.norm_std),
  29. ])
  30. return inference_transform
  31. def img_transform(self, img_rgb, transform=None):
  32. # 将数据转换为模型读取的形式
  33. if transform is None:
  34. raise ValueError("找不到transform!必须有transform对img进行处理")
  35. img_t = transform(img_rgb)
  36. return img_t
  37. def get_model(self):
  38. return self.model
  39. # 输出图片路径或者cv2格式的图片数据
  40. def predict(self, img):
  41. if type(img) == type('path'):
  42. img = Image.open(img).convert('RGB')
  43. else:
  44. img = img.convert('RGB')
  45. transform = self.inference_transform()
  46. img_tensor = transform(img)
  47. img_tensor.unsqueeze_(0)
  48. img_tensor = img_tensor.to(self.device)
  49. # print(img.shape)
  50. with torch.no_grad():
  51. outputs = self.model(img_tensor)
  52. return outputs.reshape(2048).cpu().numpy()