refractionPre.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import torch
  2. import os
  3. import json
  4. from torchvision import transforms,models
  5. from PIL import Image
  6. from pathlib import Path
  7. import matplotlib.pyplot as plt
  8. import numpy as np
  9. # 加载分类模型
  10. def create_efficientnet_b7(in_features, out_features, pretrained=False):
  11. model = models.efficientnet_b7(pretrained)
  12. model.classifier = torch.nn.Sequential(
  13. torch.nn.Linear(in_features=in_features, out_features=out_features, bias=False)
  14. )
  15. return model
  16. def load_classification_model(weights_path=""):
  17. model = create_efficientnet_b7(in_features=2560, out_features=5, pretrained=False).to(device)
  18. model.load_state_dict(torch.load(weights_path, map_location=device))
  19. model.eval()
  20. return model
  21. def yolo_detect(img):
  22. results = yolo_model(img)
  23. pred = results.pred[0].cpu().numpy()
  24. pred = pred[pred[:, 5] == 0][:, :4]
  25. boxes = pred.astype(np.int32)
  26. max_img = get_object(img, boxes)
  27. return max_img
  28. def get_object(img, boxes):
  29. if isinstance(img, str):
  30. img = Image.open(img)
  31. if len(boxes) == 0:
  32. return img
  33. max_area = 0
  34. # 选出最大的框
  35. x1, y1, x2, y2 = 0, 0, 0, 0
  36. for box in boxes:
  37. temp_x1, temp_y1, temp_x2, temp_y2 = box
  38. area = (temp_x2 - temp_x1) * (temp_y2 - temp_y1)
  39. if area > max_area:
  40. max_area = area
  41. x1, y1, x2, y2 = temp_x1, temp_y1, temp_x2, temp_y2
  42. max_img = img.crop((x1, y1, x2, y2))
  43. return max_img
  44. # 分类模型推理
  45. def classify_card(classification_model, card_region_tensor):
  46. with torch.no_grad():
  47. output = classification_model(card_region_tensor)
  48. probabilities = torch.nn.functional.softmax(output[0], dim=0)
  49. predicted_class = torch.argmax(probabilities).item()
  50. prob = probabilities[predicted_class].numpy()
  51. return predicted_class,prob
  52. if __name__ == '__main__':
  53. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  54. # yolo_weights_path = 'best.pt'
  55. yolo_model = torch.hub.load(r'D:\img2img\apiserver\ultralytics_yolov5_master',
  56. 'custom', path=r"D:\img2img\apiserver\ultralytics_yolov5_master\best.pt",
  57. source='local')
  58. classification_weights_path = "D:/refraction.pth"
  59. image_path = "C:/Users/WS/Desktop/prizm/red/1d1ede8d05a241d3acf65d2badcb1e78.jpg"
  60. classification_model = load_classification_model(classification_weights_path)
  61. json_path = './class_indices.json'
  62. assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
  63. with open(json_path, "r") as f:
  64. class_indict = json.load(f)
  65. card_region = yolo_detect(image_path)
  66. data_transform = transforms.Compose([
  67. transforms.Resize((224,224)),
  68. # transforms.CenterCrop(img_size[num_model]),
  69. transforms.ToTensor(),
  70. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  71. ])
  72. card_region_tensor = data_transform(card_region).unsqueeze(0)
  73. # 分类模型推理
  74. predicted_class,prob = classify_card(classification_model, card_region_tensor)
  75. predict= "class: {} prob: {:.3}".format(class_indict[str(predicted_class)],
  76. prob)
  77. print(f'The detected card belongs to : {predict}')