|
|
@@ -0,0 +1,100 @@
|
|
|
+import torch
|
|
|
+import os
|
|
|
+import json
|
|
|
+from torchvision import transforms,models
|
|
|
+from PIL import Image
|
|
|
+from pathlib import Path
|
|
|
+import matplotlib.pyplot as plt
|
|
|
+import numpy as np
|
|
|
+
|
|
|
+# 加载分类模型
|
|
|
+def create_efficientnet_b7(in_features, out_features, pretrained=False):
|
|
|
+ model = models.efficientnet_b7(pretrained)
|
|
|
+ model.classifier = torch.nn.Sequential(
|
|
|
+ torch.nn.Linear(in_features=in_features, out_features=out_features, bias=False)
|
|
|
+ )
|
|
|
+ return model
|
|
|
+
|
|
|
+def load_classification_model(weights_path=""):
|
|
|
+ model = create_efficientnet_b7(in_features=2560, out_features=5, pretrained=False).to(device)
|
|
|
+ model.load_state_dict(torch.load(weights_path, map_location=device))
|
|
|
+ model.eval()
|
|
|
+ return model
|
|
|
+
|
|
|
+def yolo_detect(img):
|
|
|
+ results = yolo_model(img)
|
|
|
+
|
|
|
+ pred = results.pred[0].cpu().numpy()
|
|
|
+ pred = pred[pred[:, 5] == 0][:, :4]
|
|
|
+ boxes = pred.astype(np.int32)
|
|
|
+
|
|
|
+ max_img = get_object(img, boxes)
|
|
|
+
|
|
|
+ return max_img
|
|
|
+
|
|
|
+
|
|
|
+def get_object(img, boxes):
|
|
|
+ if isinstance(img, str):
|
|
|
+ img = Image.open(img)
|
|
|
+
|
|
|
+ if len(boxes) == 0:
|
|
|
+ return img
|
|
|
+
|
|
|
+ max_area = 0
|
|
|
+
|
|
|
+ # 选出最大的框
|
|
|
+ x1, y1, x2, y2 = 0, 0, 0, 0
|
|
|
+ for box in boxes:
|
|
|
+ temp_x1, temp_y1, temp_x2, temp_y2 = box
|
|
|
+ area = (temp_x2 - temp_x1) * (temp_y2 - temp_y1)
|
|
|
+ if area > max_area:
|
|
|
+ max_area = area
|
|
|
+ x1, y1, x2, y2 = temp_x1, temp_y1, temp_x2, temp_y2
|
|
|
+
|
|
|
+ max_img = img.crop((x1, y1, x2, y2))
|
|
|
+ return max_img
|
|
|
+
|
|
|
+# 分类模型推理
|
|
|
+def classify_card(classification_model, card_region_tensor):
|
|
|
+ with torch.no_grad():
|
|
|
+ output = classification_model(card_region_tensor)
|
|
|
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
|
|
|
+ predicted_class = torch.argmax(probabilities).item()
|
|
|
+ prob = probabilities[predicted_class].numpy()
|
|
|
+ return predicted_class,prob
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
+ # yolo_weights_path = 'best.pt'
|
|
|
+ yolo_model = torch.hub.load(r'D:\img2img\apiserver\ultralytics_yolov5_master',
|
|
|
+ 'custom', path=r"D:\img2img\apiserver\ultralytics_yolov5_master\best.pt",
|
|
|
+ source='local')
|
|
|
+ classification_weights_path = "D:/refraction.pth"
|
|
|
+ image_path = "C:/Users/WS/Desktop/prizm/red/1d1ede8d05a241d3acf65d2badcb1e78.jpg"
|
|
|
+
|
|
|
+
|
|
|
+ classification_model = load_classification_model(classification_weights_path)
|
|
|
+
|
|
|
+ json_path = './class_indices.json'
|
|
|
+ assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
|
|
|
+
|
|
|
+ with open(json_path, "r") as f:
|
|
|
+ class_indict = json.load(f)
|
|
|
+
|
|
|
+ card_region = yolo_detect(image_path)
|
|
|
+
|
|
|
+ data_transform = transforms.Compose([
|
|
|
+ transforms.Resize((224,224)),
|
|
|
+ # transforms.CenterCrop(img_size[num_model]),
|
|
|
+ transforms.ToTensor(),
|
|
|
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
|
+ ])
|
|
|
+ card_region_tensor = data_transform(card_region).unsqueeze(0)
|
|
|
+
|
|
|
+ # 分类模型推理
|
|
|
+ predicted_class,prob = classify_card(classification_model, card_region_tensor)
|
|
|
+ predict= "class: {} prob: {:.3}".format(class_indict[str(predicted_class)],
|
|
|
+ prob)
|
|
|
+
|
|
|
+
|
|
|
+ print(f'The detected card belongs to : {predict}')
|