import towhee import cv2 from towhee._types.image import Image import os import PIL.Image as Image import numpy as np from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility from MyEfficientNet import MyModel import torch connections.connect(host='127.0.0.1', port='19530') dataset_path = ["D:\Code\ML\images\Mywork3\card_database_yolo/*/*/*/*"] img_id = 0 yolo_num = 0 vec_num = 0 myModel = MyModel(r"D:\Code\ML\model\card_cls\effcient_card_out854_freeze2.pth", out_features=854) # yolo_model = torch.hub.load(r"C:\Users\Administrator\.cache\torch\hub\ultralytics_yolov5_master", 'custom', # path="yolov5s.pt", source='local') # yolo_model = torch.hub.load("ultralytics/yolov5", "yolov5s") # 生成ID def get_id(param): global img_id img_id += 1 return img_id # def eff_enbedding(img): # global vec_num # vec_num += 1 # print('vec: ', vec_num) # return myModel.predict(img) # 生成向量 def img2vec(img): global vec_num vec_num += 1 print('vec: ', vec_num) return myModel.predict(img) # 生成信息 path_num = 0 def get_info(path): path = os.path.split(path)[0] path, num_and_player = os.path.split(path) num = num_and_player.split(' ')[0] player = ' '.join(os.path.split(num_and_player)[-1].split(' ')[1:]) path, year = os.path.split(path) series = os.path.split(path)[1] rtn = "{} {} {} #{}".format(series, year, player, num) global path_num path_num += 1 print(path_num, " loading " + rtn) return rtn def read_imgID(results): imgIDs = [] for re in results: # 输出结果图片信息 print('---------', re) imgIDs.append(re.id) return imgIDs # def yolo_detect(img): # results = yolo_model(img) # # pred = results.pred[0][:, :4].cpu().numpy() # boxes = pred.astype(np.int32) # # max_img = get_object(img, boxes) # # global yolo_num # yolo_num += 1 # print("yolo_num: ", yolo_num) # return max_img # # # def get_object(img, boxes): # if isinstance(img, str): # img = Image.open(img) # # if len(boxes) == 0: # return img # # max_area = 0 # # # 选出最大的框 # x1, y1, x2, y2 = 0, 0, 0, 0 # for box in boxes: # temp_x1, temp_y1, temp_x2, temp_y2 = box # area = (temp_x2 - temp_x1) * (temp_y2 - temp_y1) # if area > max_area: # max_area = area # x1, y1, x2, y2 = temp_x1, temp_y1, temp_x2, temp_y2 # # max_img = img.crop((x1, y1, x2, y2)) # return max_img # 创建向量数据库 def create_milvus_collection(collection_name, dim): if utility.has_collection(collection_name): utility.drop_collection(collection_name) fields = [ FieldSchema(name='img_id', dtype=DataType.INT64, is_primary=True), FieldSchema(name='path', dtype=DataType.VARCHAR, max_length=300), FieldSchema(name="info", dtype=DataType.VARCHAR, max_length=300), FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, descrition='image embedding vectors', dim=dim) ] schema = CollectionSchema(fields=fields, description='reverse image search') collection = Collection(name=collection_name, schema=schema) index_params = { 'metric_type': 'L2', 'index_type': "IVF_FLAT", 'params': {"nlist": dim} } collection.create_index(field_name="embedding", index_params=index_params) return collection # 判断是否加载已有数据库,或新创建数据库 def is_creat_collection(have_coll, collection_name): if have_coll: # 连接现有的数据库 collection = Collection(name=collection_name) else: # 新建立数据库 collection = create_milvus_collection(collection_name, 2560) dc = ( towhee.glob['path'](*dataset_path) .runas_op['path', 'img_id'](func=get_id) .runas_op['path', 'info'](func=get_info) # .image_decode['path', 'img']() # .runas_op['path', "object"](yolo_detect) .runas_op['path', 'vec'](func=img2vec) .tensor_normalize['vec', 'vec']() # .image_embedding.timm['img', 'vec'](model_name='resnet50') .ann_insert.milvus[('img_id', 'path', 'info', 'vec'), 'mr'](collection=collection) ) print('Total number of inserted data is {}.'.format(collection.num_entities)) return collection # 通过ID查询 def query_by_imgID(collection, img_id, limit=1): expr = 'img_id == ' + str(img_id) res = collection.query(expr, output_fields=["path", "info"], offset=0, limit=limit, timeout=2) return res # 分别返回 编号,年份,系列 def from_path_get_info(path): card_info = [] for i in range(3): path = os.path.split(path)[0] card_info.append(os.path.split(path)[-1]) card_info[0] = card_info[0].split('#')[-1] return card_info def from_query_path_get_info(path): card_info = [] for i in range(3): path = os.path.split(path)[0] card_info.append(os.path.split(path)[-1]) card_info[0] = card_info[0].split(' ')[0] return card_info if __name__ == '__main__': print('start') # 是否存在数据库 have_coll = True # 默认模型 # collection = is_creat_collection(have_coll=have_coll, collection_name="reverse_image_search") # 自定义模型 collection = is_creat_collection(have_coll=have_coll, collection_name="reverse_image_search_myModel") # 测试的图片路径 img_path = ["D:/Code/ML/images/test02/test(mosaic,pz)/*/*/*/*"] data = (towhee.glob['path'](*img_path) # image_decode['path', 'img'](). # .runas_op['path', "object"](yolo_detect) .runas_op['path', 'vec'](func=img2vec) .tensor_normalize['vec', 'vec']() # image_embedding.timm['img', 'vec'](model_name='resnet50'). .ann_search.milvus['vec', 'result'](collection=collection, limit=3) .runas_op['result', 'result_imgID'](func=read_imgID) .select['path', 'result_imgID', 'vec']() ) print(data) collection.load() # res = query_by_imgID(collection, data[0].result_imgID[0]) # # print(res[0]) top3_num = 0 top1_num = 0 test_img_num = len(list(data)) # 查询所有测试图片 for i in range(test_img_num): top3_flag = False # 获取图片真正的编号, 年份, 系列 source_code, source_year, source_series = from_path_get_info(data[i].path) # 每个测试图片返回三个最相似的图片ID,一一测试 for j in range(3): res = query_by_imgID(collection, data[i].result_imgID[j]) # 获取预测的图片的编号, 年份, 系列 result_code, result_year, result_series = from_query_path_get_info(res[0]['path']) # 判断top1是否正确 if j == 0 and source_code == result_code and source_year == result_year and source_series == result_series: top1_num += 1 print(top1_num) elif j == 0: print('top_1 错误') # top3中有一个正确的标记为正确 if source_code == result_code and source_year == result_year and source_series == result_series: top3_flag = True print("series: {}, year: {},code: {} === result - series: {}, year: {}, code: {}".format( source_series, source_year, source_code, result_series, result_year, result_code, )) if top3_flag: top3_num += 1 print("====================================") print("测试图片共: ", test_img_num) top1_accuracy = (top1_num / test_img_num) * 100 top3_accuracy = (top3_num / test_img_num) * 100 print("top3 准确率:{} % \n top1 准确率: {} %". format(top3_accuracy, top1_accuracy)) ''' '''