| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061 |
- import torch
- import torchvision.models as models
- import torchvision.transforms as transforms
- from PIL import Image
- import timm
- class MyModel:
- def __init__(self, model_dict_path):
- self.norm_mean = [0.485, 0.456, 0.406]
- self.norm_std = [0.229, 0.224, 0.225]
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- # self.model = models.resnet50(pretrained=True)
- self.model = timm.create_model('resnet50', num_classes=2048, pretrained=True)
- self.model.eval()
- # 自定义模型
- # print(list(self.model.children()))
- features = list(self.model.children())[:-1] # 去掉全连接层
- self.model = torch.nn.Sequential(*features).to(self.device)
- # self.model.to(self.device)
- def inference_transform(self):
- inference_transform = transforms.Compose([
- transforms.Resize((256, 256)),
- transforms.CenterCrop(224),
- transforms.ToTensor(),
- transforms.Normalize(self.norm_mean, self.norm_std),
- ])
- return inference_transform
- def img_transform(self, img_rgb, transform=None):
- # 将数据转换为模型读取的形式
- if transform is None:
- raise ValueError("找不到transform!必须有transform对img进行处理")
- img_t = transform(img_rgb)
- return img_t
- def get_model(self):
- return self.model
- # 输出图片路径或者cv2格式的图片数据
- def predict(self, img):
- if type(img) == type('path'):
- img = Image.open(img).convert('RGB')
- transform = self.inference_transform()
- img_tensor = transform(img)
- img_tensor.unsqueeze_(0)
- img_tensor = img_tensor.to(self.device)
- # print(img.shape)
- with torch.no_grad():
- outputs = self.model(img_tensor)
- return outputs.reshape(2048).cpu().numpy()
|