main.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. import os
  2. from flask import Flask, render_template, request, jsonify
  3. from werkzeug.utils import secure_filename
  4. from pymilvus import connections, Collection
  5. from MyModel import MyModel
  6. import towhee
  7. import torch
  8. import time
  9. import json
  10. import numpy as np
  11. from datetime import timedelta
  12. # 设置允许的文件格式
  13. ALLOWED_EXTENSIONS = {'png', 'jpg', 'JPG', 'PNG', 'bmp'}
  14. def allowed_file(filename):
  15. return '.' in filename and filename.rsplit('.', 1)[1] in ALLOWED_EXTENSIONS
  16. app = Flask(__name__)
  17. app.send_file_max_age_default = timedelta(seconds=1)
  18. # 用文件保存数字
  19. def set_num(img_num):
  20. record_dict = {"img_num": img_num}
  21. with open("static/record.json", "w") as f:
  22. json.dump(record_dict, f)
  23. print("加载入文件完成...")
  24. def get_num():
  25. with open("static/record.json", 'r') as load_f:
  26. load_dict = json.load(load_f)
  27. img_num = load_dict['img_num']
  28. return img_num
  29. def read_imgID(results):
  30. imgIDs = []
  31. for re in results:
  32. # 输出结果图片信息
  33. print('---------', re)
  34. imgIDs.append(re.id)
  35. return imgIDs
  36. def img2vec(img):
  37. return myModel.predict(img)
  38. def query_by_imgID(collection, img_id, limit=1):
  39. expr = 'img_id == ' + str(img_id)
  40. res = collection.query(expr, output_fields=["path", "info"], offset=0, limit=limit, timeout=3)
  41. return res
  42. def query_by_img(collection, img_path):
  43. data = (towhee.glob['path'](img_path).
  44. runas_op['path', 'vec'](func=img2vec).
  45. ann_search.milvus['vec', 'result'](collection=collection, limit=3).
  46. runas_op['result', 'result_imgID'](func=read_imgID).
  47. select['path', 'result_imgID']()
  48. )
  49. res = query_by_imgID(collection, data[0].result_imgID[0])
  50. return res[0]
  51. # 保存上传图片并返回保存路径,默认路径为 'static/images'
  52. def save_requestImg(f, save_name):
  53. if not (f and allowed_file(f.filename)):
  54. return jsonify({"error": 1001, "msg": "请检查上传的图片类型,仅限于png、PNG、jpg、JPG、bmp"})
  55. basepath = os.path.dirname(__file__) # 当前文件所在路径
  56. # print(f.filename)
  57. upload_path = os.path.join(basepath, 'static/images', secure_filename(save_name)) # 注意:没有的文件夹一定要先创建,不然会提示没有该路径
  58. # upload_path = os.path.join(basepath, 'static/images','test.jpg') #注意:没有的文件夹一定要先创建,不然会提示没有该路径
  59. f.save(upload_path)
  60. return upload_path
  61. @app.route('/upload', methods=['POST', 'GET']) # 添加路由
  62. def upload():
  63. print(request)
  64. if request.method == 'POST':
  65. # 文件上传及保存
  66. f = request.files['file']
  67. user_input = request.form.get("name")
  68. # 访问次数记录
  69. set_num(get_num() + 1)
  70. save_name = str(get_num()) + '.jpg'
  71. upload_path = save_requestImg(f, save_name)
  72. res = query_by_img(collection, upload_path)
  73. return render_template('upload_ok.html', img_name=save_name, userinput=user_input, val1=time.time(),
  74. result=str(res))
  75. return render_template('upload.html')
  76. @app.route('/image_api', methods=['POST', 'GET']) # 添加路由
  77. def imahe_api():
  78. if request.method == 'POST':
  79. # 文件上传及保存
  80. f = request.files['file']
  81. # 访问次数记录
  82. set_num(get_num() + 1)
  83. save_name = str(get_num()) + '.jpg'
  84. upload_path = save_requestImg(f, save_name)
  85. res = query_by_img(collection, upload_path)
  86. result = {'matched_url': res['path'],
  87. 'tag': res['info']}
  88. return json.dumps(result)
  89. return json.dumps("")
  90. if __name__ == '__main__':
  91. connections.connect(host='127.0.0.1', port='19530')
  92. myModel = MyModel('')
  93. collection = Collection(name="reverse_image_search_myModel")
  94. collection.load()
  95. app.debug = True
  96. print("use GPU: ", torch.cuda.is_available())
  97. app.run()