AnlaAnla 2 years ago
commit
afd8c94c61
17 changed files with 907 additions and 0 deletions
  1. 8 0
      .idea/.gitignore
  2. 6 0
      .idea/inspectionProfiles/profiles_settings.xml
  3. 4 0
      .idea/misc.xml
  4. 8 0
      .idea/modules.xml
  5. 7 0
      .idea/other.xml
  6. 11 0
      .idea/towhee_test.iml
  7. 6 0
      .idea/vcs.xml
  8. 78 0
      Image2YoyoImage.py
  9. 279 0
      Milvus_Test.py
  10. 64 0
      MyEfficientNet.py
  11. 67 0
      MyModel.py
  12. 9 0
      ResnetTest.py
  13. 10 0
      test01.py
  14. 0 0
      test02.py
  15. 285 0
      test03.py
  16. 7 0
      test04.py
  17. 58 0
      yolo_object.py

+ 8 - 0
.idea/.gitignore

@@ -0,0 +1,8 @@
+# 默认忽略的文件
+/shelf/
+/workspace.xml
+# 基于编辑器的 HTTP 客户端请求
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml

+ 6 - 0
.idea/inspectionProfiles/profiles_settings.xml

@@ -0,0 +1,6 @@
+<component name="InspectionProjectProfileManager">
+  <settings>
+    <option name="USE_PROJECT_PROFILE" value="false" />
+    <version value="1.0" />
+  </settings>
+</component>

+ 4 - 0
.idea/misc.xml

@@ -0,0 +1,4 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="ProjectRootManager" version="2" project-jdk-name="pytorch" project-jdk-type="Python SDK" />
+</project>

+ 8 - 0
.idea/modules.xml

@@ -0,0 +1,8 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="ProjectModuleManager">
+    <modules>
+      <module fileurl="file://$PROJECT_DIR$/.idea/towhee_test.iml" filepath="$PROJECT_DIR$/.idea/towhee_test.iml" />
+    </modules>
+  </component>
+</project>

+ 7 - 0
.idea/other.xml

@@ -0,0 +1,7 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="PySciProjectComponent">
+    <option name="PY_SCI_VIEW" value="true" />
+    <option name="PY_SCI_VIEW_SUGGESTED" value="true" />
+  </component>
+</project>

+ 11 - 0
.idea/towhee_test.iml

@@ -0,0 +1,11 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<module type="PYTHON_MODULE" version="4">
+  <component name="NewModuleRootManager">
+    <content url="file://$MODULE_DIR$" />
+    <orderEntry type="jdk" jdkName="pytorch" jdkType="Python SDK" />
+    <orderEntry type="sourceFolder" forTests="false" />
+  </component>
+  <component name="PyDocumentationSettings">
+    <option name="renderExternalDocumentation" value="true" />
+  </component>
+</module>

+ 6 - 0
.idea/vcs.xml

@@ -0,0 +1,6 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="VcsDirectoryMappings">
+    <mapping directory="$PROJECT_DIR$" vcs="Git" />
+  </component>
+</project>

+ 78 - 0
Image2YoyoImage.py

