| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417 |
- import base64
- import hashlib
- import hmac
- import json
- import secrets
- from datetime import datetime, timedelta, timezone
- from typing import Optional, List
- from fastapi import APIRouter, Depends, HTTPException, Query, status
- from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
- from mysql.connector.pooling import PooledMySQLConnection
- from pydantic import BaseModel, ConfigDict, Field
- from app.core.config import settings
- from app.core.database_loader import get_db_connection
- from app.core.logger import get_logger
- logger = get_logger(__name__)
- router = APIRouter()
- db_dependency = Depends(get_db_connection)
- oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.API_PREFIX}/users/login")
- # Hard-coded by design: used only for promoting the current user to admin.
- ADMIN_GRANT_KEY = settings.ADMIN_GRANT_KEY
- TOKEN_SECRET_KEY = settings.TOKEN_SECRET_KEY
- TOKEN_EXPIRE_MINUTES = 60 * 24 * 7
- class UserRegisterRequest(BaseModel):
- username: str = Field(..., min_length=6, max_length=20, pattern=r"^[A-Za-z0-9]+$")
- nickname: str = Field(..., min_length=1, max_length=20)
- password: str = Field(..., min_length=6, max_length=20, pattern=r"^[A-Za-z0-9]+$")
- class UserUpdateRequest(BaseModel):
- model_config = ConfigDict(extra="forbid")
- nickname: Optional[str] = Field(None, min_length=1, max_length=20)
- password: Optional[str] = Field(None, min_length=6, max_length=20, pattern=r"^[A-Za-z0-9]+$")
- class UserLoginRequest(BaseModel):
- username: str = Field(..., min_length=6, max_length=20, pattern=r"^[A-Za-z0-9]+$")
- password: str = Field(..., min_length=6, max_length=20, pattern=r"^[A-Za-z0-9]+$")
- class AdminGrantRequest(BaseModel):
- key: str
- user_id: Optional[int] = None
- class BindCardRequest(BaseModel):
- user_id: int
- card_id: List[int] = Field(..., min_length=1)
- class UserResponse(BaseModel):
- id: int
- username: str
- nickname: str
- is_admin: bool
- class UserListItem(BaseModel):
- id: int
- username: str
- nickname: str
- is_admin: bool
- created_at: datetime
- updated_at: datetime
- class UserListWithTotal(BaseModel):
- total: int
- list: List[UserListItem]
- class UserListResponseWrapper(BaseModel):
- data: UserListWithTotal
- class TokenResponse(BaseModel):
- access_token: str
- token_type: str = "bearer"
- def _auth_exception() -> HTTPException:
- return HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Invalid or expired token",
- headers={"WWW-Authenticate": "Bearer"},
- )
- def _hash_password(password: str, salt: Optional[str] = None) -> str:
- salt = salt or secrets.token_hex(16)
- password_hash = hashlib.sha256(f"{salt}:{password}".encode("utf-8")).hexdigest()
- return f"{salt}${password_hash}"
- def _verify_password(password: str, stored_password: str) -> bool:
- try:
- salt, old_hash = stored_password.split("$", 1)
- except ValueError:
- return False
- return hmac.compare_digest(_hash_password(password, salt), f"{salt}${old_hash}")
- def _b64encode(data: bytes) -> str:
- return base64.urlsafe_b64encode(data).decode("utf-8").rstrip("=")
- def _b64decode(data: str) -> bytes:
- padding = "=" * (-len(data) % 4)
- return base64.urlsafe_b64decode(f"{data}{padding}".encode("utf-8"))
- def _create_access_token(user_id: int) -> str:
- expires_at = datetime.now(timezone.utc) + timedelta(minutes=TOKEN_EXPIRE_MINUTES)
- payload = {"user_id": user_id, "exp": int(expires_at.timestamp())}
- body = _b64encode(json.dumps(payload, separators=(",", ":")).encode("utf-8"))
- signature = hmac.new(TOKEN_SECRET_KEY.encode("utf-8"), body.encode("utf-8"), hashlib.sha256).hexdigest()
- return f"{body}.{signature}"
- def _decode_access_token(token: str) -> int:
- try:
- body, signature = token.split(".", 1)
- expected_signature = hmac.new(
- TOKEN_SECRET_KEY.encode("utf-8"),
- body.encode("utf-8"),
- hashlib.sha256
- ).hexdigest()
- if not hmac.compare_digest(signature, expected_signature):
- raise ValueError("bad signature")
- payload = json.loads(_b64decode(body))
- if int(payload.get("exp", 0)) < int(datetime.now(timezone.utc).timestamp()):
- raise ValueError("expired")
- return int(payload["user_id"])
- except Exception:
- raise _auth_exception()
- def _normalize_user(user: dict) -> dict:
- user["is_admin"] = bool(user.get("is_admin"))
- return user
- def _get_user_by_id(db_conn: PooledMySQLConnection, user_id: int) -> Optional[dict]:
- with db_conn.cursor(dictionary=True) as cursor:
- cursor.execute(
- f"SELECT id, username, nickname, is_admin FROM `{settings.DB_USER_TABLE_NAME}` WHERE id = %s",
- (user_id,)
- )
- user = cursor.fetchone()
- return _normalize_user(user) if user else None
- def _get_user_by_username(db_conn: PooledMySQLConnection, username: str) -> Optional[dict]:
- with db_conn.cursor(dictionary=True) as cursor:
- cursor.execute(
- f"SELECT id, username, nickname, password, is_admin FROM `{settings.DB_USER_TABLE_NAME}` WHERE username = %s",
- (username,)
- )
- user = cursor.fetchone()
- return _normalize_user(user) if user else None
- def get_current_user(
- token: str = Depends(oauth2_scheme),
- db_conn: PooledMySQLConnection = db_dependency
- ) -> dict:
- user_id = _decode_access_token(token)
- user = _get_user_by_id(db_conn, user_id)
- if not user:
- raise _auth_exception()
- return user
- def require_admin_user(current_user: dict = Depends(get_current_user)) -> dict:
- if not current_user.get("is_admin"):
- raise HTTPException(status_code=403, detail="该请求需要管理员权限")
- return current_user
- def check_card_permission(db_conn: PooledMySQLConnection, current_user: dict, card_id: int):
- if current_user.get("is_admin"):
- return
- with db_conn.cursor() as cursor:
- cursor.execute(
- f"SELECT 1 FROM `{settings.DB_USER_CARD_TABLE_NAME}` WHERE user_id = %s AND card_id = %s LIMIT 1",
- (current_user["id"], card_id)
- )
- if cursor.fetchone():
- return
- raise HTTPException(status_code=403, detail="没有该卡片权限")
- @router.post("/register", response_model=UserResponse, summary="注册用户")
- def register_user(data: UserRegisterRequest, db_conn: PooledMySQLConnection = db_dependency):
- try:
- if _get_user_by_username(db_conn, data.username):
- raise HTTPException(status_code=400, detail="该用户名已经存在")
- with db_conn.cursor(dictionary=True) as cursor:
- cursor.execute(
- f"INSERT INTO `{settings.DB_USER_TABLE_NAME}` (username, nickname, password) VALUES (%s, %s, %s)",
- (data.username, data.nickname, _hash_password(data.password))
- )
- db_conn.commit()
- user_id = cursor.lastrowid
- logger.info(f"User created, id: {user_id}, username: {data.username}")
- user = _get_user_by_id(db_conn, user_id)
- return UserResponse.model_validate(user)
- except HTTPException:
- db_conn.rollback()
- raise
- except Exception as e:
- db_conn.rollback()
- logger.error(f"Register user failed: {e}")
- raise HTTPException(status_code=500, detail="注册用户失败")
- @router.post("/login", response_model=TokenResponse, summary="用户登录")
- def login_for_access_token(
- form_data: OAuth2PasswordRequestForm = Depends(),
- db_conn: PooledMySQLConnection = db_dependency
- ):
- username = form_data.username
- password = form_data.password
- user = _get_user_by_username(db_conn, username)
- if not user or not _verify_password(password, user["password"]):
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="用户名或密码错误",
- headers={"WWW-Authenticate": "Bearer"},
- )
- return TokenResponse(access_token=_create_access_token(user["id"]))
- @router.put("/me", response_model=UserResponse, summary="修改当前的昵称或密码")
- def update_current_user(
- data: UserUpdateRequest,
- current_user: dict = Depends(get_current_user),
- db_conn: PooledMySQLConnection = db_dependency
- ):
- update_fields = []
- params = []
- if data.nickname is not None:
- update_fields.append("nickname = %s")
- params.append(data.nickname)
- if data.password is not None:
- update_fields.append("password = %s")
- params.append(_hash_password(data.password))
- if not update_fields:
- raise HTTPException(status_code=400, detail="No fields to update")
- try:
- params.append(current_user["id"])
- with db_conn.cursor() as cursor:
- query = (
- f"UPDATE `{settings.DB_USER_TABLE_NAME}` SET "
- f"{', '.join(update_fields)} WHERE id = %s"
- )
- cursor.execute(query, tuple(params))
- db_conn.commit()
- return UserResponse.model_validate(_get_user_by_id(db_conn, current_user["id"]))
- except Exception as e:
- db_conn.rollback()
- logger.error(f"Update user failed: {e}")
- raise HTTPException(status_code=500, detail="修改用户失败")
- @router.get("/list", response_model=UserListResponseWrapper, summary="查询用户列表")
- def list_users(
- nickname: Optional[str] = Query(None, description="按昵称模糊搜索"),
- sort_order: str = Query("desc", pattern="^(asc|desc)$", description="按创建时间排序: asc 或 desc"),
- skip: int = Query(0, ge=0),
- page_num: Optional[int] = Query(None, ge=1),
- limit: int = Query(100, ge=1, le=1000),
- current_user: dict = Depends(require_admin_user),
- db_conn: PooledMySQLConnection = db_dependency
- ):
- if page_num is not None:
- skip = (page_num - 1) * limit
- try:
- with db_conn.cursor(dictionary=True) as cursor:
- conditions = []
- params = []
- if nickname:
- conditions.append("nickname LIKE %s")
- params.append(f"%{nickname}%")
- where_clause = ""
- if conditions:
- where_clause = " WHERE " + " AND ".join(conditions)
- count_query = f"SELECT COUNT(*) as total FROM `{settings.DB_USER_TABLE_NAME}`" + where_clause
- cursor.execute(count_query, tuple(params))
- total_count = cursor.fetchone()["total"]
- order_sql = "ASC" if sort_order.lower() == "asc" else "DESC"
- data_query = (
- f"SELECT id, username, nickname, is_admin, created_at, updated_at "
- f"FROM `{settings.DB_USER_TABLE_NAME}`"
- f"{where_clause} ORDER BY created_at {order_sql}, id DESC LIMIT %s OFFSET %s"
- )
- data_params = params.copy()
- data_params.extend([limit, skip])
- cursor.execute(data_query, tuple(data_params))
- users = [_normalize_user(user) for user in cursor.fetchall()]
- return {
- "data": {
- "total": total_count,
- "list": users
- }
- }
- except Exception as e:
- logger.error(f"Query user list failed: {e}")
- raise HTTPException(status_code=500, detail="查询用户列表失败")
- @router.post("/grant_admin", response_model=UserResponse, summary="给用户授予管理员权限")
- def grant_current_user_admin(
- data: AdminGrantRequest,
- current_user: dict = Depends(get_current_user),
- db_conn: PooledMySQLConnection = db_dependency
- ):
- if not hmac.compare_digest(data.key, ADMIN_GRANT_KEY):
- raise HTTPException(status_code=403, detail="Invalid admin key")
- try:
- target_user_id = data.user_id or current_user["id"]
- if not _get_user_by_id(db_conn, target_user_id):
- raise HTTPException(status_code=404, detail="未发现该用户")
- with db_conn.cursor() as cursor:
- cursor.execute(
- f"UPDATE `{settings.DB_USER_TABLE_NAME}` SET is_admin = TRUE WHERE id = %s",
- (target_user_id,)
- )
- db_conn.commit()
- return UserResponse.model_validate(_get_user_by_id(db_conn, target_user_id))
- except HTTPException:
- db_conn.rollback()
- raise
- except Exception as e:
- db_conn.rollback()
- logger.error(f"Grant admin failed: {e}")
- raise HTTPException(status_code=500, detail="Grant admin failed")
- @router.post("/bind_card", status_code=200, summary="给用户绑定卡片ID")
- def bind_card_to_user(
- data: BindCardRequest,
- current_user: dict = Depends(require_admin_user),
- db_conn: PooledMySQLConnection = db_dependency
- ):
- try:
- # Keep the request field as card_id, but accept multiple card ids.
- card_ids = list(dict.fromkeys(data.card_id))
- with db_conn.cursor(dictionary=True) as cursor:
- cursor.execute(f"SELECT id FROM `{settings.DB_USER_TABLE_NAME}` WHERE id = %s", (data.user_id,))
- if not cursor.fetchone():
- raise HTTPException(status_code=404, detail="未发现该用户")
- format_strings = ",".join(["%s"] * len(card_ids))
- cursor.execute(
- f"SELECT id FROM `{settings.DB_CARD_TABLE_NAME}` WHERE id IN ({format_strings})",
- tuple(card_ids)
- )
- existing_card_ids = {row["id"] for row in cursor.fetchall()}
- missing_card_ids = sorted(set(card_ids) - existing_card_ids)
- if missing_card_ids:
- raise HTTPException(status_code=404, detail=f"卡片未发现: {missing_card_ids}")
- bind_params = [(data.user_id, card_id) for card_id in card_ids]
- cursor.executemany(
- f"INSERT IGNORE INTO `{settings.DB_USER_CARD_TABLE_NAME}` (user_id, card_id) VALUES (%s, %s)",
- bind_params
- )
- inserted_count = cursor.rowcount
- db_conn.commit()
- logger.info(f"Admin {current_user['id']} bound cards {card_ids} to user {data.user_id}")
- return {
- "message": "卡片绑定成功",
- "user_id": data.user_id,
- "card_id": card_ids,
- "inserted_count": inserted_count
- }
- except HTTPException:
- db_conn.rollback()
- raise
- except Exception as e:
- db_conn.rollback()
- logger.error(f"Bind card failed: {e}")
- raise HTTPException(status_code=500, detail="卡片绑定失败")
|