zhihao.gu hace 2 años
commit
22c6381a68

+ 8 - 0
.idea/.gitignore

@@ -0,0 +1,8 @@
+# 默认忽略的文件
+/shelf/
+/workspace.xml
+# 基于编辑器的 HTTP 客户端请求
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml

+ 11 - 0
.idea/card_reverse_search_API.iml

@@ -0,0 +1,11 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<module type="PYTHON_MODULE" version="4">
+  <component name="NewModuleRootManager">
+    <content url="file://$MODULE_DIR$" />
+    <orderEntry type="inheritedJdk" />
+    <orderEntry type="sourceFolder" forTests="false" />
+  </component>
+  <component name="PyDocumentationSettings">
+    <option name="renderExternalDocumentation" value="true" />
+  </component>
+</module>

+ 6 - 0
.idea/inspectionProfiles/profiles_settings.xml

@@ -0,0 +1,6 @@
+<component name="InspectionProjectProfileManager">
+  <settings>
+    <option name="USE_PROJECT_PROFILE" value="false" />
+    <version value="1.0" />
+  </settings>
+</component>

+ 4 - 0
.idea/misc.xml

@@ -0,0 +1,4 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="ProjectRootManager" version="2" project-jdk-name="pytorch" project-jdk-type="Python SDK" />
+</project>

+ 8 - 0
.idea/modules.xml

@@ -0,0 +1,8 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="ProjectModuleManager">
+    <modules>
+      <module fileurl="file://$PROJECT_DIR$/.idea/card_reverse_search_API.iml" filepath="$PROJECT_DIR$/.idea/card_reverse_search_API.iml" />
+    </modules>
+  </component>
+</project>

+ 7 - 0
.idea/other.xml

@@ -0,0 +1,7 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="PySciProjectComponent">
+    <option name="PY_SCI_VIEW" value="true" />
+    <option name="PY_SCI_VIEW_SUGGESTED" value="true" />
+  </component>
+</project>

+ 6 - 0
.idea/vcs.xml

@@ -0,0 +1,6 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="VcsDirectoryMappings">
+    <mapping directory="$PROJECT_DIR$" vcs="Git" />
+  </component>
+</project>

+ 61 - 0
MyModel.py

@@ -0,0 +1,61 @@
+import torch
+import torchvision.models as models
+import torchvision.transforms as transforms
+from PIL import Image
+import timm
+
+
+class MyModel:
+    def __init__(self, model_dict_path):
+        self.norm_mean = [0.485, 0.456, 0.406]
+        self.norm_std = [0.229, 0.224, 0.225]
+
+        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+        # self.model = models.resnet50(pretrained=True)
+        self.model = timm.create_model('resnet50', num_classes=2048, pretrained=True)
+        self.model.eval()
+
+        # 自定义模型
+        # print(list(self.model.children()))
+        features = list(self.model.children())[:-1]  # 去掉全连接层
+        self.model = torch.nn.Sequential(*features).to(self.device)
+
+        # self.model.to(self.device)
+
+
+    def inference_transform(self):
+        inference_transform = transforms.Compose([
+            transforms.Resize((256, 256)),
+            transforms.CenterCrop(224),
+            transforms.ToTensor(),
+            transforms.Normalize(self.norm_mean, self.norm_std),
+        ])
+        return inference_transform
+
+    def img_transform(self, img_rgb, transform=None):
+        # 将数据转换为模型读取的形式
+        if transform is None:
+            raise ValueError("找不到transform!必须有transform对img进行处理")
+
+        img_t = transform(img_rgb)
+        return img_t
+
+    def get_model(self):
+        return self.model
+
+    # 输出图片路径或者cv2格式的图片数据
+    def predict(self, img):
+        if type(img) == type('path'):
+            img = Image.open(img).convert('RGB')
+
+        transform = self.inference_transform()
+
+        img_tensor = transform(img)
+        img_tensor.unsqueeze_(0)
+        img_tensor = img_tensor.to(self.device)
+        # print(img.shape)
+
+        with torch.no_grad():
+            outputs = self.model(img_tensor)
+        return outputs.reshape(2048).cpu().numpy()

+ 20 - 0
clean_local_upload.py

@@ -0,0 +1,20 @@
+import os
+import json
+
+def set_num(img_num):
+    record_dict = {"img_num": img_num}
+    with open("static/record.json", "w") as f:
+        json.dump(record_dict, f)
+        print("加载入文件完成...")
+
+# 清空图片文件,并且将record设置为零
+if __name__ == '__main__':
+    path = 'static/images'
+    img_list = os.listdir(path)
+    img_list.remove('index.txt')
+
+    for name in img_list:
+        os.remove(os.path.join(path, name))
+
+    set_num(0)
+    print('end')

+ 140 - 0
main.py