@@ -0,0 +1,78 @@
+import towhee
+import torch
+import os
+import glob
+from PIL import Image, ImageOps
+from MyModel import MyModel
+import numpy as np
+
+vec_num = 0
+yolo_model = torch.hub.load(r"C:\Users\Administrator\.cache\torch\hub\ultralytics_yolov5_master", 'custom',
+                            path="yolov5s.pt", source='local')
+
+dataset_path = [r"D:\Code\ML\images\Mywork3\card_database\prizm\21-22\*\*"]
+yolo_dataset_dir = r"D:\Code\ML\images\Mywork3\card_database_yolo"
+
+
+def get_save_dir(save_dir, source_path):
+    path02, path01 = os.path.split(source_path)
+    path03, path02 = os.path.split(path02)
+    path04, path03 = os.path.split(path03)
+    path05, path04 = os.path.split(path04)
+
+    return os.path.join(save_dir, path04, path03, path02, path01)
+
+
+def yolo_detect(img_path):
+    dest_path = get_save_dir(yolo_dataset_dir, img_path)
+    save_dir = os.path.split(dest_path)[0]
+
+    # 如果已经存在这个yolo检测后的图片
+    if os.path.exists(dest_path):
+        print("----已经存在 ", dest_path)
+        return
+
+    img = Image.open(img_path)
+    img = ImageOps.exif_transpose(img)
+    results = yolo_model(img)
+
+    pred = results.pred[0][:, :4].cpu().numpy()
+    boxes = pred.astype(np.int32)
+
+    max_img = get_object(img_path, boxes)
+
+
+
+    if not os.path.exists(save_dir):
+        os.makedirs(save_dir)
+    max_img.save(dest_path)
+    print(dest_path)
+
+
+def get_object(img, boxes):
+    if isinstance(img, str):
+        img = Image.open(img)
+
+    if len(boxes) == 0:
+        return img
+
+    max_area = 0
+
+    # 选出最大的框
+    x1, y1, x2, y2 = 0, 0, 0, 0
+    for box in boxes:
+        temp_x1, temp_y1, temp_x2, temp_y2 = box
+        area = (temp_x2 - temp_x1) * (temp_y2 - temp_y1)
+        if area > max_area:
+            max_area = area
+            x1, y1, x2, y2 = temp_x1, temp_y1, temp_x2, temp_y2
+
+    max_img = img.crop((x1, y1, x2, y2))
+    return max_img
+
+
+data = (towhee.glob['path'](*dataset_path)
+        .runas_op['path', ''](yolo_detect)
+        )
+
+print('end')

+ 279 - 0
Milvus_Test.py

