import torch import os import json from torchvision import transforms,models from PIL import Image from pathlib import Path import matplotlib.pyplot as plt import numpy as np # 加载分类模型 def create_efficientnet_b7(in_features, out_features, pretrained=False): model = models.efficientnet_b7(pretrained) model.classifier = torch.nn.Sequential( torch.nn.Linear(in_features=in_features, out_features=out_features, bias=False) ) return model def load_classification_model(weights_path=""): model = create_efficientnet_b7(in_features=2560, out_features=5, pretrained=False).to(device) model.load_state_dict(torch.load(weights_path, map_location=device)) model.eval() return model def yolo_detect(img): results = yolo_model(img) pred = results.pred[0].cpu().numpy() pred = pred[pred[:, 5] == 0][:, :4] boxes = pred.astype(np.int32) max_img = get_object(img, boxes) return max_img def get_object(img, boxes): if isinstance(img, str): img = Image.open(img) if len(boxes) == 0: return img max_area = 0 # 选出最大的框 x1, y1, x2, y2 = 0, 0, 0, 0 for box in boxes: temp_x1, temp_y1, temp_x2, temp_y2 = box area = (temp_x2 - temp_x1) * (temp_y2 - temp_y1) if area > max_area: max_area = area x1, y1, x2, y2 = temp_x1, temp_y1, temp_x2, temp_y2 max_img = img.crop((x1, y1, x2, y2)) return max_img # 分类模型推理 def classify_card(classification_model, card_region_tensor): with torch.no_grad(): output = classification_model(card_region_tensor) probabilities = torch.nn.functional.softmax(output[0], dim=0) predicted_class = torch.argmax(probabilities).item() prob = probabilities[predicted_class].numpy() return predicted_class,prob if __name__ == '__main__': device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # yolo_weights_path = 'best.pt' yolo_model = torch.hub.load(r'D:\img2img\apiserver\ultralytics_yolov5_master', 'custom', path=r"D:\img2img\apiserver\ultralytics_yolov5_master\best.pt", source='local') classification_weights_path = "D:/refraction.pth" image_path = "C:/Users/WS/Desktop/prizm/red/1d1ede8d05a241d3acf65d2badcb1e78.jpg" classification_model = load_classification_model(classification_weights_path) json_path = './class_indices.json' assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path) with open(json_path, "r") as f: class_indict = json.load(f) card_region = yolo_detect(image_path) data_transform = transforms.Compose([ transforms.Resize((224,224)), # transforms.CenterCrop(img_size[num_model]), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) card_region_tensor = data_transform(card_region).unsqueeze(0) # 分类模型推理 predicted_class,prob = classify_card(classification_model, card_region_tensor) predict= "class: {} prob: {:.3}".format(class_indict[str(predicted_class)], prob) print(f'The detected card belongs to : {predict}')