@@ -0,0 +1,140 @@
+import os
+from flask import Flask, render_template, request, jsonify
+from werkzeug.utils import secure_filename
+
+from pymilvus import connections, Collection
+from MyModel import MyModel
+import towhee
+
+import torch
+import time
+import json
+import numpy as np
+from datetime import timedelta
+
+# 设置允许的文件格式
+ALLOWED_EXTENSIONS = {'png', 'jpg', 'JPG', 'PNG', 'bmp'}
+
+
+def allowed_file(filename):
+    return '.' in filename and filename.rsplit('.', 1)[1] in ALLOWED_EXTENSIONS
+
+
+app = Flask(__name__)
+app.send_file_max_age_default = timedelta(seconds=1)
+
+
+# 用文件保存数字
+def set_num(img_num):
+    record_dict = {"img_num": img_num}
+    with open("static/record.json", "w") as f:
+        json.dump(record_dict, f)
+        print("加载入文件完成...")
+
+
+def get_num():
+    with open("static/record.json", 'r') as load_f:
+        load_dict = json.load(load_f)
+        img_num = load_dict['img_num']
+    return img_num
+
+
+def read_imgID(results):
+    imgIDs = []
+    for re in results:
+        # 输出结果图片信息
+        print('---------', re)
+        imgIDs.append(re.id)
+    return imgIDs
+
+
+def img2vec(img):
+    return myModel.predict(img)
+
+
+def query_by_imgID(collection, img_id, limit=1):
+    expr = 'img_id == ' + str(img_id)
+    res = collection.query(expr, output_fields=["path", "info"], offset=0, limit=limit, timeout=3)
+    return res
+
+
+def query_by_img(collection, img_path):
+    data = (towhee.glob['path'](img_path).
+            runas_op['path', 'vec'](func=img2vec).
+            ann_search.milvus['vec', 'result'](collection=collection, limit=3).
+            runas_op['result', 'result_imgID'](func=read_imgID).
+            select['path', 'result_imgID']()
+            )
+
+    res = query_by_imgID(collection, data[0].result_imgID[0])
+
+    return res[0]
+
+
+# 保存上传图片并返回保存路径,默认路径为 'static/images'
+def save_requestImg(f, save_name):
+    if not (f and allowed_file(f.filename)):
+        return jsonify({"error": 1001, "msg": "请检查上传的图片类型,仅限于png、PNG、jpg、JPG、bmp"})
+
+    basepath = os.path.dirname(__file__)  # 当前文件所在路径
+
+    # print(f.filename)
+
+    upload_path = os.path.join(basepath, 'static/images', secure_filename(save_name))  # 注意:没有的文件夹一定要先创建,不然会提示没有该路径
+    # upload_path = os.path.join(basepath, 'static/images','test.jpg')  #注意:没有的文件夹一定要先创建,不然会提示没有该路径
+    f.save(upload_path)
+
+    return upload_path
+
+
+@app.route('/upload', methods=['POST', 'GET'])  # 添加路由
+def upload():
+    print(request)
+    if request.method == 'POST':
+        # 文件上传及保存
+        f = request.files['file']
+        user_input = request.form.get("name")
+
+        # 访问次数记录
+        set_num(get_num() + 1)
+        save_name = str(get_num()) + '.jpg'
+        upload_path = save_requestImg(f, save_name)
+
+        res = query_by_img(collection, upload_path)
+
+        return render_template('upload_ok.html', img_name=save_name, userinput=user_input, val1=time.time(),
+                               result=str(res))
+
+    return render_template('upload.html')
+
+
+@app.route('/image_api', methods=['POST', 'GET'])  # 添加路由
+def imahe_api():
+    if request.method == 'POST':
+        # 文件上传及保存
+        f = request.files['file']
+
+        # 访问次数记录
+        set_num(get_num() + 1)
+        save_name = str(get_num()) + '.jpg'
+        upload_path = save_requestImg(f, save_name)
+
+        res = query_by_img(collection, upload_path)
+
+        result = {'matched_url': res['path'],
+                  'tag': res['info']}
+        return json.dumps(result)
+
+    return json.dumps("")
+
+
+if __name__ == '__main__':
+    connections.connect(host='127.0.0.1', port='19530')
+
+    myModel = MyModel('')
+    collection = Collection(name="reverse_image_search_myModel")
+    collection.load()
+
+    app.debug = True
+    print("use GPU: ", torch.cuda.is_available())
+    app.run()

+ 0 - 0
static/images/index.txt


+ 1 - 0
static/record.json

@@ -0,0 +1 @@
+{"img_num": 0}

+ 17 - 0
templates/upload.html

@@ -0,0 +1,17 @@
+<!DOCTYPE html>
+<html lang="en">
+<head>
+    <meta charset="UTF-8">
+    <title>上传图片进行识别</title>
+</head>
+<body>
+    <h1>上传图片识别</h1>
+    <form action="" enctype='multipart/form-data' method='POST'>
+        <input type="file" name="file" style="margin-top:20px;"/>
+        <br>
+        <i>请输入(备用框):</i>
+        <input type="text" class="txt_input" name="name" style="margin-top:10px;"/>
+        <input type="submit" value="上传" class="button-new" style="margin-top:15px;"/>
+    </form>
+</body>
+</html>

+ 21 - 0
templates/upload_ok.html

@@ -0,0 +1,21 @@
+<!DOCTYPE html>
+<html lang="en">
+<head>
+    <meta charset="UTF-8">
+    <title>上传图片进行识别</title>
+</head>
+<body>
+    <h1>上传图片识别</h1>
+    <form action="" enctype='multipart/form-data' method='POST'>
+        <input type="file" name="file" style="margin-top:20px;"/>
+        <br>
+        <input type="text" class="txt_input" name="name" style="margin-top:10px;"/>
+        <input type="submit" value="上传" class="button-new" style="margin-top:15px;"/>
+    </form>
+    <h1>输入内容:{{userinput}}!</h1>
+    <br>
+    <h2>result: {{result}}</h2>
+    <br>
+    <img src="{{ url_for('static', filename= './images/' + img_name, _t=val1) }}" width="400" height="400" alt="你的图片被外星人劫持了~~"/>
+</body>
+</html>