@@ -0,0 +1,279 @@
+import towhee
+import cv2
+from towhee._types.image import Image
+import os
+import PIL.Image as Image
+import numpy as np
+from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
+
+from MyModel import MyModel
+from MyEfficientNet import MyEfficient
+import torch
+from transformers import ViTFeatureExtractor, ViTModel
+from towhee.types.image_utils import to_image_color
+
+connections.connect(host='127.0.0.1', port='19530')
+dataset_path = ["D:\Code\ML\images\Mywork3\card_database_yolo/*/*/*/*"]
+
+img_id = 0
+vec_num = 0
+myModel = MyModel(r"D:\Code\ML\model\card_cls\res_card_out764_freeze4.pth", out_features=764)
+# myModel = MyModel(r"C:\Users\Administrator\.cache\torch\hub\checkpoints\resnet50-0676ba61.pth", out_features=1000)
+
+# myModel = MyEfficient('')
+
+
+yolo_model = torch.hub.load(r"C:\Users\Administrator\.cache\torch\hub\ultralytics_yolov5_master", 'custom',
+                            path="yolov5s.pt", source='local')
+
+
+# yolo_model = torch.hub.load("ultralytics/yolov5", "yolov5s")
+
+
+# 生成ID
+def get_id(param):
+    global img_id
+    img_id += 1
+    return img_id
+
+
+# def eff_enbedding(img):
+#     global vec_num
+#     vec_num += 1
+#     print('vec: ', vec_num)
+#     return myModel.predict(img)
+
+# 生成向量
+def img2vec(img):
+    global vec_num
+    vec_num += 1
+    print('vec: ', vec_num)
+    return myModel.predict(img)
+
+
+# 生成信息
+path_num = 0
+
+
+def get_info(path):
+    path = os.path.split(path)[0]
+
+    path, num_and_player = os.path.split(path)
+    num = num_and_player.split(' ')[0]
+    player = ' '.join(os.path.split(num_and_player)[-1].split(' ')[1:])
+    path, year = os.path.split(path)
+    series = os.path.split(path)[1]
+    rtn = "{} {} {} #{}".format(series, year, player, num)
+
+    global path_num
+    path_num += 1
+    print(path_num, " loading " + rtn)
+    return rtn
+
+
+def read_imgID(results):
+    imgIDs = []
+    for re in results:
+        # 输出结果图片信息
+        print('---------', re)
+        imgIDs.append(re.id)
+    return imgIDs
+
+
+def yolo_detect(img):
+    results = yolo_model(img)
+
+    pred = results.pred[0][:, :4].cpu().numpy()
+    boxes = pred.astype(np.int32)
+
+    max_img = get_object(img, boxes)
+    return max_img
+
+
+def get_object(img, boxes):
+    if isinstance(img, str):
+        img = Image.open(img)
+
+    if len(boxes) == 0:
+        return img
+
+    max_area = 0
+
+    # 选出最大的框
+    x1, y1, x2, y2 = 0, 0, 0, 0
+    for box in boxes:
+        temp_x1, temp_y1, temp_x2, temp_y2 = box
+        area = (temp_x2 - temp_x1) * (temp_y2 - temp_y1)
+        if area > max_area:
+            max_area = area
+            x1, y1, x2, y2 = temp_x1, temp_y1, temp_x2, temp_y2
+
+    max_img = img.crop((x1, y1, x2, y2))
+    return max_img
+
+
+# 创建向量数据库
+def create_milvus_collection(collection_name, dim):
+    if utility.has_collection(collection_name):
+        utility.drop_collection(collection_name)
+
+    fields = [
+        FieldSchema(name='img_id', dtype=DataType.INT64, is_primary=True),
+        FieldSchema(name='path', dtype=DataType.VARCHAR, max_length=300),
+        FieldSchema(name="info", dtype=DataType.VARCHAR, max_length=300),
+        FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, descrition='image embedding vectors', dim=dim)
+    ]
+    schema = CollectionSchema(fields=fields, description='reverse image search')
+    collection = Collection(name=collection_name, schema=schema)
+
+    index_params = {
+        'metric_type': 'L2',
+        'index_type': "IVF_FLAT",
+        'params': {"nlist": dim}
+    }
+    collection.create_index(field_name="embedding", index_params=index_params)
+    return collection
+
+
+# 判断是否加载已有数据库,或新创建数据库
+def is_creat_collection(have_coll, collection_name):
+    if have_coll:
+        # 连接现有的数据库
+        collection = Collection(name=collection_name)
+    else:
+        # 新建立数据库
+        collection = create_milvus_collection(collection_name, 2048)
+        dc = (
+            towhee.glob['path'](*dataset_path)
+            .runas_op['path', 'img_id'](func=get_id)
+            .runas_op['path', 'info'](func=get_info)
+            # .image_decode['path', 'img']()
+            # .runas_op['path', "object"](yolo_detect)
+            .runas_op['path', 'vec'](func=img2vec)
+            .tensor_normalize['vec', 'vec']()
+            # .image_embedding.timm['img', 'vec'](model_name='resnet50')
+            .ann_insert.milvus[('img_id', 'path', 'info', 'vec'), 'mr'](collection=collection)
+        )
+
+    print('Total number of inserted data is {}.'.format(collection.num_entities))
+    return collection
+
+
+# 通过ID查询
+def query_by_imgID(collection, img_id, limit=1):
+    expr = 'img_id == ' + str(img_id)
+    res = collection.query(expr, output_fields=["path", "info"], offset=0, limit=limit, timeout=2)
+    return res
+
+
+def from_path_get_series(path):
+    for i in range(3):
+        path = os.path.split(path)[0]
+    series = os.path.split(path)[-1]
+
+    return series
+
+
+if __name__ == '__main__':
+    print('start')
+
+    # 是否存在数据库
+    have_coll = False
+
+    # 默认模型
+    # collection = is_creat_collection(have_coll=have_coll, collection_name="reverse_image_search")
+    # 自定义模型
+    collection = is_creat_collection(have_coll=have_coll, collection_name="reverse_image_search_myModel")
+
+    # 测试的图片路径
+    img_path = ["D:/Code/ML/images/test02/test2/*/*/*/*"]
+
+    data = (towhee.glob['path'](*img_path)
+            # image_decode['path', 'img']().
+            .runas_op['path', "object"](yolo_detect)
+            .runas_op['object', 'vec'](func=img2vec)
+            .tensor_normalize['vec', 'vec']()
+            # image_embedding.timm['img', 'vec'](model_name='resnet50').
+            .ann_search.milvus['vec', 'result'](collection=collection, limit=3)
+            .runas_op['result', 'result_imgID'](func=read_imgID)
+            .select['path', 'result_imgID', 'vec']()
+            )
+
+    print(data)
+
+    collection.load()
+    # res = query_by_imgID(collection, data[0].result_imgID[0])
+    #
+    # print(res[0])
+
+
+    top3_num = 0
+    top1_num = 0
+    test_img_num = len(list(data))
+
+    # 查询所有测试图片
+    for i in range(test_img_num):
+        top3_flag = False
+
+        # 获取图片真正的系列
+        source_card_series = from_path_get_series(data[i].path)
+        # 获取图片真正的编号
+        source_num = os.path.split(os.path.split(data[i].path)[0])[-1].split('#')[-1]
+
+        # 每个测试图片返回三个最相似的图片ID,一一测试
+        for j in range(3):
+            res = query_by_imgID(collection, data[i].result_imgID[j])
+
+            # 获取预测的图片的系列
+            result_card_series = from_path_get_series(res[0]['path'])
+            # 获取预测的图片的编号
+            result_num = os.path.split(os.path.split(res[0]['path'])[0])[-1].split(' ')[0].split('#')[-1]
+
+            # 判断top1是否正确
+            if j == 0 and source_num == result_num and source_card_series == result_card_series:
+                top1_num += 1
+
+            # top3中有一个正确的标记为正确
+            if source_num == result_num and source_card_series == result_card_series:
+                top3_flag = True
+
+            # 日志
+            if j == 0 and source_num == result_num and source_card_series == result_card_series:
+                print(top1_num)
+            elif j == 0:
+                print('top_1 错误')
+            print("series: {}, num: {} === result - series: {}, num: {}".format(
+                source_card_series, source_num, result_card_series, result_num
+            ))
+
+        if top3_flag:
+            top3_num += 1
+
+        print("====================================")
+
+    print("测试图片共: ", test_img_num)
+    top1_accuracy = (top1_num / test_img_num) * 100
+    top3_accuracy = (top3_num / test_img_num) * 100
+
+    print("top3 准确率:{} % \n top1 准确率: {} %".
+          format(top3_accuracy, top1_accuracy))
+
+'''
+ 测试图片共:  168
+ 自定义resnet50_freeze_out421 + yolo + normalize
+top3 准确率:96.42857142857143 % 
+ top1 准确率: 95.23809523809523 %
+
+
+测试图片: 773, 数据库图片: 5848
+自定义resnet50_freeze_out421 + yolo + normalize
+测试图片共:  773
+top3 准确率:96.63648124191462 % 
+ top1 准确率: 95.60155239327295 %
+
+ 
+ 测试图片: 773, 数据库图片: 5848
+ 自定义resnet50_out764_freeze + yolo + normalize
+top3 准确率:96.76584734799482 % 
+ top1 准确率: 96.50711513583441 %
+'''

