| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267 |
- import towhee
- import cv2
- from towhee._types.image import Image
- import os
- import PIL.Image as Image
- import numpy as np
- from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
- from MyEfficientNet import MyModel
- import torch
- connections.connect(host='127.0.0.1', port='19530')
- dataset_path = ["D:\Code\ML\images\Mywork3\card_database_yolo/*/*/*/*"]
- img_id = 0
- yolo_num = 0
- vec_num = 0
- myModel = MyModel(r"D:\Code\ML\model\card_cls\effcient_card_out854_freeze2.pth", out_features=854)
- # yolo_model = torch.hub.load(r"C:\Users\Administrator\.cache\torch\hub\ultralytics_yolov5_master", 'custom',
- # path="yolov5s.pt", source='local')
- # yolo_model = torch.hub.load("ultralytics/yolov5", "yolov5s")
- # 生成ID
- def get_id(param):
- global img_id
- img_id += 1
- return img_id
- # def eff_enbedding(img):
- # global vec_num
- # vec_num += 1
- # print('vec: ', vec_num)
- # return myModel.predict(img)
- # 生成向量
- def img2vec(img):
- global vec_num
- vec_num += 1
- print('vec: ', vec_num)
- return myModel.predict(img)
- # 生成信息
- path_num = 0
- def get_info(path):
- path = os.path.split(path)[0]
- path, num_and_player = os.path.split(path)
- num = num_and_player.split(' ')[0]
- player = ' '.join(os.path.split(num_and_player)[-1].split(' ')[1:])
- path, year = os.path.split(path)
- series = os.path.split(path)[1]
- rtn = "{} {} {} #{}".format(series, year, player, num)
- global path_num
- path_num += 1
- print(path_num, " loading " + rtn)
- return rtn
- def read_imgID(results):
- imgIDs = []
- for re in results:
- # 输出结果图片信息
- print('---------', re)
- imgIDs.append(re.id)
- return imgIDs
- # def yolo_detect(img):
- # results = yolo_model(img)
- #
- # pred = results.pred[0][:, :4].cpu().numpy()
- # boxes = pred.astype(np.int32)
- #
- # max_img = get_object(img, boxes)
- #
- # global yolo_num
- # yolo_num += 1
- # print("yolo_num: ", yolo_num)
- # return max_img
- #
- #
- # def get_object(img, boxes):
- # if isinstance(img, str):
- # img = Image.open(img)
- #
- # if len(boxes) == 0:
- # return img
- #
- # max_area = 0
- #
- # # 选出最大的框
- # x1, y1, x2, y2 = 0, 0, 0, 0
- # for box in boxes:
- # temp_x1, temp_y1, temp_x2, temp_y2 = box
- # area = (temp_x2 - temp_x1) * (temp_y2 - temp_y1)
- # if area > max_area:
- # max_area = area
- # x1, y1, x2, y2 = temp_x1, temp_y1, temp_x2, temp_y2
- #
- # max_img = img.crop((x1, y1, x2, y2))
- # return max_img
- # 创建向量数据库
- def create_milvus_collection(collection_name, dim):
- if utility.has_collection(collection_name):
- utility.drop_collection(collection_name)
- fields = [
- FieldSchema(name='img_id', dtype=DataType.INT64, is_primary=True),
- FieldSchema(name='path', dtype=DataType.VARCHAR, max_length=300),
- FieldSchema(name="info", dtype=DataType.VARCHAR, max_length=300),
- FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, descrition='image embedding vectors', dim=dim)
- ]
- schema = CollectionSchema(fields=fields, description='reverse image search')
- collection = Collection(name=collection_name, schema=schema)
- index_params = {
- 'metric_type': 'L2',
- 'index_type': "IVF_FLAT",
- 'params': {"nlist": dim}
- }
- collection.create_index(field_name="embedding", index_params=index_params)
- return collection
- # 判断是否加载已有数据库,或新创建数据库
- def is_creat_collection(have_coll, collection_name):
- if have_coll:
- # 连接现有的数据库
- collection = Collection(name=collection_name)
- else:
- # 新建立数据库
- collection = create_milvus_collection(collection_name, 2560)
- dc = (
- towhee.glob['path'](*dataset_path)
- .runas_op['path', 'img_id'](func=get_id)
- .runas_op['path', 'info'](func=get_info)
- # .image_decode['path', 'img']()
- # .runas_op['path', "object"](yolo_detect)
- .runas_op['path', 'vec'](func=img2vec)
- .tensor_normalize['vec', 'vec']()
- # .image_embedding.timm['img', 'vec'](model_name='resnet50')
- .ann_insert.milvus[('img_id', 'path', 'info', 'vec'), 'mr'](collection=collection)
- )
- print('Total number of inserted data is {}.'.format(collection.num_entities))
- return collection
- # 通过ID查询
- def query_by_imgID(collection, img_id, limit=1):
- expr = 'img_id == ' + str(img_id)
- res = collection.query(expr, output_fields=["path", "info"], offset=0, limit=limit, timeout=2)
- return res
- # 分别返回 编号,年份,系列
- def from_path_get_info(path):
- card_info = []
- for i in range(3):
- path = os.path.split(path)[0]
- card_info.append(os.path.split(path)[-1])
- card_info[0] = card_info[0].split('#')[-1]
- return card_info
- def from_query_path_get_info(path):
- card_info = []
- for i in range(3):
- path = os.path.split(path)[0]
- card_info.append(os.path.split(path)[-1])
- card_info[0] = card_info[0].split(' ')[0]
- return card_info
- if __name__ == '__main__':
- print('start')
- # 是否存在数据库
- have_coll = True
- # 默认模型
- # collection = is_creat_collection(have_coll=have_coll, collection_name="reverse_image_search")
- # 自定义模型
- collection = is_creat_collection(have_coll=have_coll, collection_name="reverse_image_search_myModel")
- # 测试的图片路径
- img_path = ["D:/Code/ML/images/test02/test(mosaic,pz)/*/*/*/*"]
- data = (towhee.glob['path'](*img_path)
- # image_decode['path', 'img']().
- # .runas_op['path', "object"](yolo_detect)
- .runas_op['path', 'vec'](func=img2vec)
- .tensor_normalize['vec', 'vec']()
- # image_embedding.timm['img', 'vec'](model_name='resnet50').
- .ann_search.milvus['vec', 'result'](collection=collection, limit=3)
- .runas_op['result', 'result_imgID'](func=read_imgID)
- .select['path', 'result_imgID', 'vec']()
- )
- print(data)
- collection.load()
- # res = query_by_imgID(collection, data[0].result_imgID[0])
- #
- # print(res[0])
- top3_num = 0
- top1_num = 0
- test_img_num = len(list(data))
- # 查询所有测试图片
- for i in range(test_img_num):
- top3_flag = False
- # 获取图片真正的编号, 年份, 系列
- source_code, source_year, source_series = from_path_get_info(data[i].path)
- # 每个测试图片返回三个最相似的图片ID,一一测试
- for j in range(3):
- res = query_by_imgID(collection, data[i].result_imgID[j])
- # 获取预测的图片的编号, 年份, 系列
- result_code, result_year, result_series = from_query_path_get_info(res[0]['path'])
- # 判断top1是否正确
- if j == 0 and source_code == result_code and source_year == result_year and source_series == result_series:
- top1_num += 1
- print(top1_num)
- elif j == 0:
- print('top_1 错误')
- # top3中有一个正确的标记为正确
- if source_code == result_code and source_year == result_year and source_series == result_series:
- top3_flag = True
- print("series: {}, year: {},code: {} === result - series: {}, year: {}, code: {}".format(
- source_series, source_year, source_code, result_series, result_year, result_code,
- ))
- if top3_flag:
- top3_num += 1
- print("====================================")
- print("测试图片共: ", test_img_num)
- top1_accuracy = (top1_num / test_img_num) * 100
- top3_accuracy = (top3_num / test_img_num) * 100
- print("top3 准确率:{} % \n top1 准确率: {} %".
- format(top3_accuracy, top1_accuracy))
- '''
-
- '''
|