MyViTFeatureExtractor.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. import torch
  2. from transformers import ViTModel, ViTConfig
  3. import torchvision.transforms as transforms
  4. from PIL import Image
  5. import numpy as np
  6. from typing import List, Union
  7. import logging
  8. import os
  9. import torch.nn.functional as F
  10. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
  11. class MyViTFeatureExtractor:
  12. def __init__(self, local_model_path: str) -> None:
  13. """
  14. 初始化特征提取器。
  15. 适配: 能够加载由 MetricViT 训练并保存的 backbone 模型。
  16. """
  17. if not os.path.isdir(local_model_path):
  18. raise NotADirectoryError(f"Model path not found: {local_model_path}")
  19. logging.info(f"Loading model from: {local_model_path}")
  20. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  21. logging.info(f"Using device: {self.device}")
  22. try:
  23. # 加载配置和模型 (这里加载的是纯 ViTModel)
  24. self.config = ViTConfig.from_pretrained(local_model_path, local_files_only=True)
  25. self.model = ViTModel.from_pretrained(local_model_path, config=self.config, local_files_only=True)
  26. except Exception as e:
  27. logging.error(f"Failed to load model: {e}")
  28. raise
  29. self.model.to(self.device)
  30. self.model.eval()
  31. # 定义预处理 (与训练时的 Val Transform 保持一致)
  32. self.transform = transforms.Compose([
  33. transforms.Resize((224, 224)),
  34. transforms.ToTensor(),
  35. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
  36. ])
  37. self.feature_dim = self.model.config.hidden_size
  38. logging.info(f"Feature dimension: {self.feature_dim}")
  39. def run(self, imgs: List[Union[str, np.ndarray, Image.Image]], normalize: bool = True) -> np.ndarray:
  40. """
  41. 处理一批图像并返回特征向量。
  42. Args:
  43. imgs: 图片路径、numpy数组或PIL Image的列表
  44. normalize: 是否进行 L2 归一化 (强烈建议为 True,适配 Milvus/Cosine 搜索)
  45. """
  46. if not imgs:
  47. return np.empty((0, self.feature_dim), dtype=np.float32)
  48. processed_tensors = []
  49. valid_indices = []
  50. # 1. 预处理
  51. for i, img_input in enumerate(imgs):
  52. try:
  53. # --- 图像读取逻辑 (保持你原有的健壮性逻辑) ---
  54. if isinstance(img_input, str):
  55. img = Image.open(img_input).convert('RGB')
  56. elif isinstance(img_input, np.ndarray):
  57. if img_input.dtype != np.uint8:
  58. if img_input.max() <= 1.0:
  59. img_input = (img_input * 255).astype(np.uint8)
  60. else:
  61. img_input = img_input.astype(np.uint8)
  62. img = Image.fromarray(img_input, 'RGB')
  63. elif isinstance(img_input, Image.Image):
  64. img = img_input.convert('RGB')
  65. else:
  66. continue
  67. # ----------------------------------------
  68. img_tensor = self.transform(img)
  69. processed_tensors.append(img_tensor)
  70. valid_indices.append(i)
  71. except Exception as e:
  72. logging.error(f"Error processing image index {i}: {e}")
  73. if not processed_tensors:
  74. return np.empty((0, self.feature_dim), dtype=np.float32)
  75. # 2. 推理
  76. batch_tensor = torch.stack(processed_tensors, dim=0).to(self.device)
  77. with torch.no_grad():
  78. outputs = self.model(batch_tensor)
  79. # 【关键】提取 last_hidden_state 的 [CLS] token (Index 0)
  80. # 这与训练时的 MetricViT 保持完全一致
  81. features = outputs.last_hidden_state[:, 0, :]
  82. # 3. 后处理
  83. if normalize:
  84. # 使用 PyTorch 的 normalize 更精确,或者保持 numpy 实现
  85. features = F.normalize(features, p=2, dim=1)
  86. output_np = features.cpu().numpy()
  87. else:
  88. output_np = features.cpu().numpy()
  89. # 4. 填充结果 (保持列表长度一致)
  90. if len(output_np) != len(imgs):
  91. final_output = np.full((len(imgs), self.feature_dim), np.nan, dtype=np.float32)
  92. for idx, vec in zip(valid_indices, output_np):
  93. final_output[idx] = vec
  94. return final_output
  95. return output_np
  96. def compare_images(extractor, img_path_A, img_path_B):
  97. """
  98. 计算两张图片的相似度
  99. """
  100. if not os.path.exists(img_path_A) or not os.path.exists(img_path_B):
  101. print(f"❌ 错误: 找不到图片路径。\nA: {img_path_A}\nB: {img_path_B}")
  102. return
  103. print(f"🔍 正在对比:")
  104. print(f" 图 A: {os.path.basename(img_path_A)}")
  105. print(f" 图 B: {os.path.basename(img_path_B)}")
  106. # 1. 提取特征 (一次传入两张图,效率更高)
  107. # run 方法返回的是已经归一化过的 numpy 数组
  108. vectors = extractor.run([img_path_A, img_path_B], normalize=True)
  109. vec_a = vectors[0]
  110. vec_b = vectors[1]
  111. # 2. 计算余弦相似度 (Cosine Similarity)
  112. # 因为 vec_a 和 vec_b 模长都为 1,所以点积就是余弦相似度
  113. similarity = np.dot(vec_a, vec_b)
  114. # 3. 计算欧氏距离 (Euclidean Distance) - 辅助参考
  115. # 距离越小越相似
  116. distance = np.linalg.norm(vec_a - vec_b)
  117. # 4. 打印结果
  118. print("-" * 30)
  119. print(f"📊 相似度结果:")
  120. print(f" ★ 余弦相似度 (Cosine): {similarity:.4f} (越接近 1.0 越相似)")
  121. print(f" ☆ 欧氏距离 (L2 Dist): {distance:.4f} (越接近 0.0 越相似)")
  122. print("-" * 30)
  123. # 5. 简单判定建议
  124. threshold = 0.85 # 这个阈值可以根据实际情况调整
  125. if similarity > threshold:
  126. print("✅ 结论: 它们极有可能是同一张卡 (或同一宝可梦的不同语言版本)")
  127. else:
  128. print("❌ 结论: 它们看起来是不同的卡片")
  129. print("\n")
  130. if __name__ == "__main__":
  131. # ================= 配置 =================
  132. # 你的模型保存路径
  133. MODEL_PATH = "/home/martin/ML/Model/pokemon_cls/vit-base-patch16-224-Pokemon02"
  134. # 这里填入你想测试的两张图片的绝对路径
  135. # 建议测试:
  136. # 1. 一张中文卡 vs 同一张的英文卡
  137. # 2. 一张卡 vs 一张完全不同的卡
  138. IMG_1 = r"/home/martin/ML/RemoteProject/untitled10/uploads/伊布us1.png"
  139. IMG_2 = r"/home/martin/ML/RemoteProject/untitled10/uploads/伊布tc1.png"
  140. # ================= 运行 =================
  141. try:
  142. print("正在加载模型,请稍候...")
  143. # 初始化提取器
  144. extractor = MyViTFeatureExtractor(MODEL_PATH)
  145. # 执行对比
  146. compare_images(extractor, IMG_1, IMG_2)
  147. except Exception as e:
  148. print(f"运行出错: {e}")