+ 64 - 0
MyEfficientNet.py

@@ -0,0 +1,64 @@
+import torch
+import torchvision.models as models
+import torchvision.transforms as transforms
+import cv2
+from PIL import Image
+import numpy as np
+import timm
+
+
+class MyEfficient:
+    def __init__(self, model_dict_path, out_features=2560):
+        self.out_features = out_features
+        self.norm_mean = [0.485, 0.456, 0.406]
+        self.norm_std = [0.229, 0.224, 0.225]
+
+        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+        self.model = models.efficientnet_b7(pretrained=True)
+        # self.model.fc = torch.nn.Linear(in_features=2048, out_features=self.out_features)
+        # self.model.load_state_dict(torch.load(model_dict_path))
+
+        # 自定义模型
+        # print(list(self.model.children()))
+        features = list(self.model.children())[:-1]  # 去掉最后一部分
+        self.model = torch.nn.Sequential(*features).to(self.device)
+
+        self.model.eval()
+        # self.model.to(self.device)
+
+    def inference_transform(self):
+        inference_transform = transforms.Compose([
+            transforms.Resize(256),
+            transforms.CenterCrop(224),
+            transforms.ToTensor(),
+            transforms.Normalize(self.norm_mean, self.norm_std),
+        ])
+        return inference_transform
+
+    def img_transform(self, img_rgb, transform=None):
+        # 将数据转换为模型读取的形式
+        if transform is None:
+            raise ValueError("找不到transform!必须有transform对img进行处理")
+
+        img_t = transform(img_rgb)
+        return img_t
+
+    def get_model(self):
+        return self.model
+
+    # 输出图片路径或者cv2格式的图片数据
+    def predict(self, img):
+        if type(img) == type('path'):
+            img = Image.open(img).convert('RGB')
+
+        transform = self.inference_transform()
+
+        img_tensor = transform(img)
+        img_tensor.unsqueeze_(0)
+        img_tensor = img_tensor.to(self.device)
+        # print(img.shape)
+
+        with torch.no_grad():
+            outputs = self.model(img_tensor)
+        return outputs.reshape(2560).cpu().numpy()

