test03.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. import towhee
  2. import cv2
  3. from towhee._types.image import Image
  4. import os
  5. import PIL.Image as Image
  6. import numpy as np
  7. from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
  8. from MyModel import MyModel
  9. from MyEfficientNet import MyEfficient
  10. import torch
  11. from transformers import ViTFeatureExtractor, ViTModel
  12. from towhee.types.image_utils import to_image_color
  13. connections.connect(host='127.0.0.1', port='19530')
  14. dataset_path = ["D:/Code/ML/images/Mywork3/card_database/prizm/21-22/*/*.JPG",
  15. "D:/Code/ML/images/Mywork3/card_database/mosaic/*/*/*.JPG"]
  16. img_id = 0
  17. vec_num = 0
  18. myModel = MyModel(r"D:\Code\ML\model\card_cls\res_card_out764_freeze4.pth", out_features=764)
  19. # myModel = MyModel(r"C:\Users\Administrator\.cache\torch\hub\checkpoints\resnet50-0676ba61.pth", out_features=1000)
  20. # myModel = MyEfficient('')
  21. yolo_model = torch.hub.load(r"C:\Users\Administrator\.cache\torch\hub\ultralytics_yolov5_master", 'custom',
  22. path="yolov5s.pt", source='local')
  23. # yolo_model = torch.hub.load("ultralytics/yolov5", "yolov5s")
  24. # 生成ID
  25. def get_id(param):
  26. global img_id
  27. img_id += 1
  28. return img_id
  29. # def eff_enbedding(img):
  30. # global vec_num
  31. # vec_num += 1
  32. # print('vec: ', vec_num)
  33. # return myModel.predict(img)
  34. # 生成向量
  35. def img2vec(img):
  36. global vec_num
  37. vec_num += 1
  38. print('vec: ', vec_num)
  39. return myModel.predict(img)
  40. # 生成信息
  41. path_num = 0
  42. def get_info(path):
  43. path = os.path.split(path)[0]
  44. path, num_and_player = os.path.split(path)
  45. num = num_and_player.split(' ')[0]
  46. player = ' '.join(os.path.split(num_and_player)[-1].split(' ')[1:])
  47. path, year = os.path.split(path)
  48. series = os.path.split(path)[1]
  49. rtn = "{} {} {} #{}".format(series, year, player, num)
  50. global path_num
  51. path_num += 1
  52. print(path_num, " loading " + rtn)
  53. return rtn
  54. def read_imgID(results):
  55. imgIDs = []
  56. for re in results:
  57. # 输出结果图片信息
  58. print('---------', re)
  59. imgIDs.append(re.id)
  60. return imgIDs
  61. def yolo_detect(img):
  62. results = yolo_model(img)
  63. pred = results.pred[0][:, :4].cpu().numpy()
  64. boxes = pred.astype(np.int32)
  65. max_img = get_object(img, boxes)
  66. return max_img
  67. def get_object(img, boxes):
  68. if isinstance(img, str):
  69. img = Image.open(img)
  70. if len(boxes) == 0:
  71. return img
  72. max_area = 0
  73. # 选出最大的框
  74. x1, y1, x2, y2 = 0, 0, 0, 0
  75. for box in boxes:
  76. temp_x1, temp_y1, temp_x2, temp_y2 = box
  77. area = (temp_x2 - temp_x1) * (temp_y2 - temp_y1)
  78. if area > max_area:
  79. max_area = area
  80. x1, y1, x2, y2 = temp_x1, temp_y1, temp_x2, temp_y2
  81. max_img = img.crop((x1, y1, x2, y2))
  82. return max_img
  83. # 创建向量数据库
  84. def create_milvus_collection(collection_name, dim):
  85. if utility.has_collection(collection_name):
  86. utility.drop_collection(collection_name)
  87. fields = [
  88. FieldSchema(name='img_id', dtype=DataType.INT64, is_primary=True),
  89. FieldSchema(name='path', dtype=DataType.VARCHAR, max_length=300),
  90. FieldSchema(name="info", dtype=DataType.VARCHAR, max_length=300),
  91. FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, descrition='image embedding vectors', dim=dim)
  92. ]
  93. schema = CollectionSchema(fields=fields, description='reverse image search')
  94. collection = Collection(name=collection_name, schema=schema)
  95. index_params = {
  96. 'metric_type': 'L2',
  97. 'index_type': "IVF_FLAT",
  98. 'params': {"nlist": dim}
  99. }
  100. collection.create_index(field_name="embedding", index_params=index_params)
  101. return collection
  102. # 判断是否加载已有数据库,或新创建数据库
  103. def is_creat_collection(have_coll, collection_name):
  104. if have_coll:
  105. # 连接现有的数据库
  106. collection = Collection(name=collection_name)
  107. else:
  108. # 新建立数据库
  109. collection = create_milvus_collection(collection_name, 2048)
  110. dc = (
  111. towhee.glob['path'](*dataset_path)
  112. .runas_op['path', 'img_id'](func=get_id)
  113. .runas_op['path', 'info'](func=get_info)
  114. # .image_decode['path', 'img']()
  115. .runas_op['path', "object"](yolo_detect)
  116. .runas_op['object', 'vec'](func=img2vec)
  117. .tensor_normalize['vec', 'vec']()
  118. # .image_embedding.timm['img', 'vec'](model_name='resnet50')
  119. .ann_insert.milvus[('img_id', 'path', 'info', 'vec'), 'mr'](collection=collection)
  120. )
  121. print('Total number of inserted data is {}.'.format(collection.num_entities))
  122. return collection
  123. # 通过ID查询
  124. def query_by_imgID(collection, img_id, limit=1):
  125. expr = 'img_id == ' + str(img_id)
  126. res = collection.query(expr, output_fields=["path", "info"], offset=0, limit=limit, timeout=2)
  127. return res
  128. def from_path_get_series(path):
  129. for i in range(3):
  130. path = os.path.split(path)[0]
  131. series = os.path.split(path)[-1]
  132. return series
  133. if __name__ == '__main__':
  134. print('start')
  135. # 是否存在数据库
  136. have_coll = True
  137. # 默认模型
  138. # collection = is_creat_collection(have_coll=have_coll, collection_name="reverse_image_search")
  139. # 自定义模型
  140. collection = is_creat_collection(have_coll=have_coll, collection_name="reverse_image_search_myModel")
  141. # 测试的图片路径
  142. img_path = ["D:/Code/ML/images/test02/test2/prizm/base 21-22/*/*.jpg",
  143. "D:/Code/ML/images/test02/test2/prizm/base 21-22/*/*.jpeg",
  144. "D:/Code/ML/images/test02/test2/prizm/base 21-22/*/*.png",
  145. "D:/Code/ML/images/test02/test2/mosaic/20-21/*/*.jpg"]
  146. data = (towhee.glob['path'](*img_path)
  147. # image_decode['path', 'img']().
  148. .runas_op['path', "object"](yolo_detect)
  149. .runas_op['object', 'vec'](func=img2vec)
  150. .tensor_normalize['vec', 'vec']()
  151. # image_embedding.timm['img', 'vec'](model_name='resnet50').
  152. .ann_search.milvus['vec', 'result'](collection=collection, limit=3)
  153. .runas_op['result', 'result_imgID'](func=read_imgID)
  154. .select['path', 'result_imgID', 'vec']()
  155. )
  156. print(data)
  157. collection.load()
  158. # res = query_by_imgID(collection, data[0].result_imgID[0])
  159. #
  160. # print(res[0])
  161. top3_num = 0
  162. top1_num = 0
  163. test_img_num = len(list(data))
  164. # 查询所有测试图片
  165. for i in range(test_img_num):
  166. top3_flag = False
  167. # 获取图片真正的系列
  168. source_card_series = from_path_get_series(data[i].path)
  169. # 获取图片真正的编号
  170. source_num = os.path.split(os.path.split(data[i].path)[0])[-1].split('#')[-1]
  171. # 每个测试图片返回三个最相似的图片ID,一一测试
  172. for j in range(3):
  173. res = query_by_imgID(collection, data[i].result_imgID[j])
  174. # 获取预测的图片的系列
  175. result_card_series = from_path_get_series(res[0]['path'])
  176. # 获取预测的图片的编号
  177. result_num = os.path.split(os.path.split(res[0]['path'])[0])[-1].split('#')[-1]
  178. # 判断top1是否正确
  179. if j == 0 and source_num == result_num and source_card_series == result_card_series:
  180. top1_num += 1
  181. # top3中有一个正确的标记为正确
  182. if source_num == result_num and source_card_series == result_card_series:
  183. top3_flag = True
  184. # 日志
  185. if j == 0 and source_num == result_num and source_card_series == result_card_series:
  186. print(top1_num)
  187. elif j == 0:
  188. print('top_1 错误')
  189. print("series: {}, num: {} === result - series: {}, num: {}".format(
  190. source_card_series, source_num, result_card_series, result_num
  191. ))
  192. if top3_flag:
  193. top3_num += 1
  194. print("====================================")
  195. print("测试图片共: ", test_img_num)
  196. top1_accuracy = (top1_num / test_img_num) * 100
  197. top3_accuracy = (top3_num / test_img_num) * 100
  198. print("top3 准确率:{} % \n top1 准确率: {} %".
  199. format(top3_accuracy, top1_accuracy))
  200. '''
  201. 148 张图片
  202. 默认resnet50 + yolo
  203. top3 准确率:100.0 %
  204. top1 准确率: 85.11904761904762 %
  205. 148 张图片
  206. 自定义resnet50_freeze_out217 + yolo
  207. top3 准确率:94.04761904761905 %
  208. top1 准确率: 93.45238095238095 %
  209. 测试图片共: 168
  210. 自定义resnet50_freeze_out421 + yolo + normalize
  211. top3 准确率:96.42857142857143 %
  212. top1 准确率: 95.23809523809523 %
  213. 测试图片共: 168
  214. 自定义resnet50_out764_freeze + yolo + normalize
  215. top3 准确率:95.23809523809523 %
  216. top1 准确率: 94.04761904761905 %
  217. '''