seg_test02.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import torch
  2. import numpy as np
  3. from PIL import Image
  4. import matplotlib.pyplot as plt
  5. from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation
  6. model_dir = r"C:\Code\ML\Model\Card_Seg\segformer_card_hand02_safetensors"
  7. # img_path = r"C:\Users\wow38\Pictures\videoframe_6871967.png"
  8. processor = AutoImageProcessor.from_pretrained(model_dir)
  9. model = AutoModelForSemanticSegmentation.from_pretrained(model_dir)
  10. def show(img_path):
  11. image = Image.open(img_path).convert("RGB")
  12. inputs = processor(images=image, return_tensors="pt")
  13. with torch.no_grad():
  14. outputs = model(**inputs)
  15. logits = outputs.logits
  16. pred = torch.nn.functional.interpolate(
  17. logits,
  18. size=image.size[::-1],
  19. mode="bilinear",
  20. align_corners=False
  21. ).argmax(dim=1)[0].cpu().numpy()
  22. plt.figure(figsize=(10, 5))
  23. plt.subplot(1, 2, 1)
  24. plt.imshow(image)
  25. plt.title("image")
  26. plt.axis("off")
  27. plt.subplot(1, 2, 2)
  28. plt.imshow(pred)
  29. plt.title("mask")
  30. plt.axis("off")
  31. plt.tight_layout()
  32. plt.show()
  33. if __name__ == '__main__':
  34. show(r"C:\Users\wow38\Pictures\videoframe_6871967 - 副本.png")
  35. print()