+ 67 - 0
MyModel.py

@@ -0,0 +1,67 @@
+import torch
+import torchvision.models as models
+import torchvision.transforms as transforms
+import cv2
+from PIL import Image
+import numpy as np
+import timm
+
+
+class MyModel:
+    def __init__(self, model_dict_path, out_features=2048):
+        self.out_features = out_features
+        self.norm_mean = [0.485, 0.456, 0.406]
+        self.norm_std = [0.229, 0.224, 0.225]
+
+        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+        self.model = models.resnet50(pretrained=False)
+        self.model.fc = torch.nn.Linear(in_features=2048, out_features=self.out_features)
+
+        self.model.load_state_dict(torch.load(model_dict_path, map_location=self.device))
+        # self.model = timm.create_model('resnet50', num_classes=2048, pretrained=True)
+        self.model.eval()
+
+        # 自定义模型
+        # print(list(self.model.children()))
+        features = list(self.model.children())[:-1]  # 去掉全连接层
+        self.model = torch.nn.Sequential(*features).to(self.device)
+
+        # self.model.to(self.device)
+
+
+    def inference_transform(self):
+        inference_transform = transforms.Compose([
+            transforms.Resize(256),
+            transforms.CenterCrop(224),
+            transforms.ToTensor(),
+            transforms.Normalize(self.norm_mean, self.norm_std),
+        ])
+        return inference_transform
+
+    def img_transform(self, img_rgb, transform=None):
+        # 将数据转换为模型读取的形式
+        if transform is None:
+            raise ValueError("找不到transform!必须有transform对img进行处理")
+
+        img_t = transform(img_rgb)
+        return img_t
+
+    def get_model(self):
+        return self.model
+
+    # 输出图片路径或者cv2格式的图片数据
+    def predict(self, img):
+        if type(img) == type('path'):
+            img = Image.open(img).convert('RGB')
+        img = img.convert('RGB')
+        transform = self.inference_transform()
+
+        img_tensor = transform(img)
+        img_tensor.unsqueeze_(0)
+        img_tensor = img_tensor.to(self.device)
+        # print(img.shape)
+
+        with torch.no_grad():
+            outputs = self.model(img_tensor)
+        return outputs.reshape(2048).cpu().numpy()

+ 9 - 0
ResnetTest.py

@@ -0,0 +1,9 @@
+import torch
+import torchvision.models as models
+
+model = models.resnet50(pretrained=False)
+model.fc = torch.nn.Linear(in_features=2048, out_features=314)
+
+model.load_state_dict(torch.load(r"D:\Code\ML\model\card_cls\res_card_freeze2.pth"))
+
+print(model)

+ 10 - 0
test01.py

@@ -0,0 +1,10 @@
+import os
+
+dir_path = r"D:\Code\ML\images\Mywork3\card_database_no_compress\prizm\21-22"
+
+for name in os.listdir(dir_path):
+    dir_name = os.path.join(dir_path, name)
+    new_name = os.path.join(dir_path, name.split(' ')[0])
+
+    os.rename(dir_name, new_name)
+    print(name, ' ====>> ', name.split(' ')[0])

+ 0 - 0
test02.py


