Milvus_Test_effnet.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. import towhee
  2. import cv2
  3. from towhee._types.image import Image
  4. import os
  5. import PIL.Image as Image
  6. import numpy as np
  7. from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
  8. from MyEfficientNet import MyModel
  9. import torch
  10. connections.connect(host='127.0.0.1', port='19530')
  11. dataset_path = ["D:\Code\ML\images\Mywork3\card_database_yolo/*/*/*/*"]
  12. img_id = 0
  13. yolo_num = 0
  14. vec_num = 0
  15. myModel = MyModel(r"D:\Code\ML\model\card_cls\effcient_card_out854_freeze2.pth", out_features=854)
  16. # yolo_model = torch.hub.load(r"C:\Users\Administrator\.cache\torch\hub\ultralytics_yolov5_master", 'custom',
  17. # path="yolov5s.pt", source='local')
  18. # yolo_model = torch.hub.load("ultralytics/yolov5", "yolov5s")
  19. # 生成ID
  20. def get_id(param):
  21. global img_id
  22. img_id += 1
  23. return img_id
  24. # def eff_enbedding(img):
  25. # global vec_num
  26. # vec_num += 1
  27. # print('vec: ', vec_num)
  28. # return myModel.predict(img)
  29. # 生成向量
  30. def img2vec(img):
  31. global vec_num
  32. vec_num += 1
  33. print('vec: ', vec_num)
  34. return myModel.predict(img)
  35. # 生成信息
  36. path_num = 0
  37. def get_info(path):
  38. path = os.path.split(path)[0]
  39. path, num_and_player = os.path.split(path)
  40. num = num_and_player.split(' ')[0]
  41. player = ' '.join(os.path.split(num_and_player)[-1].split(' ')[1:])
  42. path, year = os.path.split(path)
  43. series = os.path.split(path)[1]
  44. rtn = "{} {} {} #{}".format(series, year, player, num)
  45. global path_num
  46. path_num += 1
  47. print(path_num, " loading " + rtn)
  48. return rtn
  49. def read_imgID(results):
  50. imgIDs = []
  51. for re in results:
  52. # 输出结果图片信息
  53. print('---------', re)
  54. imgIDs.append(re.id)
  55. return imgIDs
  56. # def yolo_detect(img):
  57. # results = yolo_model(img)
  58. #
  59. # pred = results.pred[0][:, :4].cpu().numpy()
  60. # boxes = pred.astype(np.int32)
  61. #
  62. # max_img = get_object(img, boxes)
  63. #
  64. # global yolo_num
  65. # yolo_num += 1
  66. # print("yolo_num: ", yolo_num)
  67. # return max_img
  68. #
  69. #
  70. # def get_object(img, boxes):
  71. # if isinstance(img, str):
  72. # img = Image.open(img)
  73. #
  74. # if len(boxes) == 0:
  75. # return img
  76. #
  77. # max_area = 0
  78. #
  79. # # 选出最大的框
  80. # x1, y1, x2, y2 = 0, 0, 0, 0
  81. # for box in boxes:
  82. # temp_x1, temp_y1, temp_x2, temp_y2 = box
  83. # area = (temp_x2 - temp_x1) * (temp_y2 - temp_y1)
  84. # if area > max_area:
  85. # max_area = area
  86. # x1, y1, x2, y2 = temp_x1, temp_y1, temp_x2, temp_y2
  87. #
  88. # max_img = img.crop((x1, y1, x2, y2))
  89. # return max_img
  90. # 创建向量数据库
  91. def create_milvus_collection(collection_name, dim):
  92. if utility.has_collection(collection_name):
  93. utility.drop_collection(collection_name)
  94. fields = [
  95. FieldSchema(name='img_id', dtype=DataType.INT64, is_primary=True),
  96. FieldSchema(name='path', dtype=DataType.VARCHAR, max_length=300),
  97. FieldSchema(name="info", dtype=DataType.VARCHAR, max_length=300),
  98. FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, descrition='image embedding vectors', dim=dim)
  99. ]
  100. schema = CollectionSchema(fields=fields, description='reverse image search')
  101. collection = Collection(name=collection_name, schema=schema)
  102. index_params = {
  103. 'metric_type': 'L2',
  104. 'index_type': "IVF_FLAT",
  105. 'params': {"nlist": dim}
  106. }
  107. collection.create_index(field_name="embedding", index_params=index_params)
  108. return collection
  109. # 判断是否加载已有数据库,或新创建数据库
  110. def is_creat_collection(have_coll, collection_name):
  111. if have_coll:
  112. # 连接现有的数据库
  113. collection = Collection(name=collection_name)
  114. else:
  115. # 新建立数据库
  116. collection = create_milvus_collection(collection_name, 2560)
  117. dc = (
  118. towhee.glob['path'](*dataset_path)
  119. .runas_op['path', 'img_id'](func=get_id)
  120. .runas_op['path', 'info'](func=get_info)
  121. # .image_decode['path', 'img']()
  122. # .runas_op['path', "object"](yolo_detect)
  123. .runas_op['path', 'vec'](func=img2vec)
  124. .tensor_normalize['vec', 'vec']()
  125. # .image_embedding.timm['img', 'vec'](model_name='resnet50')
  126. .ann_insert.milvus[('img_id', 'path', 'info', 'vec'), 'mr'](collection=collection)
  127. )
  128. print('Total number of inserted data is {}.'.format(collection.num_entities))
  129. return collection
  130. # 通过ID查询
  131. def query_by_imgID(collection, img_id, limit=1):
  132. expr = 'img_id == ' + str(img_id)
  133. res = collection.query(expr, output_fields=["path", "info"], offset=0, limit=limit, timeout=2)
  134. return res
  135. # 分别返回 编号,年份,系列
  136. def from_path_get_info(path):
  137. card_info = []
  138. for i in range(3):
  139. path = os.path.split(path)[0]
  140. card_info.append(os.path.split(path)[-1])
  141. card_info[0] = card_info[0].split('#')[-1]
  142. return card_info
  143. def from_query_path_get_info(path):
  144. card_info = []
  145. for i in range(3):
  146. path = os.path.split(path)[0]
  147. card_info.append(os.path.split(path)[-1])
  148. card_info[0] = card_info[0].split(' ')[0]
  149. return card_info
  150. if __name__ == '__main__':
  151. print('start')
  152. # 是否存在数据库
  153. have_coll = True
  154. # 默认模型
  155. # collection = is_creat_collection(have_coll=have_coll, collection_name="reverse_image_search")
  156. # 自定义模型
  157. collection = is_creat_collection(have_coll=have_coll, collection_name="reverse_image_search_myModel")
  158. # 测试的图片路径
  159. img_path = ["D:/Code/ML/images/test02/test(mosaic,pz)/*/*/*/*"]
  160. data = (towhee.glob['path'](*img_path)
  161. # image_decode['path', 'img']().
  162. # .runas_op['path', "object"](yolo_detect)
  163. .runas_op['path', 'vec'](func=img2vec)
  164. .tensor_normalize['vec', 'vec']()
  165. # image_embedding.timm['img', 'vec'](model_name='resnet50').
  166. .ann_search.milvus['vec', 'result'](collection=collection, limit=3)
  167. .runas_op['result', 'result_imgID'](func=read_imgID)
  168. .select['path', 'result_imgID', 'vec']()
  169. )
  170. print(data)
  171. collection.load()
  172. # res = query_by_imgID(collection, data[0].result_imgID[0])
  173. #
  174. # print(res[0])
  175. top3_num = 0
  176. top1_num = 0
  177. test_img_num = len(list(data))
  178. # 查询所有测试图片
  179. for i in range(test_img_num):
  180. top3_flag = False
  181. # 获取图片真正的编号, 年份, 系列
  182. source_code, source_year, source_series = from_path_get_info(data[i].path)
  183. # 每个测试图片返回三个最相似的图片ID,一一测试
  184. for j in range(3):
  185. res = query_by_imgID(collection, data[i].result_imgID[j])
  186. # 获取预测的图片的编号, 年份, 系列
  187. result_code, result_year, result_series = from_query_path_get_info(res[0]['path'])
  188. # 判断top1是否正确
  189. if j == 0 and source_code == result_code and source_year == result_year and source_series == result_series:
  190. top1_num += 1
  191. print(top1_num)
  192. elif j == 0:
  193. print('top_1 错误')
  194. # top3中有一个正确的标记为正确
  195. if source_code == result_code and source_year == result_year and source_series == result_series:
  196. top3_flag = True
  197. print("series: {}, year: {},code: {} === result - series: {}, year: {}, code: {}".format(
  198. source_series, source_year, source_code, result_series, result_year, result_code,
  199. ))
  200. if top3_flag:
  201. top3_num += 1
  202. print("====================================")
  203. print("测试图片共: ", test_img_num)
  204. top1_accuracy = (top1_num / test_img_num) * 100
  205. top3_accuracy = (top3_num / test_img_num) * 100
  206. print("top3 准确率:{} % \n top1 准确率: {} %".
  207. format(top3_accuracy, top1_accuracy))
  208. '''
  209. '''