|
|
@@ -0,0 +1,140 @@
|
|
|
+import os
|
|
|
+from flask import Flask, render_template, request, jsonify
|
|
|
+from werkzeug.utils import secure_filename
|
|
|
+
|
|
|
+from pymilvus import connections, Collection
|
|
|
+from MyModel import MyModel
|
|
|
+import towhee
|
|
|
+
|
|
|
+import torch
|
|
|
+import time
|
|
|
+import json
|
|
|
+import numpy as np
|
|
|
+from datetime import timedelta
|
|
|
+
|
|
|
+# 设置允许的文件格式
|
|
|
+ALLOWED_EXTENSIONS = {'png', 'jpg', 'JPG', 'PNG', 'bmp'}
|
|
|
+
|
|
|
+
|
|
|
+def allowed_file(filename):
|
|
|
+ return '.' in filename and filename.rsplit('.', 1)[1] in ALLOWED_EXTENSIONS
|
|
|
+
|
|
|
+
|
|
|
+app = Flask(__name__)
|
|
|
+app.send_file_max_age_default = timedelta(seconds=1)
|
|
|
+
|
|
|
+
|
|
|
+# 用文件保存数字
|
|
|
+def set_num(img_num):
|
|
|
+ record_dict = {"img_num": img_num}
|
|
|
+ with open("static/record.json", "w") as f:
|
|
|
+ json.dump(record_dict, f)
|
|
|
+ print("加载入文件完成...")
|
|
|
+
|
|
|
+
|
|
|
+def get_num():
|
|
|
+ with open("static/record.json", 'r') as load_f:
|
|
|
+ load_dict = json.load(load_f)
|
|
|
+ img_num = load_dict['img_num']
|
|
|
+ return img_num
|
|
|
+
|
|
|
+
|
|
|
+def read_imgID(results):
|
|
|
+ imgIDs = []
|
|
|
+ for re in results:
|
|
|
+ # 输出结果图片信息
|
|
|
+ print('---------', re)
|
|
|
+ imgIDs.append(re.id)
|
|
|
+ return imgIDs
|
|
|
+
|
|
|
+
|
|
|
+def img2vec(img):
|
|
|
+ return myModel.predict(img)
|
|
|
+
|
|
|
+
|
|
|
+def query_by_imgID(collection, img_id, limit=1):
|
|
|
+ expr = 'img_id == ' + str(img_id)
|
|
|
+ res = collection.query(expr, output_fields=["path", "info"], offset=0, limit=limit, timeout=3)
|
|
|
+ return res
|
|
|
+
|
|
|
+
|
|
|
+def query_by_img(collection, img_path):
|
|
|
+ data = (towhee.glob['path'](img_path).
|
|
|
+ runas_op['path', 'vec'](func=img2vec).
|
|
|
+ ann_search.milvus['vec', 'result'](collection=collection, limit=3).
|
|
|
+ runas_op['result', 'result_imgID'](func=read_imgID).
|
|
|
+ select['path', 'result_imgID']()
|
|
|
+ )
|
|
|
+
|
|
|
+ res = query_by_imgID(collection, data[0].result_imgID[0])
|
|
|
+
|
|
|
+ return res[0]
|
|
|
+
|
|
|
+
|
|
|
+# 保存上传图片并返回保存路径,默认路径为 'static/images'
|
|
|
+def save_requestImg(f, save_name):
|
|
|
+ if not (f and allowed_file(f.filename)):
|
|
|
+ return jsonify({"error": 1001, "msg": "请检查上传的图片类型,仅限于png、PNG、jpg、JPG、bmp"})
|
|
|
+
|
|
|
+ basepath = os.path.dirname(__file__) # 当前文件所在路径
|
|
|
+
|
|
|
+ # print(f.filename)
|
|
|
+
|
|
|
+ upload_path = os.path.join(basepath, 'static/images', secure_filename(save_name)) # 注意:没有的文件夹一定要先创建,不然会提示没有该路径
|
|
|
+ # upload_path = os.path.join(basepath, 'static/images','test.jpg') #注意:没有的文件夹一定要先创建,不然会提示没有该路径
|
|
|
+ f.save(upload_path)
|
|
|
+
|
|
|
+ return upload_path
|
|
|
+
|
|
|
+
|
|
|
+@app.route('/upload', methods=['POST', 'GET']) # 添加路由
|
|
|
+def upload():
|
|
|
+ print(request)
|
|
|
+ if request.method == 'POST':
|
|
|
+ # 文件上传及保存
|
|
|
+ f = request.files['file']
|
|
|
+ user_input = request.form.get("name")
|
|
|
+
|
|
|
+ # 访问次数记录
|
|
|
+ set_num(get_num() + 1)
|
|
|
+ save_name = str(get_num()) + '.jpg'
|
|
|
+ upload_path = save_requestImg(f, save_name)
|
|
|
+
|
|
|
+ res = query_by_img(collection, upload_path)
|
|
|
+
|
|
|
+ return render_template('upload_ok.html', img_name=save_name, userinput=user_input, val1=time.time(),
|
|
|
+ result=str(res))
|
|
|
+
|
|
|
+ return render_template('upload.html')
|
|
|
+
|
|
|
+
|
|
|
+@app.route('/image_api', methods=['POST', 'GET']) # 添加路由
|
|
|
+def imahe_api():
|
|
|
+ if request.method == 'POST':
|
|
|
+ # 文件上传及保存
|
|
|
+ f = request.files['file']
|
|
|
+
|
|
|
+ # 访问次数记录
|
|
|
+ set_num(get_num() + 1)
|
|
|
+ save_name = str(get_num()) + '.jpg'
|
|
|
+ upload_path = save_requestImg(f, save_name)
|
|
|
+
|
|
|
+ res = query_by_img(collection, upload_path)
|
|
|
+
|
|
|
+ result = {'matched_url': res['path'],
|
|
|
+ 'tag': res['info']}
|
|
|
+ return json.dumps(result)
|
|
|
+
|
|
|
+ return json.dumps("")
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ connections.connect(host='127.0.0.1', port='19530')
|
|
|
+
|
|
|
+ myModel = MyModel('')
|
|
|
+ collection = Collection(name="reverse_image_search_myModel")
|
|
|
+ collection.load()
|
|
|
+
|
|
|
+ app.debug = True
|
|
|
+ print("use GPU: ", torch.cuda.is_available())
|
|
|
+ app.run()
|