users.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. import base64
  2. import hashlib
  3. import hmac
  4. import json
  5. import secrets
  6. from datetime import datetime, timedelta, timezone
  7. from typing import Optional, List
  8. from fastapi import APIRouter, Depends, HTTPException, Query, status
  9. from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
  10. from mysql.connector.pooling import PooledMySQLConnection
  11. from pydantic import BaseModel, ConfigDict, Field
  12. from app.core.config import settings
  13. from app.core.database_loader import get_db_connection
  14. from app.core.logger import get_logger
  15. logger = get_logger(__name__)
  16. router = APIRouter()
  17. db_dependency = Depends(get_db_connection)
  18. oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.API_PREFIX}/users/login")
  19. # Hard-coded by design: used only for promoting the current user to admin.
  20. ADMIN_GRANT_KEY = settings.ADMIN_GRANT_KEY
  21. TOKEN_SECRET_KEY = settings.TOKEN_SECRET_KEY
  22. TOKEN_EXPIRE_MINUTES = 60 * 24 * 7
  23. class UserRegisterRequest(BaseModel):
  24. username: str = Field(..., min_length=6, max_length=20, pattern=r"^[A-Za-z0-9]+$")
  25. nickname: str = Field(..., min_length=1, max_length=20)
  26. password: str = Field(..., min_length=6, max_length=20, pattern=r"^[A-Za-z0-9]+$")
  27. class UserUpdateRequest(BaseModel):
  28. model_config = ConfigDict(extra="forbid")
  29. nickname: Optional[str] = Field(None, min_length=1, max_length=20)
  30. password: Optional[str] = Field(None, min_length=6, max_length=20, pattern=r"^[A-Za-z0-9]+$")
  31. class UserLoginRequest(BaseModel):
  32. username: str = Field(..., min_length=6, max_length=20, pattern=r"^[A-Za-z0-9]+$")
  33. password: str = Field(..., min_length=6, max_length=20, pattern=r"^[A-Za-z0-9]+$")
  34. class AdminGrantRequest(BaseModel):
  35. key: str
  36. user_id: Optional[int] = None
  37. class BindCardRequest(BaseModel):
  38. user_id: int
  39. card_id: List[int] = Field(..., min_length=1)
  40. class UserResponse(BaseModel):
  41. id: int
  42. username: str
  43. nickname: str
  44. is_admin: bool
  45. class UserListItem(BaseModel):
  46. id: int
  47. username: str
  48. nickname: str
  49. is_admin: bool
  50. created_at: datetime
  51. updated_at: datetime
  52. class UserListWithTotal(BaseModel):
  53. total: int
  54. list: List[UserListItem]
  55. class UserListResponseWrapper(BaseModel):
  56. data: UserListWithTotal
  57. class TokenResponse(BaseModel):
  58. access_token: str
  59. token_type: str = "bearer"
  60. def _auth_exception() -> HTTPException:
  61. return HTTPException(
  62. status_code=status.HTTP_401_UNAUTHORIZED,
  63. detail="Invalid or expired token",
  64. headers={"WWW-Authenticate": "Bearer"},
  65. )
  66. def _hash_password(password: str, salt: Optional[str] = None) -> str:
  67. salt = salt or secrets.token_hex(16)
  68. password_hash = hashlib.sha256(f"{salt}:{password}".encode("utf-8")).hexdigest()
  69. return f"{salt}${password_hash}"
  70. def _verify_password(password: str, stored_password: str) -> bool:
  71. try:
  72. salt, old_hash = stored_password.split("$", 1)
  73. except ValueError:
  74. return False
  75. return hmac.compare_digest(_hash_password(password, salt), f"{salt}${old_hash}")
  76. def _b64encode(data: bytes) -> str:
  77. return base64.urlsafe_b64encode(data).decode("utf-8").rstrip("=")
  78. def _b64decode(data: str) -> bytes:
  79. padding = "=" * (-len(data) % 4)
  80. return base64.urlsafe_b64decode(f"{data}{padding}".encode("utf-8"))
  81. def _create_access_token(user_id: int) -> str:
  82. expires_at = datetime.now(timezone.utc) + timedelta(minutes=TOKEN_EXPIRE_MINUTES)
  83. payload = {"user_id": user_id, "exp": int(expires_at.timestamp())}
  84. body = _b64encode(json.dumps(payload, separators=(",", ":")).encode("utf-8"))
  85. signature = hmac.new(TOKEN_SECRET_KEY.encode("utf-8"), body.encode("utf-8"), hashlib.sha256).hexdigest()
  86. return f"{body}.{signature}"
  87. def _decode_access_token(token: str) -> int:
  88. try:
  89. body, signature = token.split(".", 1)
  90. expected_signature = hmac.new(
  91. TOKEN_SECRET_KEY.encode("utf-8"),
  92. body.encode("utf-8"),
  93. hashlib.sha256
  94. ).hexdigest()
  95. if not hmac.compare_digest(signature, expected_signature):
  96. raise ValueError("bad signature")
  97. payload = json.loads(_b64decode(body))
  98. if int(payload.get("exp", 0)) < int(datetime.now(timezone.utc).timestamp()):
  99. raise ValueError("expired")
  100. return int(payload["user_id"])
  101. except Exception:
  102. raise _auth_exception()
  103. def _normalize_user(user: dict) -> dict:
  104. user["is_admin"] = bool(user.get("is_admin"))
  105. return user
  106. def _get_user_by_id(db_conn: PooledMySQLConnection, user_id: int) -> Optional[dict]:
  107. with db_conn.cursor(dictionary=True) as cursor:
  108. cursor.execute(
  109. f"SELECT id, username, nickname, is_admin FROM `{settings.DB_USER_TABLE_NAME}` WHERE id = %s",
  110. (user_id,)
  111. )
  112. user = cursor.fetchone()
  113. return _normalize_user(user) if user else None
  114. def _get_user_by_username(db_conn: PooledMySQLConnection, username: str) -> Optional[dict]:
  115. with db_conn.cursor(dictionary=True) as cursor:
  116. cursor.execute(
  117. f"SELECT id, username, nickname, password, is_admin FROM `{settings.DB_USER_TABLE_NAME}` WHERE username = %s",
  118. (username,)
  119. )
  120. user = cursor.fetchone()
  121. return _normalize_user(user) if user else None
  122. def get_current_user(
  123. token: str = Depends(oauth2_scheme),
  124. db_conn: PooledMySQLConnection = db_dependency
  125. ) -> dict:
  126. user_id = _decode_access_token(token)
  127. user = _get_user_by_id(db_conn, user_id)
  128. if not user:
  129. raise _auth_exception()
  130. return user
  131. def require_admin_user(current_user: dict = Depends(get_current_user)) -> dict:
  132. if not current_user.get("is_admin"):
  133. raise HTTPException(status_code=403, detail="该请求需要管理员权限")
  134. return current_user
  135. def check_card_permission(db_conn: PooledMySQLConnection, current_user: dict, card_id: int):
  136. if current_user.get("is_admin"):
  137. return
  138. with db_conn.cursor() as cursor:
  139. cursor.execute(
  140. f"SELECT 1 FROM `{settings.DB_USER_CARD_TABLE_NAME}` WHERE user_id = %s AND card_id = %s LIMIT 1",
  141. (current_user["id"], card_id)
  142. )
  143. if cursor.fetchone():
  144. return
  145. raise HTTPException(status_code=403, detail="没有该卡片权限")
  146. @router.post("/register", response_model=UserResponse, summary="注册用户")
  147. def register_user(data: UserRegisterRequest, db_conn: PooledMySQLConnection = db_dependency):
  148. try:
  149. if _get_user_by_username(db_conn, data.username):
  150. raise HTTPException(status_code=400, detail="该用户名已经存在")
  151. with db_conn.cursor(dictionary=True) as cursor:
  152. cursor.execute(
  153. f"INSERT INTO `{settings.DB_USER_TABLE_NAME}` (username, nickname, password) VALUES (%s, %s, %s)",
  154. (data.username, data.nickname, _hash_password(data.password))
  155. )
  156. db_conn.commit()
  157. user_id = cursor.lastrowid
  158. logger.info(f"User created, id: {user_id}, username: {data.username}")
  159. user = _get_user_by_id(db_conn, user_id)
  160. return UserResponse.model_validate(user)
  161. except HTTPException:
  162. db_conn.rollback()
  163. raise
  164. except Exception as e:
  165. db_conn.rollback()
  166. logger.error(f"Register user failed: {e}")
  167. raise HTTPException(status_code=500, detail="注册用户失败")
  168. @router.post("/login", response_model=TokenResponse, summary="用户登录")
  169. def login_for_access_token(
  170. form_data: OAuth2PasswordRequestForm = Depends(),
  171. db_conn: PooledMySQLConnection = db_dependency
  172. ):
  173. username = form_data.username
  174. password = form_data.password
  175. user = _get_user_by_username(db_conn, username)
  176. if not user or not _verify_password(password, user["password"]):
  177. raise HTTPException(
  178. status_code=status.HTTP_401_UNAUTHORIZED,
  179. detail="用户名或密码错误",
  180. headers={"WWW-Authenticate": "Bearer"},
  181. )
  182. return TokenResponse(access_token=_create_access_token(user["id"]))
  183. @router.put("/me", response_model=UserResponse, summary="修改当前的昵称或密码")
  184. def update_current_user(
  185. data: UserUpdateRequest,
  186. current_user: dict = Depends(get_current_user),
  187. db_conn: PooledMySQLConnection = db_dependency
  188. ):
  189. update_fields = []
  190. params = []
  191. if data.nickname is not None:
  192. update_fields.append("nickname = %s")
  193. params.append(data.nickname)
  194. if data.password is not None:
  195. update_fields.append("password = %s")
  196. params.append(_hash_password(data.password))
  197. if not update_fields:
  198. raise HTTPException(status_code=400, detail="No fields to update")
  199. try:
  200. params.append(current_user["id"])
  201. with db_conn.cursor() as cursor:
  202. query = (
  203. f"UPDATE `{settings.DB_USER_TABLE_NAME}` SET "
  204. f"{', '.join(update_fields)} WHERE id = %s"
  205. )
  206. cursor.execute(query, tuple(params))
  207. db_conn.commit()
  208. return UserResponse.model_validate(_get_user_by_id(db_conn, current_user["id"]))
  209. except Exception as e:
  210. db_conn.rollback()
  211. logger.error(f"Update user failed: {e}")
  212. raise HTTPException(status_code=500, detail="修改用户失败")
  213. @router.get("/list", response_model=UserListResponseWrapper, summary="查询用户列表")
  214. def list_users(
  215. nickname: Optional[str] = Query(None, description="按昵称模糊搜索"),
  216. sort_order: str = Query("desc", pattern="^(asc|desc)$", description="按创建时间排序: asc 或 desc"),
  217. skip: int = Query(0, ge=0),
  218. page_num: Optional[int] = Query(None, ge=1),
  219. limit: int = Query(100, ge=1, le=1000),
  220. current_user: dict = Depends(require_admin_user),
  221. db_conn: PooledMySQLConnection = db_dependency
  222. ):
  223. if page_num is not None:
  224. skip = (page_num - 1) * limit
  225. try:
  226. with db_conn.cursor(dictionary=True) as cursor:
  227. conditions = []
  228. params = []
  229. if nickname:
  230. conditions.append("nickname LIKE %s")
  231. params.append(f"%{nickname}%")
  232. where_clause = ""
  233. if conditions:
  234. where_clause = " WHERE " + " AND ".join(conditions)
  235. count_query = f"SELECT COUNT(*) as total FROM `{settings.DB_USER_TABLE_NAME}`" + where_clause
  236. cursor.execute(count_query, tuple(params))
  237. total_count = cursor.fetchone()["total"]
  238. order_sql = "ASC" if sort_order.lower() == "asc" else "DESC"
  239. data_query = (
  240. f"SELECT id, username, nickname, is_admin, created_at, updated_at "
  241. f"FROM `{settings.DB_USER_TABLE_NAME}`"
  242. f"{where_clause} ORDER BY created_at {order_sql}, id DESC LIMIT %s OFFSET %s"
  243. )
  244. data_params = params.copy()
  245. data_params.extend([limit, skip])
  246. cursor.execute(data_query, tuple(data_params))
  247. users = [_normalize_user(user) for user in cursor.fetchall()]
  248. return {
  249. "data": {
  250. "total": total_count,
  251. "list": users
  252. }
  253. }
  254. except Exception as e:
  255. logger.error(f"Query user list failed: {e}")
  256. raise HTTPException(status_code=500, detail="查询用户列表失败")
  257. @router.post("/grant_admin", response_model=UserResponse, summary="给用户授予管理员权限")
  258. def grant_current_user_admin(
  259. data: AdminGrantRequest,
  260. current_user: dict = Depends(get_current_user),
  261. db_conn: PooledMySQLConnection = db_dependency
  262. ):
  263. if not hmac.compare_digest(data.key, ADMIN_GRANT_KEY):
  264. raise HTTPException(status_code=403, detail="Invalid admin key")
  265. try:
  266. target_user_id = data.user_id or current_user["id"]
  267. if not _get_user_by_id(db_conn, target_user_id):
  268. raise HTTPException(status_code=404, detail="未发现该用户")
  269. with db_conn.cursor() as cursor:
  270. cursor.execute(
  271. f"UPDATE `{settings.DB_USER_TABLE_NAME}` SET is_admin = TRUE WHERE id = %s",
  272. (target_user_id,)
  273. )
  274. db_conn.commit()
  275. return UserResponse.model_validate(_get_user_by_id(db_conn, target_user_id))
  276. except HTTPException:
  277. db_conn.rollback()
  278. raise
  279. except Exception as e:
  280. db_conn.rollback()
  281. logger.error(f"Grant admin failed: {e}")
  282. raise HTTPException(status_code=500, detail="Grant admin failed")
  283. @router.post("/bind_card", status_code=200, summary="给用户绑定卡片ID")
  284. def bind_card_to_user(
  285. data: BindCardRequest,
  286. current_user: dict = Depends(require_admin_user),
  287. db_conn: PooledMySQLConnection = db_dependency
  288. ):
  289. try:
  290. # Keep the request field as card_id, but accept multiple card ids.
  291. card_ids = list(dict.fromkeys(data.card_id))
  292. with db_conn.cursor(dictionary=True) as cursor:
  293. cursor.execute(f"SELECT id FROM `{settings.DB_USER_TABLE_NAME}` WHERE id = %s", (data.user_id,))
  294. if not cursor.fetchone():
  295. raise HTTPException(status_code=404, detail="未发现该用户")
  296. format_strings = ",".join(["%s"] * len(card_ids))
  297. cursor.execute(
  298. f"SELECT id FROM `{settings.DB_CARD_TABLE_NAME}` WHERE id IN ({format_strings})",
  299. tuple(card_ids)
  300. )
  301. existing_card_ids = {row["id"] for row in cursor.fetchall()}
  302. missing_card_ids = sorted(set(card_ids) - existing_card_ids)
  303. if missing_card_ids:
  304. raise HTTPException(status_code=404, detail=f"卡片未发现: {missing_card_ids}")
  305. bind_params = [(data.user_id, card_id) for card_id in card_ids]
  306. cursor.executemany(
  307. f"INSERT IGNORE INTO `{settings.DB_USER_CARD_TABLE_NAME}` (user_id, card_id) VALUES (%s, %s)",
  308. bind_params
  309. )
  310. inserted_count = cursor.rowcount
  311. db_conn.commit()
  312. logger.info(f"Admin {current_user['id']} bound cards {card_ids} to user {data.user_id}")
  313. return {
  314. "message": "卡片绑定成功",
  315. "user_id": data.user_id,
  316. "card_id": card_ids,
  317. "inserted_count": inserted_count
  318. }
  319. except HTTPException:
  320. db_conn.rollback()
  321. raise
  322. except Exception as e:
  323. db_conn.rollback()
  324. logger.error(f"Bind card failed: {e}")
  325. raise HTTPException(status_code=500, detail="卡片绑定失败")