MyModel.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. import torch
  2. import torchvision.transforms as transforms
  3. from PIL import Image
  4. class MyModel:
  5. def __init__(self, model_path: str) -> None:
  6. self.norm_mean = [0.485, 0.456, 0.406]
  7. self.norm_std = [0.229, 0.224, 0.225]
  8. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  9. self.model = torch.load(model_path)
  10. self.model.eval()
  11. def inference_transform(self):
  12. inference_transform = transforms.Compose([
  13. transforms.Resize((224, 224)),
  14. transforms.ToTensor(),
  15. transforms.Normalize(self.norm_mean, self.norm_std),
  16. ])
  17. return inference_transform
  18. # 输入图片, 获取图片特征向量
  19. def run(self, img):
  20. if type(img) == type('path'):
  21. img = Image.open(img).convert('RGB')
  22. else:
  23. img = Image.fromarray(img)
  24. img = img.convert('RGB')
  25. transform = self.inference_transform()
  26. img_tensor = transform(img)
  27. img_tensor = img_tensor.unsqueeze(0).to(self.device)
  28. # Perform prediction
  29. with torch.no_grad():
  30. outputs = self.model(img_tensor)
  31. _, predicted = torch.max(outputs.data, 1)
  32. return int(predicted)