Sfoglia il codice sorgente

折射分类预测

shan.wan 2 anni fa
parent
commit
fe65d210a0
1 ha cambiato i file con 100 aggiunte e 0 eliminazioni
  1. 100 0
      refractionPre.py

+ 100 - 0
refractionPre.py

@@ -0,0 +1,100 @@
+import torch
+import os
+import json
+from torchvision import transforms,models
+from PIL import Image
+from pathlib import Path
+import matplotlib.pyplot as plt
+import numpy as np
+
+# 加载分类模型
+def create_efficientnet_b7(in_features, out_features, pretrained=False):
+    model = models.efficientnet_b7(pretrained)
+    model.classifier = torch.nn.Sequential(
+        torch.nn.Linear(in_features=in_features, out_features=out_features, bias=False)
+    )
+    return model
+
+def load_classification_model(weights_path=""):
+    model = create_efficientnet_b7(in_features=2560, out_features=5, pretrained=False).to(device)
+    model.load_state_dict(torch.load(weights_path, map_location=device))
+    model.eval()
+    return model
+
+def yolo_detect(img):
+    results = yolo_model(img)
+
+    pred = results.pred[0].cpu().numpy()
+    pred = pred[pred[:, 5] == 0][:, :4]
+    boxes = pred.astype(np.int32)
+
+    max_img = get_object(img, boxes)
+
+    return max_img
+
+
+def get_object(img, boxes):
+    if isinstance(img, str):
+        img = Image.open(img)
+
+    if len(boxes) == 0:
+        return img
+
+    max_area = 0
+
+    # 选出最大的框
+    x1, y1, x2, y2 = 0, 0, 0, 0
+    for box in boxes:
+        temp_x1, temp_y1, temp_x2, temp_y2 = box
+        area = (temp_x2 - temp_x1) * (temp_y2 - temp_y1)
+        if area > max_area:
+            max_area = area
+            x1, y1, x2, y2 = temp_x1, temp_y1, temp_x2, temp_y2
+
+    max_img = img.crop((x1, y1, x2, y2))
+    return max_img
+
+# 分类模型推理
+def classify_card(classification_model, card_region_tensor):
+    with torch.no_grad():
+        output = classification_model(card_region_tensor)
+        probabilities = torch.nn.functional.softmax(output[0], dim=0)
+        predicted_class = torch.argmax(probabilities).item()
+        prob = probabilities[predicted_class].numpy()
+    return predicted_class,prob
+
+if __name__ == '__main__':
+    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+    # yolo_weights_path = 'best.pt'
+    yolo_model = torch.hub.load(r'D:\img2img\apiserver\ultralytics_yolov5_master',
+                                'custom', path=r"D:\img2img\apiserver\ultralytics_yolov5_master\best.pt",
+                                source='local')
+    classification_weights_path = "D:/refraction.pth"
+    image_path = "C:/Users/WS/Desktop/prizm/red/1d1ede8d05a241d3acf65d2badcb1e78.jpg"
+
+
+    classification_model = load_classification_model(classification_weights_path)
+
+    json_path = './class_indices.json'
+    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
+
+    with open(json_path, "r") as f:
+        class_indict = json.load(f)
+
+    card_region = yolo_detect(image_path)
+
+    data_transform = transforms.Compose([
+         transforms.Resize((224,224)),
+         # transforms.CenterCrop(img_size[num_model]),
+         transforms.ToTensor(),
+         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
+    ])
+    card_region_tensor = data_transform(card_region).unsqueeze(0)
+
+    # 分类模型推理
+    predicted_class,prob = classify_card(classification_model, card_region_tensor)
+    predict= "class: {}   prob: {:.3}".format(class_indict[str(predicted_class)],
+                                                 prob)
+
+
+    print(f'The detected card belongs to : {predict}')