| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140 |
- import os
- from flask import Flask, render_template, request, jsonify
- from werkzeug.utils import secure_filename
- from pymilvus import connections, Collection
- from MyModel import MyModel
- import towhee
- import torch
- import time
- import json
- import numpy as np
- from datetime import timedelta
- # 设置允许的文件格式
- ALLOWED_EXTENSIONS = {'png', 'jpg', 'JPG', 'PNG', 'bmp'}
- def allowed_file(filename):
- return '.' in filename and filename.rsplit('.', 1)[1] in ALLOWED_EXTENSIONS
- app = Flask(__name__)
- app.send_file_max_age_default = timedelta(seconds=1)
- # 用文件保存数字
- def set_num(img_num):
- record_dict = {"img_num": img_num}
- with open("static/record.json", "w") as f:
- json.dump(record_dict, f)
- print("加载入文件完成...")
- def get_num():
- with open("static/record.json", 'r') as load_f:
- load_dict = json.load(load_f)
- img_num = load_dict['img_num']
- return img_num
- def read_imgID(results):
- imgIDs = []
- for re in results:
- # 输出结果图片信息
- print('---------', re)
- imgIDs.append(re.id)
- return imgIDs
- def img2vec(img):
- return myModel.predict(img)
- def query_by_imgID(collection, img_id, limit=1):
- expr = 'img_id == ' + str(img_id)
- res = collection.query(expr, output_fields=["path", "info"], offset=0, limit=limit, timeout=3)
- return res
- def query_by_img(collection, img_path):
- data = (towhee.glob['path'](img_path).
- runas_op['path', 'vec'](func=img2vec).
- ann_search.milvus['vec', 'result'](collection=collection, limit=3).
- runas_op['result', 'result_imgID'](func=read_imgID).
- select['path', 'result_imgID']()
- )
- res = query_by_imgID(collection, data[0].result_imgID[0])
- return res[0]
- # 保存上传图片并返回保存路径,默认路径为 'static/images'
- def save_requestImg(f, save_name):
- if not (f and allowed_file(f.filename)):
- return jsonify({"error": 1001, "msg": "请检查上传的图片类型,仅限于png、PNG、jpg、JPG、bmp"})
- basepath = os.path.dirname(__file__) # 当前文件所在路径
- # print(f.filename)
- upload_path = os.path.join(basepath, 'static/images', secure_filename(save_name)) # 注意:没有的文件夹一定要先创建,不然会提示没有该路径
- # upload_path = os.path.join(basepath, 'static/images','test.jpg') #注意:没有的文件夹一定要先创建,不然会提示没有该路径
- f.save(upload_path)
- return upload_path
- @app.route('/upload', methods=['POST', 'GET']) # 添加路由
- def upload():
- print(request)
- if request.method == 'POST':
- # 文件上传及保存
- f = request.files['file']
- user_input = request.form.get("name")
- # 访问次数记录
- set_num(get_num() + 1)
- save_name = str(get_num()) + '.jpg'
- upload_path = save_requestImg(f, save_name)
- res = query_by_img(collection, upload_path)
- return render_template('upload_ok.html', img_name=save_name, userinput=user_input, val1=time.time(),
- result=str(res))
- return render_template('upload.html')
- @app.route('/image_api', methods=['POST', 'GET']) # 添加路由
- def imahe_api():
- if request.method == 'POST':
- # 文件上传及保存
- f = request.files['file']
- # 访问次数记录
- set_num(get_num() + 1)
- save_name = str(get_num()) + '.jpg'
- upload_path = save_requestImg(f, save_name)
- res = query_by_img(collection, upload_path)
- result = {'matched_url': res['path'],
- 'tag': res['info']}
- return json.dumps(result)
- return json.dumps("")
- if __name__ == '__main__':
- connections.connect(host='127.0.0.1', port='19530')
- myModel = MyModel('')
- collection = Collection(name="reverse_image_search_myModel")
- collection.load()
- app.debug = True
- print("use GPU: ", torch.cuda.is_available())
- app.run()
|