predict_preprocess.py 995 B

123456789101112131415161718192021222324252627282930313233343536
  1. import math
  2. import cv2
  3. import numpy as np
  4. import math
  5. import torch
  6. import torch.nn as nn
  7. import torchvision.transforms as transforms
  8. import cv2
  9. from app.utils.data_augmentation import LetterBox
  10. def predict_preprocess(img_bgr, imgSize_train):
  11. device_str = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  12. img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
  13. letterBox = LetterBox(imgSize_train)
  14. img_rgb_letterbox = letterBox.handle_single_img(img_rgb)
  15. img_np = np.array(img_rgb_letterbox)
  16. imgTensor = torch.tensor(img_np, dtype=torch.float32, device=device_str)
  17. # 将所有元素值除以255,进行归一化
  18. imgTensor = imgTensor * (1 / 255.0)
  19. # 把形状从[H, W, C] 改为 [C, H, W]
  20. imgTensor_CHW = imgTensor.permute(2, 0, 1)
  21. normaliz_operate_c3 = transforms.Compose([
  22. transforms.Normalize(mean=(0, 0, 0), std=(1, 1, 1)),
  23. ])
  24. imgTensor_CHW_norm = normaliz_operate_c3(imgTensor_CHW)
  25. return imgTensor_CHW_norm