+ 285 - 0
test03.py

@@ -0,0 +1,285 @@
+import towhee
+import cv2
+from towhee._types.image import Image
+import os
+import PIL.Image as Image
+import numpy as np
+from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
+
+from MyModel import MyModel
+from MyEfficientNet import MyEfficient
+import torch
+from transformers import ViTFeatureExtractor, ViTModel
+from towhee.types.image_utils import to_image_color
+
+connections.connect(host='127.0.0.1', port='19530')
+dataset_path = ["D:/Code/ML/images/Mywork3/card_database/prizm/21-22/*/*.JPG",
+                "D:/Code/ML/images/Mywork3/card_database/mosaic/*/*/*.JPG"]
+
+img_id = 0
+vec_num = 0
+myModel = MyModel(r"D:\Code\ML\model\card_cls\res_card_out764_freeze4.pth", out_features=764)
+# myModel = MyModel(r"C:\Users\Administrator\.cache\torch\hub\checkpoints\resnet50-0676ba61.pth", out_features=1000)
+
+# myModel = MyEfficient('')
+
+
+yolo_model = torch.hub.load(r"C:\Users\Administrator\.cache\torch\hub\ultralytics_yolov5_master", 'custom',
+                            path="yolov5s.pt", source='local')
+
+
+# yolo_model = torch.hub.load("ultralytics/yolov5", "yolov5s")
+
+
+# 生成ID
+def get_id(param):
+    global img_id
+    img_id += 1
+    return img_id
+
+
+# def eff_enbedding(img):
+#     global vec_num
+#     vec_num += 1
+#     print('vec: ', vec_num)
+#     return myModel.predict(img)
+
+# 生成向量
+def img2vec(img):
+    global vec_num
+    vec_num += 1
+    print('vec: ', vec_num)
+    return myModel.predict(img)
+
+
+# 生成信息
+path_num = 0
+
+
+def get_info(path):
+    path = os.path.split(path)[0]
+
+    path, num_and_player = os.path.split(path)
+    num = num_and_player.split(' ')[0]
+    player = ' '.join(os.path.split(num_and_player)[-1].split(' ')[1:])
+    path, year = os.path.split(path)
+    series = os.path.split(path)[1]
+    rtn = "{} {} {} #{}".format(series, year, player, num)
+
+    global path_num
+    path_num += 1
+    print(path_num, " loading " + rtn)
+    return rtn
+
+def read_imgID(results):
+    imgIDs = []
+    for re in results:
+        # 输出结果图片信息
+        print('---------', re)
+        imgIDs.append(re.id)
+    return imgIDs
+
+
+def yolo_detect(img):
+    results = yolo_model(img)
+
+    pred = results.pred[0][:, :4].cpu().numpy()
+    boxes = pred.astype(np.int32)
+
+    max_img = get_object(img, boxes)
+    return max_img
+
+
+def get_object(img, boxes):
+    if isinstance(img, str):
+        img = Image.open(img)
+
+    if len(boxes) == 0:
+        return img
+
+    max_area = 0
+
+    # 选出最大的框
+    x1, y1, x2, y2 = 0, 0, 0, 0
+    for box in boxes:
+        temp_x1, temp_y1, temp_x2, temp_y2 = box
+        area = (temp_x2 - temp_x1) * (temp_y2 - temp_y1)
+        if area > max_area:
+            max_area = area
+            x1, y1, x2, y2 = temp_x1, temp_y1, temp_x2, temp_y2
+
+    max_img = img.crop((x1, y1, x2, y2))
+    return max_img
+
+
+# 创建向量数据库
+def create_milvus_collection(collection_name, dim):
+    if utility.has_collection(collection_name):
+        utility.drop_collection(collection_name)
+
+    fields = [
+        FieldSchema(name='img_id', dtype=DataType.INT64, is_primary=True),
+        FieldSchema(name='path', dtype=DataType.VARCHAR, max_length=300),
+        FieldSchema(name="info", dtype=DataType.VARCHAR, max_length=300),
+        FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, descrition='image embedding vectors', dim=dim)
+    ]
+    schema = CollectionSchema(fields=fields, description='reverse image search')
+    collection = Collection(name=collection_name, schema=schema)
+
+    index_params = {
+        'metric_type': 'L2',
+        'index_type': "IVF_FLAT",
+        'params': {"nlist": dim}
+    }
+    collection.create_index(field_name="embedding", index_params=index_params)
+    return collection
+
+
+# 判断是否加载已有数据库,或新创建数据库
+def is_creat_collection(have_coll, collection_name):
+    if have_coll:
+        # 连接现有的数据库
+        collection = Collection(name=collection_name)
+    else:
+        # 新建立数据库
+        collection = create_milvus_collection(collection_name, 2048)
+        dc = (
+            towhee.glob['path'](*dataset_path)
+            .runas_op['path', 'img_id'](func=get_id)
+            .runas_op['path', 'info'](func=get_info)
+            # .image_decode['path', 'img']()
+            .runas_op['path', "object"](yolo_detect)
+            .runas_op['object', 'vec'](func=img2vec)
+            .tensor_normalize['vec', 'vec']()
+            # .image_embedding.timm['img', 'vec'](model_name='resnet50')
+            .ann_insert.milvus[('img_id', 'path', 'info', 'vec'), 'mr'](collection=collection)
+        )
+
+    print('Total number of inserted data is {}.'.format(collection.num_entities))
+    return collection
+
+
+# 通过ID查询
+def query_by_imgID(collection, img_id, limit=1):
+    expr = 'img_id == ' + str(img_id)
+    res = collection.query(expr, output_fields=["path", "info"], offset=0, limit=limit, timeout=2)
+    return res
+
+
+def from_path_get_series(path):
+    for i in range(3):
+        path = os.path.split(path)[0]
+    series = os.path.split(path)[-1]
+
+    return series
+
+
+if __name__ == '__main__':
+    print('start')
+
+    # 是否存在数据库
+    have_coll = True
+
+    # 默认模型
+    # collection = is_creat_collection(have_coll=have_coll, collection_name="reverse_image_search")
+    # 自定义模型
+    collection = is_creat_collection(have_coll=have_coll, collection_name="reverse_image_search_myModel")
+
+
+    # 测试的图片路径
+    img_path = ["D:/Code/ML/images/test02/test2/prizm/base 21-22/*/*.jpg",
+                "D:/Code/ML/images/test02/test2/prizm/base 21-22/*/*.jpeg",
+                "D:/Code/ML/images/test02/test2/prizm/base 21-22/*/*.png",
+                "D:/Code/ML/images/test02/test2/mosaic/20-21/*/*.jpg"]
+
+    data = (towhee.glob['path'](*img_path)
+            # image_decode['path', 'img']().
+            .runas_op['path', "object"](yolo_detect)
+            .runas_op['object', 'vec'](func=img2vec)
+            .tensor_normalize['vec', 'vec']()
+            # image_embedding.timm['img', 'vec'](model_name='resnet50').
+            .ann_search.milvus['vec', 'result'](collection=collection, limit=3)
+            .runas_op['result', 'result_imgID'](func=read_imgID)
+            .select['path', 'result_imgID', 'vec']()
+            )
+
+    print(data)
+
+    collection.load()
+    # res = query_by_imgID(collection, data[0].result_imgID[0])
+    #
+    # print(res[0])
+
+    top3_num = 0
+    top1_num = 0
+    test_img_num = len(list(data))
+
+    # 查询所有测试图片
+    for i in range(test_img_num):
+        top3_flag = False
+
+        # 获取图片真正的系列
+        source_card_series = from_path_get_series(data[i].path)
+        # 获取图片真正的编号
+        source_num = os.path.split(os.path.split(data[i].path)[0])[-1].split('#')[-1]
+
+        # 每个测试图片返回三个最相似的图片ID,一一测试
+        for j in range(3):
+            res = query_by_imgID(collection, data[i].result_imgID[j])
+
+            # 获取预测的图片的系列
+            result_card_series = from_path_get_series(res[0]['path'])
+            # 获取预测的图片的编号
+            result_num = os.path.split(os.path.split(res[0]['path'])[0])[-1].split('#')[-1]
+
+            # 判断top1是否正确
+            if j == 0 and source_num == result_num and source_card_series == result_card_series:
+                top1_num += 1
+
+            # top3中有一个正确的标记为正确
+            if source_num == result_num and source_card_series == result_card_series:
+                top3_flag = True
+
+
+            # 日志
+            if j == 0 and source_num == result_num and source_card_series == result_card_series:
+                print(top1_num)
+            elif j == 0:
+                print('top_1 错误')
+            print("series: {}, num: {} === result - series: {}, num: {}".format(
+                source_card_series, source_num, result_card_series, result_num
+            ))
+
+        if top3_flag:
+            top3_num += 1
+
+        print("====================================")
+
+    print("测试图片共: ", test_img_num)
+    top1_accuracy = (top1_num / test_img_num) * 100
+    top3_accuracy = (top3_num / test_img_num) * 100
+
+    print("top3 准确率:{} % \n top1 准确率: {} %".
+          format(top3_accuracy, top1_accuracy))
+
+'''
+148 张图片
+ 默认resnet50 + yolo
+top3 准确率:100.0 % 
+ top1 准确率: 85.11904761904762 %
+ 
+ 148 张图片
+ 自定义resnet50_freeze_out217 + yolo
+top3 准确率:94.04761904761905 % 
+ top1 准确率: 93.45238095238095 %
+ 
+ 测试图片共:  168
+ 自定义resnet50_freeze_out421 + yolo + normalize
+top3 准确率:96.42857142857143 % 
+ top1 准确率: 95.23809523809523 %
+ 
+ 测试图片共:  168
+ 自定义resnet50_out764_freeze + yolo + normalize
+top3 准确率:95.23809523809523 % 
+ top1 准确率: 94.04761904761905 %
+'''

