|
@@ -6,8 +6,8 @@ import PIL.Image as Image
|
|
|
import numpy as np
|
|
import numpy as np
|
|
|
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
|
|
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
|
|
|
|
|
|
|
|
-from MyModel import MyModel
|
|
|
|
|
-from MyEfficientNet import MyEfficient
|
|
|
|
|
|
|
+from MyModel2 import MyModel
|
|
|
|
|
+
|
|
|
import torch
|
|
import torch
|
|
|
from transformers import ViTFeatureExtractor, ViTModel
|
|
from transformers import ViTFeatureExtractor, ViTModel
|
|
|
from towhee.types.image_utils import to_image_color
|
|
from towhee.types.image_utils import to_image_color
|
|
@@ -16,15 +16,16 @@ connections.connect(host='127.0.0.1', port='19530')
|
|
|
dataset_path = ["D:\Code\ML\images\Mywork3\card_database_yolo/*/*/*/*"]
|
|
dataset_path = ["D:\Code\ML\images\Mywork3\card_database_yolo/*/*/*/*"]
|
|
|
|
|
|
|
|
img_id = 0
|
|
img_id = 0
|
|
|
|
|
+yolo_num = 0
|
|
|
vec_num = 0
|
|
vec_num = 0
|
|
|
-myModel = MyModel(r"D:\Code\ML\model\card_cls\res_card_out764_freeze4.pth", out_features=764)
|
|
|
|
|
|
|
+myModel = MyModel(r"D:\Code\ML\model\card_cls\res_card_out764_freeze5.pth", out_features=764)
|
|
|
# myModel = MyModel(r"C:\Users\Administrator\.cache\torch\hub\checkpoints\resnet50-0676ba61.pth", out_features=1000)
|
|
# myModel = MyModel(r"C:\Users\Administrator\.cache\torch\hub\checkpoints\resnet50-0676ba61.pth", out_features=1000)
|
|
|
|
|
|
|
|
# myModel = MyEfficient('')
|
|
# myModel = MyEfficient('')
|
|
|
|
|
|
|
|
|
|
|
|
|
-yolo_model = torch.hub.load(r"C:\Users\Administrator\.cache\torch\hub\ultralytics_yolov5_master", 'custom',
|
|
|
|
|
- path="yolov5s.pt", source='local')
|
|
|
|
|
|
|
+# yolo_model = torch.hub.load(r"C:\Users\Administrator\.cache\torch\hub\ultralytics_yolov5_master", 'custom',
|
|
|
|
|
+# path="yolov5s.pt", source='local')
|
|
|
|
|
|
|
|
|
|
|
|
|
# yolo_model = torch.hub.load("ultralytics/yolov5", "yolov5s")
|
|
# yolo_model = torch.hub.load("ultralytics/yolov5", "yolov5s")
|
|
@@ -80,36 +81,40 @@ def read_imgID(results):
|
|
|
return imgIDs
|
|
return imgIDs
|
|
|
|
|
|
|
|
|
|
|
|
|
-def yolo_detect(img):
|
|
|
|
|
- results = yolo_model(img)
|
|
|
|
|
-
|
|
|
|
|
- pred = results.pred[0][:, :4].cpu().numpy()
|
|
|
|
|
- boxes = pred.astype(np.int32)
|
|
|
|
|
-
|
|
|
|
|
- max_img = get_object(img, boxes)
|
|
|
|
|
- return max_img
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-def get_object(img, boxes):
|
|
|
|
|
- if isinstance(img, str):
|
|
|
|
|
- img = Image.open(img)
|
|
|
|
|
-
|
|
|
|
|
- if len(boxes) == 0:
|
|
|
|
|
- return img
|
|
|
|
|
-
|
|
|
|
|
- max_area = 0
|
|
|
|
|
-
|
|
|
|
|
- # 选出最大的框
|
|
|
|
|
- x1, y1, x2, y2 = 0, 0, 0, 0
|
|
|
|
|
- for box in boxes:
|
|
|
|
|
- temp_x1, temp_y1, temp_x2, temp_y2 = box
|
|
|
|
|
- area = (temp_x2 - temp_x1) * (temp_y2 - temp_y1)
|
|
|
|
|
- if area > max_area:
|
|
|
|
|
- max_area = area
|
|
|
|
|
- x1, y1, x2, y2 = temp_x1, temp_y1, temp_x2, temp_y2
|
|
|
|
|
-
|
|
|
|
|
- max_img = img.crop((x1, y1, x2, y2))
|
|
|
|
|
- return max_img
|
|
|
|
|
|
|
+# def yolo_detect(img):
|
|
|
|
|
+# results = yolo_model(img)
|
|
|
|
|
+#
|
|
|
|
|
+# pred = results.pred[0][:, :4].cpu().numpy()
|
|
|
|
|
+# boxes = pred.astype(np.int32)
|
|
|
|
|
+#
|
|
|
|
|
+# max_img = get_object(img, boxes)
|
|
|
|
|
+#
|
|
|
|
|
+# global yolo_num
|
|
|
|
|
+# yolo_num += 1
|
|
|
|
|
+# print("yolo_num: ", yolo_num)
|
|
|
|
|
+# return max_img
|
|
|
|
|
+#
|
|
|
|
|
+#
|
|
|
|
|
+# def get_object(img, boxes):
|
|
|
|
|
+# if isinstance(img, str):
|
|
|
|
|
+# img = Image.open(img)
|
|
|
|
|
+#
|
|
|
|
|
+# if len(boxes) == 0:
|
|
|
|
|
+# return img
|
|
|
|
|
+#
|
|
|
|
|
+# max_area = 0
|
|
|
|
|
+#
|
|
|
|
|
+# # 选出最大的框
|
|
|
|
|
+# x1, y1, x2, y2 = 0, 0, 0, 0
|
|
|
|
|
+# for box in boxes:
|
|
|
|
|
+# temp_x1, temp_y1, temp_x2, temp_y2 = box
|
|
|
|
|
+# area = (temp_x2 - temp_x1) * (temp_y2 - temp_y1)
|
|
|
|
|
+# if area > max_area:
|
|
|
|
|
+# max_area = area
|
|
|
|
|
+# x1, y1, x2, y2 = temp_x1, temp_y1, temp_x2, temp_y2
|
|
|
|
|
+#
|
|
|
|
|
+# max_img = img.crop((x1, y1, x2, y2))
|
|
|
|
|
+# return max_img
|
|
|
|
|
|
|
|
|
|
|
|
|
# 创建向量数据库
|
|
# 创建向量数据库
|
|
@@ -166,12 +171,23 @@ def query_by_imgID(collection, img_id, limit=1):
|
|
|
return res
|
|
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
-def from_path_get_series(path):
|
|
|
|
|
|
|
+# 分别返回 编号,年份,系列
|
|
|
|
|
+def from_path_get_info(path):
|
|
|
|
|
+ card_info = []
|
|
|
for i in range(3):
|
|
for i in range(3):
|
|
|
path = os.path.split(path)[0]
|
|
path = os.path.split(path)[0]
|
|
|
- series = os.path.split(path)[-1]
|
|
|
|
|
|
|
+ card_info.append(os.path.split(path)[-1])
|
|
|
|
|
+ card_info[0] = card_info[0].split('#')[-1]
|
|
|
|
|
+ return card_info
|
|
|
|
|
|
|
|
- return series
|
|
|
|
|
|
|
+
|
|
|
|
|
+def from_query_path_get_info(path):
|
|
|
|
|
+ card_info = []
|
|
|
|
|
+ for i in range(3):
|
|
|
|
|
+ path = os.path.split(path)[0]
|
|
|
|
|
+ card_info.append(os.path.split(path)[-1])
|
|
|
|
|
+ card_info[0] = card_info[0].split(' ')[0]
|
|
|
|
|
+ return card_info
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
if __name__ == '__main__':
|
|
@@ -186,12 +202,12 @@ if __name__ == '__main__':
|
|
|
collection = is_creat_collection(have_coll=have_coll, collection_name="reverse_image_search_myModel")
|
|
collection = is_creat_collection(have_coll=have_coll, collection_name="reverse_image_search_myModel")
|
|
|
|
|
|
|
|
# 测试的图片路径
|
|
# 测试的图片路径
|
|
|
- img_path = ["D:/Code/ML/images/test02/test2/*/*/*/*"]
|
|
|
|
|
|
|
+ img_path = ["D:/Code/ML/images/test02/test(mosaic,pz)/*/*/*/*"]
|
|
|
|
|
|
|
|
data = (towhee.glob['path'](*img_path)
|
|
data = (towhee.glob['path'](*img_path)
|
|
|
# image_decode['path', 'img']().
|
|
# image_decode['path', 'img']().
|
|
|
- .runas_op['path', "object"](yolo_detect)
|
|
|
|
|
- .runas_op['object', 'vec'](func=img2vec)
|
|
|
|
|
|
|
+ # .runas_op['path', "object"](yolo_detect)
|
|
|
|
|
+ .runas_op['path', 'vec'](func=img2vec)
|
|
|
.tensor_normalize['vec', 'vec']()
|
|
.tensor_normalize['vec', 'vec']()
|
|
|
# image_embedding.timm['img', 'vec'](model_name='resnet50').
|
|
# image_embedding.timm['img', 'vec'](model_name='resnet50').
|
|
|
.ann_search.milvus['vec', 'result'](collection=collection, limit=3)
|
|
.ann_search.milvus['vec', 'result'](collection=collection, limit=3)
|
|
@@ -206,7 +222,6 @@ if __name__ == '__main__':
|
|
|
#
|
|
#
|
|
|
# print(res[0])
|
|
# print(res[0])
|
|
|
|
|
|
|
|
-
|
|
|
|
|
top3_num = 0
|
|
top3_num = 0
|
|
|
top1_num = 0
|
|
top1_num = 0
|
|
|
test_img_num = len(list(data))
|
|
test_img_num = len(list(data))
|
|
@@ -215,35 +230,29 @@ if __name__ == '__main__':
|
|
|
for i in range(test_img_num):
|
|
for i in range(test_img_num):
|
|
|
top3_flag = False
|
|
top3_flag = False
|
|
|
|
|
|
|
|
- # 获取图片真正的系列
|
|
|
|
|
- source_card_series = from_path_get_series(data[i].path)
|
|
|
|
|
- # 获取图片真正的编号
|
|
|
|
|
- source_num = os.path.split(os.path.split(data[i].path)[0])[-1].split('#')[-1]
|
|
|
|
|
|
|
+ # 获取图片真正的编号, 年份, 系列
|
|
|
|
|
+ source_code, source_year, source_series = from_path_get_info(data[i].path)
|
|
|
|
|
|
|
|
# 每个测试图片返回三个最相似的图片ID,一一测试
|
|
# 每个测试图片返回三个最相似的图片ID,一一测试
|
|
|
for j in range(3):
|
|
for j in range(3):
|
|
|
res = query_by_imgID(collection, data[i].result_imgID[j])
|
|
res = query_by_imgID(collection, data[i].result_imgID[j])
|
|
|
|
|
|
|
|
- # 获取预测的图片的系列
|
|
|
|
|
- result_card_series = from_path_get_series(res[0]['path'])
|
|
|
|
|
- # 获取预测的图片的编号
|
|
|
|
|
- result_num = os.path.split(os.path.split(res[0]['path'])[0])[-1].split(' ')[0].split('#')[-1]
|
|
|
|
|
|
|
+ # 获取预测的图片的编号, 年份, 系列
|
|
|
|
|
+ result_code, result_year, result_series = from_query_path_get_info(res[0]['path'])
|
|
|
|
|
|
|
|
# 判断top1是否正确
|
|
# 判断top1是否正确
|
|
|
- if j == 0 and source_num == result_num and source_card_series == result_card_series:
|
|
|
|
|
|
|
+ if j == 0 and source_code == result_code and source_year == result_year and source_series == result_series:
|
|
|
top1_num += 1
|
|
top1_num += 1
|
|
|
|
|
+ print(top1_num)
|
|
|
|
|
+ elif j == 0:
|
|
|
|
|
+ print('top_1 错误')
|
|
|
|
|
|
|
|
# top3中有一个正确的标记为正确
|
|
# top3中有一个正确的标记为正确
|
|
|
- if source_num == result_num and source_card_series == result_card_series:
|
|
|
|
|
|
|
+ if source_code == result_code and source_year == result_year and source_series == result_series:
|
|
|
top3_flag = True
|
|
top3_flag = True
|
|
|
|
|
|
|
|
- # 日志
|
|
|
|
|
- if j == 0 and source_num == result_num and source_card_series == result_card_series:
|
|
|
|
|
- print(top1_num)
|
|
|
|
|
- elif j == 0:
|
|
|
|
|
- print('top_1 错误')
|
|
|
|
|
- print("series: {}, num: {} === result - series: {}, num: {}".format(
|
|
|
|
|
- source_card_series, source_num, result_card_series, result_num
|
|
|
|
|
|
|
+ print("series: {}, year: {},code: {} === result - series: {}, year: {}, code: {}".format(
|
|
|
|
|
+ source_series, source_year, source_code, result_series, result_year, result_code,
|
|
|
))
|
|
))
|
|
|
|
|
|
|
|
if top3_flag:
|
|
if top3_flag:
|
|
@@ -258,22 +267,3 @@ if __name__ == '__main__':
|
|
|
print("top3 准确率:{} % \n top1 准确率: {} %".
|
|
print("top3 准确率:{} % \n top1 准确率: {} %".
|
|
|
format(top3_accuracy, top1_accuracy))
|
|
format(top3_accuracy, top1_accuracy))
|
|
|
|
|
|
|
|
-'''
|
|
|
|
|
- 测试图片共: 168
|
|
|
|
|
- 自定义resnet50_freeze_out421 + yolo + normalize
|
|
|
|
|
-top3 准确率:96.42857142857143 %
|
|
|
|
|
- top1 准确率: 95.23809523809523 %
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-测试图片: 773, 数据库图片: 5848
|
|
|
|
|
-自定义resnet50_freeze_out421 + yolo + normalize
|
|
|
|
|
-测试图片共: 773
|
|
|
|
|
-top3 准确率:96.63648124191462 %
|
|
|
|
|
- top1 准确率: 95.60155239327295 %
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
- 测试图片: 773, 数据库图片: 5848
|
|
|
|
|
- 自定义resnet50_out764_freeze + yolo + normalize
|
|
|
|
|
-top3 准确率:96.76584734799482 %
|
|
|
|
|
- top1 准确率: 96.50711513583441 %
|
|
|
|
|
-'''
|
|
|