test02.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import os
  2. import shutil
  3. from typing import List
  4. import uvicorn
  5. from fastapi import FastAPI, File, UploadFile, HTTPException
  6. from fastapi.responses import HTMLResponse
  7. # 创建一个目标文件夹来存放上传的文件
  8. UPLOAD_DIRECTORY = "./uploads"
  9. if not os.path.exists(UPLOAD_DIRECTORY):
  10. os.makedirs(UPLOAD_DIRECTORY)
  11. app = FastAPI()
  12. @app.post("/upload-folder/")
  13. async def upload_folder(files: List[UploadFile] = File(...)):
  14. """
  15. 接收通过 webkitdirectory 上传的整个文件夹
  16. """
  17. saved_files = []
  18. for file in files:
  19. # file.filename 会包含从选定目录开始的相对路径
  20. # 例如: "my_folder/data.csv" 或 "my_folder/images/pic.png"
  21. # 安全性检查:防止路径遍历攻击 (e.g., "my_folder/../../etc/passwd")
  22. if ".." in file.filename:
  23. raise HTTPException(status_code=400, detail=f"Invalid filename: {file.filename}. Contains '..'")
  24. # 在服务器上创建完整的目标路径
  25. # os.path.join 会正确处理不同操作系统的路径分隔符
  26. destination_path = os.path.join(UPLOAD_DIRECTORY, file.filename)
  27. # 获取目标文件的目录路径
  28. destination_dir = os.path.dirname(destination_path)
  29. # 如果目录不存在,则创建它
  30. if not os.path.exists(destination_dir):
  31. os.makedirs(destination_dir)
  32. try:
  33. # 异步地将文件内容写入目标路径
  34. with open(destination_path, "wb") as buffer:
  35. shutil.copyfileobj(file.file, buffer)
  36. saved_files.append(file.filename)
  37. finally:
  38. # 确保关闭文件
  39. await file.close()
  40. return {"message": f"Successfully uploaded {len(saved_files)} files", "filenames": saved_files}
  41. # 提供一个简单的 HTML 上传页面用于测试
  42. @app.get("/")
  43. async def main():
  44. content = """
  45. <body>
  46. <h2>上传整个文件夹</h2>
  47. <p>选择一个文件夹,其中的所有文件(包括子目录中的文件)都将被上传。</p>
  48. <form action="/upload-folder/" enctype="multipart/form-data" method="post">
  49. <input name="files" type="file" webkitdirectory directory multiple>
  50. <input type="submit">
  51. </form>
  52. </body>
  53. """
  54. return HTMLResponse(content=content)
  55. if __name__ == "__main__":
  56. uvicorn.run(app, host="0.0.0.0", port=8000)