+ 7 - 0
test04.py

@@ -0,0 +1,7 @@
+import glob
+import os
+
+path = r"D:\Code\ML\images\Mywork3\card_database_yolo\mosaic\20-21"
+
+for name in os.listdir(path):
+    os.rename(os.path.join(path, name), os.path.join(path, name.split('#')[-1]))

+ 58 - 0
yolo_object.py

@@ -0,0 +1,58 @@
+import towhee
+import torch
+import cv2
+import numpy as np
+from PIL import Image, ImageOps
+
+yolo_model = torch.hub.load(r"C:\Users\Administrator\.cache\torch\hub\ultralytics_yolov5_master", 'custom', path="yolov5s.pt", source='local')
+# yolo_model = torch.hub.load("ultralytics/yolov5", "yolov5s")
+
+def yolo_detect(img):
+    results = yolo_model(img)
+
+    pred = results.pred[0][:, :4].cpu().numpy()
+    boxes = pred.astype(np.int32)
+
+    max_img = get_object(img, boxes)
+    return max_img
+
+
+def get_object(img, boxes):
+    if isinstance(img, str):
+        img = Image.open(img)
+
+    if len(boxes) == 0:
+        return img
+
+    max_area = 0
+
+    # 选出最大的框
+    x1, y1, x2, y2 = 0, 0, 0, 0
+    for box in boxes:
+        temp_x1, temp_y1, temp_x2, temp_y2 = box
+        area = (temp_x2 - temp_x1) * (temp_y2 - temp_y1)
+        if area > max_area:
+            max_area = area
+            x1, y1, x2, y2 = temp_x1, temp_y1, temp_x2, temp_y2
+
+    max_img = img.crop((x1, y1, x2, y2))
+    return max_img
+
+
+img = Image.open(r"D:\Code\ML\images\Mywork3\train_data\train\1\IMG_5024.JPG").convert("RGB")
+img = ImageOps.exif_transpose(img)
+results = yolo_model(img)
+results.show()
+# print(img)
+# result = yolo_detect(img)
+
+# result.show()
+
+dc = (
+    towhee.glob['path'](r"D:\Code\ML\images\test02\test\prizm\26-1.jpg")
+    .runas_op['path', "object"](yolo_detect)
+
+)
+
+print(dc)
+