| 1234567891011121314151617181920212223242526272829303132333435363738394041 |
- import torch
- import torchvision.transforms as transforms
- from PIL import Image
- class MyModel:
- def __init__(self, model_path: str) -> None:
- self.norm_mean = [0.485, 0.456, 0.406]
- self.norm_std = [0.229, 0.224, 0.225]
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- self.model = torch.load(model_path)
- self.model.eval()
- def inference_transform(self):
- inference_transform = transforms.Compose([
- transforms.Resize((224, 224)),
- transforms.ToTensor(),
- transforms.Normalize(self.norm_mean, self.norm_std),
- ])
- return inference_transform
- # 输入图片, 获取图片特征向量
- def run(self, img):
- if type(img) == type('path'):
- img = Image.open(img).convert('RGB')
- else:
- img = Image.fromarray(img)
- img = img.convert('RGB')
- transform = self.inference_transform()
- img_tensor = transform(img)
- img_tensor = img_tensor.unsqueeze(0).to(self.device)
- # Perform prediction
- with torch.no_grad():
- outputs = self.model(img_tensor)
- _, predicted = torch.max(outputs.data, 1)
- return int(predicted)
|