import base64 import hashlib import hmac import json import secrets from datetime import datetime, timedelta, timezone from typing import Optional, List from fastapi import APIRouter, Depends, HTTPException, Query, status from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from mysql.connector.pooling import PooledMySQLConnection from pydantic import BaseModel, ConfigDict, Field from app.core.config import settings from app.core.database_loader import get_db_connection from app.core.logger import get_logger logger = get_logger(__name__) router = APIRouter() db_dependency = Depends(get_db_connection) oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.API_PREFIX}/users/login") # Hard-coded by design: used only for promoting the current user to admin. ADMIN_GRANT_KEY = settings.ADMIN_GRANT_KEY TOKEN_SECRET_KEY = settings.TOKEN_SECRET_KEY TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 class UserRegisterRequest(BaseModel): username: str = Field(..., min_length=6, max_length=20, pattern=r"^[A-Za-z0-9]+$") nickname: str = Field(..., min_length=1, max_length=20) password: str = Field(..., min_length=6, max_length=20, pattern=r"^[A-Za-z0-9]+$") class UserUpdateRequest(BaseModel): model_config = ConfigDict(extra="forbid") nickname: Optional[str] = Field(None, min_length=1, max_length=20) password: Optional[str] = Field(None, min_length=6, max_length=20, pattern=r"^[A-Za-z0-9]+$") class UserLoginRequest(BaseModel): username: str = Field(..., min_length=6, max_length=20, pattern=r"^[A-Za-z0-9]+$") password: str = Field(..., min_length=6, max_length=20, pattern=r"^[A-Za-z0-9]+$") class AdminGrantRequest(BaseModel): key: str user_id: Optional[int] = None class BindCardRequest(BaseModel): user_id: int card_id: List[int] = Field(..., min_length=1) class UserResponse(BaseModel): id: int username: str nickname: str is_admin: bool class UserListItem(BaseModel): id: int username: str nickname: str is_admin: bool created_at: datetime updated_at: datetime class UserListWithTotal(BaseModel): total: int list: List[UserListItem] class UserListResponseWrapper(BaseModel): data: UserListWithTotal class TokenResponse(BaseModel): access_token: str token_type: str = "bearer" def _auth_exception() -> HTTPException: return HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired token", headers={"WWW-Authenticate": "Bearer"}, ) def _hash_password(password: str, salt: Optional[str] = None) -> str: salt = salt or secrets.token_hex(16) password_hash = hashlib.sha256(f"{salt}:{password}".encode("utf-8")).hexdigest() return f"{salt}${password_hash}" def _verify_password(password: str, stored_password: str) -> bool: try: salt, old_hash = stored_password.split("$", 1) except ValueError: return False return hmac.compare_digest(_hash_password(password, salt), f"{salt}${old_hash}") def _b64encode(data: bytes) -> str: return base64.urlsafe_b64encode(data).decode("utf-8").rstrip("=") def _b64decode(data: str) -> bytes: padding = "=" * (-len(data) % 4) return base64.urlsafe_b64decode(f"{data}{padding}".encode("utf-8")) def _create_access_token(user_id: int) -> str: expires_at = datetime.now(timezone.utc) + timedelta(minutes=TOKEN_EXPIRE_MINUTES) payload = {"user_id": user_id, "exp": int(expires_at.timestamp())} body = _b64encode(json.dumps(payload, separators=(",", ":")).encode("utf-8")) signature = hmac.new(TOKEN_SECRET_KEY.encode("utf-8"), body.encode("utf-8"), hashlib.sha256).hexdigest() return f"{body}.{signature}" def _decode_access_token(token: str) -> int: try: body, signature = token.split(".", 1) expected_signature = hmac.new( TOKEN_SECRET_KEY.encode("utf-8"), body.encode("utf-8"), hashlib.sha256 ).hexdigest() if not hmac.compare_digest(signature, expected_signature): raise ValueError("bad signature") payload = json.loads(_b64decode(body)) if int(payload.get("exp", 0)) < int(datetime.now(timezone.utc).timestamp()): raise ValueError("expired") return int(payload["user_id"]) except Exception: raise _auth_exception() def _normalize_user(user: dict) -> dict: user["is_admin"] = bool(user.get("is_admin")) return user def _get_user_by_id(db_conn: PooledMySQLConnection, user_id: int) -> Optional[dict]: with db_conn.cursor(dictionary=True) as cursor: cursor.execute( f"SELECT id, username, nickname, is_admin FROM `{settings.DB_USER_TABLE_NAME}` WHERE id = %s", (user_id,) ) user = cursor.fetchone() return _normalize_user(user) if user else None def _get_user_by_username(db_conn: PooledMySQLConnection, username: str) -> Optional[dict]: with db_conn.cursor(dictionary=True) as cursor: cursor.execute( f"SELECT id, username, nickname, password, is_admin FROM `{settings.DB_USER_TABLE_NAME}` WHERE username = %s", (username,) ) user = cursor.fetchone() return _normalize_user(user) if user else None def get_current_user( token: str = Depends(oauth2_scheme), db_conn: PooledMySQLConnection = db_dependency ) -> dict: user_id = _decode_access_token(token) user = _get_user_by_id(db_conn, user_id) if not user: raise _auth_exception() return user def require_admin_user(current_user: dict = Depends(get_current_user)) -> dict: if not current_user.get("is_admin"): raise HTTPException(status_code=403, detail="该请求需要管理员权限") return current_user def check_card_permission(db_conn: PooledMySQLConnection, current_user: dict, card_id: int): if current_user.get("is_admin"): return with db_conn.cursor() as cursor: cursor.execute( f"SELECT 1 FROM `{settings.DB_USER_CARD_TABLE_NAME}` WHERE user_id = %s AND card_id = %s LIMIT 1", (current_user["id"], card_id) ) if cursor.fetchone(): return raise HTTPException(status_code=403, detail="没有该卡片权限") @router.post("/register", response_model=UserResponse, summary="注册用户") def register_user(data: UserRegisterRequest, db_conn: PooledMySQLConnection = db_dependency): try: if _get_user_by_username(db_conn, data.username): raise HTTPException(status_code=400, detail="该用户名已经存在") with db_conn.cursor(dictionary=True) as cursor: cursor.execute( f"INSERT INTO `{settings.DB_USER_TABLE_NAME}` (username, nickname, password) VALUES (%s, %s, %s)", (data.username, data.nickname, _hash_password(data.password)) ) db_conn.commit() user_id = cursor.lastrowid logger.info(f"User created, id: {user_id}, username: {data.username}") user = _get_user_by_id(db_conn, user_id) return UserResponse.model_validate(user) except HTTPException: db_conn.rollback() raise except Exception as e: db_conn.rollback() logger.error(f"Register user failed: {e}") raise HTTPException(status_code=500, detail="注册用户失败") @router.post("/login", response_model=TokenResponse, summary="用户登录") def login_for_access_token( form_data: OAuth2PasswordRequestForm = Depends(), db_conn: PooledMySQLConnection = db_dependency ): username = form_data.username password = form_data.password user = _get_user_by_username(db_conn, username) if not user or not _verify_password(password, user["password"]): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="用户名或密码错误", headers={"WWW-Authenticate": "Bearer"}, ) return TokenResponse(access_token=_create_access_token(user["id"])) @router.put("/me", response_model=UserResponse, summary="修改当前的昵称或密码") def update_current_user( data: UserUpdateRequest, current_user: dict = Depends(get_current_user), db_conn: PooledMySQLConnection = db_dependency ): update_fields = [] params = [] if data.nickname is not None: update_fields.append("nickname = %s") params.append(data.nickname) if data.password is not None: update_fields.append("password = %s") params.append(_hash_password(data.password)) if not update_fields: raise HTTPException(status_code=400, detail="No fields to update") try: params.append(current_user["id"]) with db_conn.cursor() as cursor: query = ( f"UPDATE `{settings.DB_USER_TABLE_NAME}` SET " f"{', '.join(update_fields)} WHERE id = %s" ) cursor.execute(query, tuple(params)) db_conn.commit() return UserResponse.model_validate(_get_user_by_id(db_conn, current_user["id"])) except Exception as e: db_conn.rollback() logger.error(f"Update user failed: {e}") raise HTTPException(status_code=500, detail="修改用户失败") @router.get("/list", response_model=UserListResponseWrapper, summary="查询用户列表") def list_users( nickname: Optional[str] = Query(None, description="按昵称模糊搜索"), sort_order: str = Query("desc", pattern="^(asc|desc)$", description="按创建时间排序: asc 或 desc"), skip: int = Query(0, ge=0), page_num: Optional[int] = Query(None, ge=1), limit: int = Query(100, ge=1, le=1000), current_user: dict = Depends(require_admin_user), db_conn: PooledMySQLConnection = db_dependency ): if page_num is not None: skip = (page_num - 1) * limit try: with db_conn.cursor(dictionary=True) as cursor: conditions = [] params = [] if nickname: conditions.append("nickname LIKE %s") params.append(f"%{nickname}%") where_clause = "" if conditions: where_clause = " WHERE " + " AND ".join(conditions) count_query = f"SELECT COUNT(*) as total FROM `{settings.DB_USER_TABLE_NAME}`" + where_clause cursor.execute(count_query, tuple(params)) total_count = cursor.fetchone()["total"] order_sql = "ASC" if sort_order.lower() == "asc" else "DESC" data_query = ( f"SELECT id, username, nickname, is_admin, created_at, updated_at " f"FROM `{settings.DB_USER_TABLE_NAME}`" f"{where_clause} ORDER BY created_at {order_sql}, id DESC LIMIT %s OFFSET %s" ) data_params = params.copy() data_params.extend([limit, skip]) cursor.execute(data_query, tuple(data_params)) users = [_normalize_user(user) for user in cursor.fetchall()] return { "data": { "total": total_count, "list": users } } except Exception as e: logger.error(f"Query user list failed: {e}") raise HTTPException(status_code=500, detail="查询用户列表失败") @router.post("/grant_admin", response_model=UserResponse, summary="给用户授予管理员权限") def grant_current_user_admin( data: AdminGrantRequest, current_user: dict = Depends(get_current_user), db_conn: PooledMySQLConnection = db_dependency ): if not hmac.compare_digest(data.key, ADMIN_GRANT_KEY): raise HTTPException(status_code=403, detail="Invalid admin key") try: target_user_id = data.user_id or current_user["id"] if not _get_user_by_id(db_conn, target_user_id): raise HTTPException(status_code=404, detail="未发现该用户") with db_conn.cursor() as cursor: cursor.execute( f"UPDATE `{settings.DB_USER_TABLE_NAME}` SET is_admin = TRUE WHERE id = %s", (target_user_id,) ) db_conn.commit() return UserResponse.model_validate(_get_user_by_id(db_conn, target_user_id)) except HTTPException: db_conn.rollback() raise except Exception as e: db_conn.rollback() logger.error(f"Grant admin failed: {e}") raise HTTPException(status_code=500, detail="Grant admin failed") @router.post("/bind_card", status_code=200, summary="给用户绑定卡片ID") def bind_card_to_user( data: BindCardRequest, current_user: dict = Depends(require_admin_user), db_conn: PooledMySQLConnection = db_dependency ): try: # Keep the request field as card_id, but accept multiple card ids. card_ids = list(dict.fromkeys(data.card_id)) with db_conn.cursor(dictionary=True) as cursor: cursor.execute(f"SELECT id FROM `{settings.DB_USER_TABLE_NAME}` WHERE id = %s", (data.user_id,)) if not cursor.fetchone(): raise HTTPException(status_code=404, detail="未发现该用户") format_strings = ",".join(["%s"] * len(card_ids)) cursor.execute( f"SELECT id FROM `{settings.DB_CARD_TABLE_NAME}` WHERE id IN ({format_strings})", tuple(card_ids) ) existing_card_ids = {row["id"] for row in cursor.fetchall()} missing_card_ids = sorted(set(card_ids) - existing_card_ids) if missing_card_ids: raise HTTPException(status_code=404, detail=f"卡片未发现: {missing_card_ids}") bind_params = [(data.user_id, card_id) for card_id in card_ids] cursor.executemany( f"INSERT IGNORE INTO `{settings.DB_USER_CARD_TABLE_NAME}` (user_id, card_id) VALUES (%s, %s)", bind_params ) inserted_count = cursor.rowcount db_conn.commit() logger.info(f"Admin {current_user['id']} bound cards {card_ids} to user {data.user_id}") return { "message": "卡片绑定成功", "user_id": data.user_id, "card_id": card_ids, "inserted_count": inserted_count } except HTTPException: db_conn.rollback() raise except Exception as e: db_conn.rollback() logger.error(f"Bind card failed: {e}") raise HTTPException(status_code=500, detail="卡片绑定失败")