|
|
@@ -0,0 +1,417 @@
|
|
|
+import base64
|
|
|
+import hashlib
|
|
|
+import hmac
|
|
|
+import json
|
|
|
+import secrets
|
|
|
+from datetime import datetime, timedelta, timezone
|
|
|
+from typing import Optional, List
|
|
|
+
|
|
|
+from fastapi import APIRouter, Depends, HTTPException, Query, status
|
|
|
+from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
|
|
+from mysql.connector.pooling import PooledMySQLConnection
|
|
|
+from pydantic import BaseModel, ConfigDict, Field
|
|
|
+
|
|
|
+from app.core.config import settings
|
|
|
+from app.core.database_loader import get_db_connection
|
|
|
+from app.core.logger import get_logger
|
|
|
+
|
|
|
+logger = get_logger(__name__)
|
|
|
+router = APIRouter()
|
|
|
+db_dependency = Depends(get_db_connection)
|
|
|
+
|
|
|
+oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.API_PREFIX}/users/login")
|
|
|
+
|
|
|
+# Hard-coded by design: used only for promoting the current user to admin.
|
|
|
+ADMIN_GRANT_KEY = settings.ADMIN_GRANT_KEY
|
|
|
+TOKEN_SECRET_KEY = settings.TOKEN_SECRET_KEY
|
|
|
+TOKEN_EXPIRE_MINUTES = 60 * 24 * 7
|
|
|
+
|
|
|
+
|
|
|
+class UserRegisterRequest(BaseModel):
|
|
|
+ username: str = Field(..., min_length=6, max_length=20, pattern=r"^[A-Za-z0-9]+$")
|
|
|
+ nickname: str = Field(..., min_length=1, max_length=20)
|
|
|
+ password: str = Field(..., min_length=6, max_length=20, pattern=r"^[A-Za-z0-9]+$")
|
|
|
+
|
|
|
+
|
|
|
+class UserUpdateRequest(BaseModel):
|
|
|
+ model_config = ConfigDict(extra="forbid")
|
|
|
+
|
|
|
+ nickname: Optional[str] = Field(None, min_length=1, max_length=20)
|
|
|
+ password: Optional[str] = Field(None, min_length=6, max_length=20, pattern=r"^[A-Za-z0-9]+$")
|
|
|
+
|
|
|
+
|
|
|
+class UserLoginRequest(BaseModel):
|
|
|
+ username: str = Field(..., min_length=6, max_length=20, pattern=r"^[A-Za-z0-9]+$")
|
|
|
+ password: str = Field(..., min_length=6, max_length=20, pattern=r"^[A-Za-z0-9]+$")
|
|
|
+
|
|
|
+
|
|
|
+class AdminGrantRequest(BaseModel):
|
|
|
+ key: str
|
|
|
+ user_id: Optional[int] = None
|
|
|
+
|
|
|
+
|
|
|
+class BindCardRequest(BaseModel):
|
|
|
+ user_id: int
|
|
|
+ card_id: List[int] = Field(..., min_length=1)
|
|
|
+
|
|
|
+
|
|
|
+class UserResponse(BaseModel):
|
|
|
+ id: int
|
|
|
+ username: str
|
|
|
+ nickname: str
|
|
|
+ is_admin: bool
|
|
|
+
|
|
|
+
|
|
|
+class UserListItem(BaseModel):
|
|
|
+ id: int
|
|
|
+ username: str
|
|
|
+ nickname: str
|
|
|
+ is_admin: bool
|
|
|
+ created_at: datetime
|
|
|
+ updated_at: datetime
|
|
|
+
|
|
|
+
|
|
|
+class UserListWithTotal(BaseModel):
|
|
|
+ total: int
|
|
|
+ list: List[UserListItem]
|
|
|
+
|
|
|
+
|
|
|
+class UserListResponseWrapper(BaseModel):
|
|
|
+ data: UserListWithTotal
|
|
|
+
|
|
|
+
|
|
|
+class TokenResponse(BaseModel):
|
|
|
+ access_token: str
|
|
|
+ token_type: str = "bearer"
|
|
|
+
|
|
|
+
|
|
|
+def _auth_exception() -> HTTPException:
|
|
|
+ return HTTPException(
|
|
|
+ status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
+ detail="Invalid or expired token",
|
|
|
+ headers={"WWW-Authenticate": "Bearer"},
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+def _hash_password(password: str, salt: Optional[str] = None) -> str:
|
|
|
+ salt = salt or secrets.token_hex(16)
|
|
|
+ password_hash = hashlib.sha256(f"{salt}:{password}".encode("utf-8")).hexdigest()
|
|
|
+ return f"{salt}${password_hash}"
|
|
|
+
|
|
|
+
|
|
|
+def _verify_password(password: str, stored_password: str) -> bool:
|
|
|
+ try:
|
|
|
+ salt, old_hash = stored_password.split("$", 1)
|
|
|
+ except ValueError:
|
|
|
+ return False
|
|
|
+ return hmac.compare_digest(_hash_password(password, salt), f"{salt}${old_hash}")
|
|
|
+
|
|
|
+
|
|
|
+def _b64encode(data: bytes) -> str:
|
|
|
+ return base64.urlsafe_b64encode(data).decode("utf-8").rstrip("=")
|
|
|
+
|
|
|
+
|
|
|
+def _b64decode(data: str) -> bytes:
|
|
|
+ padding = "=" * (-len(data) % 4)
|
|
|
+ return base64.urlsafe_b64decode(f"{data}{padding}".encode("utf-8"))
|
|
|
+
|
|
|
+
|
|
|
+def _create_access_token(user_id: int) -> str:
|
|
|
+ expires_at = datetime.now(timezone.utc) + timedelta(minutes=TOKEN_EXPIRE_MINUTES)
|
|
|
+ payload = {"user_id": user_id, "exp": int(expires_at.timestamp())}
|
|
|
+ body = _b64encode(json.dumps(payload, separators=(",", ":")).encode("utf-8"))
|
|
|
+ signature = hmac.new(TOKEN_SECRET_KEY.encode("utf-8"), body.encode("utf-8"), hashlib.sha256).hexdigest()
|
|
|
+ return f"{body}.{signature}"
|
|
|
+
|
|
|
+
|
|
|
+def _decode_access_token(token: str) -> int:
|
|
|
+ try:
|
|
|
+ body, signature = token.split(".", 1)
|
|
|
+ expected_signature = hmac.new(
|
|
|
+ TOKEN_SECRET_KEY.encode("utf-8"),
|
|
|
+ body.encode("utf-8"),
|
|
|
+ hashlib.sha256
|
|
|
+ ).hexdigest()
|
|
|
+ if not hmac.compare_digest(signature, expected_signature):
|
|
|
+ raise ValueError("bad signature")
|
|
|
+
|
|
|
+ payload = json.loads(_b64decode(body))
|
|
|
+ if int(payload.get("exp", 0)) < int(datetime.now(timezone.utc).timestamp()):
|
|
|
+ raise ValueError("expired")
|
|
|
+ return int(payload["user_id"])
|
|
|
+ except Exception:
|
|
|
+ raise _auth_exception()
|
|
|
+
|
|
|
+
|
|
|
+def _normalize_user(user: dict) -> dict:
|
|
|
+ user["is_admin"] = bool(user.get("is_admin"))
|
|
|
+ return user
|
|
|
+
|
|
|
+
|
|
|
+def _get_user_by_id(db_conn: PooledMySQLConnection, user_id: int) -> Optional[dict]:
|
|
|
+ with db_conn.cursor(dictionary=True) as cursor:
|
|
|
+ cursor.execute(
|
|
|
+ f"SELECT id, username, nickname, is_admin FROM `{settings.DB_USER_TABLE_NAME}` WHERE id = %s",
|
|
|
+ (user_id,)
|
|
|
+ )
|
|
|
+ user = cursor.fetchone()
|
|
|
+ return _normalize_user(user) if user else None
|
|
|
+
|
|
|
+
|
|
|
+def _get_user_by_username(db_conn: PooledMySQLConnection, username: str) -> Optional[dict]:
|
|
|
+ with db_conn.cursor(dictionary=True) as cursor:
|
|
|
+ cursor.execute(
|
|
|
+ f"SELECT id, username, nickname, password, is_admin FROM `{settings.DB_USER_TABLE_NAME}` WHERE username = %s",
|
|
|
+ (username,)
|
|
|
+ )
|
|
|
+ user = cursor.fetchone()
|
|
|
+ return _normalize_user(user) if user else None
|
|
|
+
|
|
|
+
|
|
|
+def get_current_user(
|
|
|
+ token: str = Depends(oauth2_scheme),
|
|
|
+ db_conn: PooledMySQLConnection = db_dependency
|
|
|
+) -> dict:
|
|
|
+ user_id = _decode_access_token(token)
|
|
|
+ user = _get_user_by_id(db_conn, user_id)
|
|
|
+ if not user:
|
|
|
+ raise _auth_exception()
|
|
|
+ return user
|
|
|
+
|
|
|
+
|
|
|
+def require_admin_user(current_user: dict = Depends(get_current_user)) -> dict:
|
|
|
+ if not current_user.get("is_admin"):
|
|
|
+ raise HTTPException(status_code=403, detail="该请求需要管理员权限")
|
|
|
+ return current_user
|
|
|
+
|
|
|
+
|
|
|
+def check_card_permission(db_conn: PooledMySQLConnection, current_user: dict, card_id: int):
|
|
|
+ if current_user.get("is_admin"):
|
|
|
+ return
|
|
|
+
|
|
|
+ with db_conn.cursor() as cursor:
|
|
|
+ cursor.execute(
|
|
|
+ f"SELECT 1 FROM `{settings.DB_USER_CARD_TABLE_NAME}` WHERE user_id = %s AND card_id = %s LIMIT 1",
|
|
|
+ (current_user["id"], card_id)
|
|
|
+ )
|
|
|
+ if cursor.fetchone():
|
|
|
+ return
|
|
|
+
|
|
|
+ raise HTTPException(status_code=403, detail="没有该卡片权限")
|
|
|
+
|
|
|
+
|
|
|
+@router.post("/register", response_model=UserResponse, summary="注册用户")
|
|
|
+def register_user(data: UserRegisterRequest, db_conn: PooledMySQLConnection = db_dependency):
|
|
|
+ try:
|
|
|
+ if _get_user_by_username(db_conn, data.username):
|
|
|
+ raise HTTPException(status_code=400, detail="该用户名已经存在")
|
|
|
+
|
|
|
+ with db_conn.cursor(dictionary=True) as cursor:
|
|
|
+ cursor.execute(
|
|
|
+ f"INSERT INTO `{settings.DB_USER_TABLE_NAME}` (username, nickname, password) VALUES (%s, %s, %s)",
|
|
|
+ (data.username, data.nickname, _hash_password(data.password))
|
|
|
+ )
|
|
|
+ db_conn.commit()
|
|
|
+
|
|
|
+ user_id = cursor.lastrowid
|
|
|
+ logger.info(f"User created, id: {user_id}, username: {data.username}")
|
|
|
+ user = _get_user_by_id(db_conn, user_id)
|
|
|
+ return UserResponse.model_validate(user)
|
|
|
+
|
|
|
+ except HTTPException:
|
|
|
+ db_conn.rollback()
|
|
|
+ raise
|
|
|
+ except Exception as e:
|
|
|
+ db_conn.rollback()
|
|
|
+ logger.error(f"Register user failed: {e}")
|
|
|
+ raise HTTPException(status_code=500, detail="注册用户失败")
|
|
|
+
|
|
|
+
|
|
|
+@router.post("/login", response_model=TokenResponse, summary="用户登录")
|
|
|
+def login_for_access_token(
|
|
|
+ form_data: OAuth2PasswordRequestForm = Depends(),
|
|
|
+ db_conn: PooledMySQLConnection = db_dependency
|
|
|
+):
|
|
|
+ username = form_data.username
|
|
|
+ password = form_data.password
|
|
|
+
|
|
|
+ user = _get_user_by_username(db_conn, username)
|
|
|
+ if not user or not _verify_password(password, user["password"]):
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
+ detail="用户名或密码错误",
|
|
|
+ headers={"WWW-Authenticate": "Bearer"},
|
|
|
+ )
|
|
|
+
|
|
|
+ return TokenResponse(access_token=_create_access_token(user["id"]))
|
|
|
+
|
|
|
+
|
|
|
+@router.put("/me", response_model=UserResponse, summary="修改当前的昵称或密码")
|
|
|
+def update_current_user(
|
|
|
+ data: UserUpdateRequest,
|
|
|
+ current_user: dict = Depends(get_current_user),
|
|
|
+ db_conn: PooledMySQLConnection = db_dependency
|
|
|
+):
|
|
|
+ update_fields = []
|
|
|
+ params = []
|
|
|
+
|
|
|
+ if data.nickname is not None:
|
|
|
+ update_fields.append("nickname = %s")
|
|
|
+ params.append(data.nickname)
|
|
|
+ if data.password is not None:
|
|
|
+ update_fields.append("password = %s")
|
|
|
+ params.append(_hash_password(data.password))
|
|
|
+
|
|
|
+ if not update_fields:
|
|
|
+ raise HTTPException(status_code=400, detail="No fields to update")
|
|
|
+
|
|
|
+ try:
|
|
|
+ params.append(current_user["id"])
|
|
|
+ with db_conn.cursor() as cursor:
|
|
|
+ query = (
|
|
|
+ f"UPDATE `{settings.DB_USER_TABLE_NAME}` SET "
|
|
|
+ f"{', '.join(update_fields)} WHERE id = %s"
|
|
|
+ )
|
|
|
+ cursor.execute(query, tuple(params))
|
|
|
+ db_conn.commit()
|
|
|
+
|
|
|
+ return UserResponse.model_validate(_get_user_by_id(db_conn, current_user["id"]))
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ db_conn.rollback()
|
|
|
+ logger.error(f"Update user failed: {e}")
|
|
|
+ raise HTTPException(status_code=500, detail="修改用户失败")
|
|
|
+
|
|
|
+
|
|
|
+@router.get("/list", response_model=UserListResponseWrapper, summary="查询用户列表")
|
|
|
+def list_users(
|
|
|
+ nickname: Optional[str] = Query(None, description="按昵称模糊搜索"),
|
|
|
+ sort_order: str = Query("desc", pattern="^(asc|desc)$", description="按创建时间排序: asc 或 desc"),
|
|
|
+ skip: int = Query(0, ge=0),
|
|
|
+ page_num: Optional[int] = Query(None, ge=1),
|
|
|
+ limit: int = Query(100, ge=1, le=1000),
|
|
|
+ current_user: dict = Depends(require_admin_user),
|
|
|
+ db_conn: PooledMySQLConnection = db_dependency
|
|
|
+):
|
|
|
+ if page_num is not None:
|
|
|
+ skip = (page_num - 1) * limit
|
|
|
+
|
|
|
+ try:
|
|
|
+ with db_conn.cursor(dictionary=True) as cursor:
|
|
|
+ conditions = []
|
|
|
+ params = []
|
|
|
+
|
|
|
+ if nickname:
|
|
|
+ conditions.append("nickname LIKE %s")
|
|
|
+ params.append(f"%{nickname}%")
|
|
|
+
|
|
|
+ where_clause = ""
|
|
|
+ if conditions:
|
|
|
+ where_clause = " WHERE " + " AND ".join(conditions)
|
|
|
+
|
|
|
+ count_query = f"SELECT COUNT(*) as total FROM `{settings.DB_USER_TABLE_NAME}`" + where_clause
|
|
|
+ cursor.execute(count_query, tuple(params))
|
|
|
+ total_count = cursor.fetchone()["total"]
|
|
|
+
|
|
|
+ order_sql = "ASC" if sort_order.lower() == "asc" else "DESC"
|
|
|
+ data_query = (
|
|
|
+ f"SELECT id, username, nickname, is_admin, created_at, updated_at "
|
|
|
+ f"FROM `{settings.DB_USER_TABLE_NAME}`"
|
|
|
+ f"{where_clause} ORDER BY created_at {order_sql}, id DESC LIMIT %s OFFSET %s"
|
|
|
+ )
|
|
|
+ data_params = params.copy()
|
|
|
+ data_params.extend([limit, skip])
|
|
|
+ cursor.execute(data_query, tuple(data_params))
|
|
|
+ users = [_normalize_user(user) for user in cursor.fetchall()]
|
|
|
+
|
|
|
+ return {
|
|
|
+ "data": {
|
|
|
+ "total": total_count,
|
|
|
+ "list": users
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Query user list failed: {e}")
|
|
|
+ raise HTTPException(status_code=500, detail="查询用户列表失败")
|
|
|
+
|
|
|
+
|
|
|
+@router.post("/grant_admin", response_model=UserResponse, summary="给用户授予管理员权限")
|
|
|
+def grant_current_user_admin(
|
|
|
+ data: AdminGrantRequest,
|
|
|
+ current_user: dict = Depends(get_current_user),
|
|
|
+ db_conn: PooledMySQLConnection = db_dependency
|
|
|
+):
|
|
|
+ if not hmac.compare_digest(data.key, ADMIN_GRANT_KEY):
|
|
|
+ raise HTTPException(status_code=403, detail="Invalid admin key")
|
|
|
+
|
|
|
+ try:
|
|
|
+ target_user_id = data.user_id or current_user["id"]
|
|
|
+ if not _get_user_by_id(db_conn, target_user_id):
|
|
|
+ raise HTTPException(status_code=404, detail="未发现该用户")
|
|
|
+
|
|
|
+ with db_conn.cursor() as cursor:
|
|
|
+ cursor.execute(
|
|
|
+ f"UPDATE `{settings.DB_USER_TABLE_NAME}` SET is_admin = TRUE WHERE id = %s",
|
|
|
+ (target_user_id,)
|
|
|
+ )
|
|
|
+ db_conn.commit()
|
|
|
+ return UserResponse.model_validate(_get_user_by_id(db_conn, target_user_id))
|
|
|
+
|
|
|
+ except HTTPException:
|
|
|
+ db_conn.rollback()
|
|
|
+ raise
|
|
|
+ except Exception as e:
|
|
|
+ db_conn.rollback()
|
|
|
+ logger.error(f"Grant admin failed: {e}")
|
|
|
+ raise HTTPException(status_code=500, detail="Grant admin failed")
|
|
|
+
|
|
|
+
|
|
|
+@router.post("/bind_card", status_code=200, summary="给用户绑定卡片ID")
|
|
|
+def bind_card_to_user(
|
|
|
+ data: BindCardRequest,
|
|
|
+ current_user: dict = Depends(require_admin_user),
|
|
|
+ db_conn: PooledMySQLConnection = db_dependency
|
|
|
+):
|
|
|
+ try:
|
|
|
+ # Keep the request field as card_id, but accept multiple card ids.
|
|
|
+ card_ids = list(dict.fromkeys(data.card_id))
|
|
|
+
|
|
|
+ with db_conn.cursor(dictionary=True) as cursor:
|
|
|
+ cursor.execute(f"SELECT id FROM `{settings.DB_USER_TABLE_NAME}` WHERE id = %s", (data.user_id,))
|
|
|
+ if not cursor.fetchone():
|
|
|
+ raise HTTPException(status_code=404, detail="未发现该用户")
|
|
|
+
|
|
|
+ format_strings = ",".join(["%s"] * len(card_ids))
|
|
|
+ cursor.execute(
|
|
|
+ f"SELECT id FROM `{settings.DB_CARD_TABLE_NAME}` WHERE id IN ({format_strings})",
|
|
|
+ tuple(card_ids)
|
|
|
+ )
|
|
|
+ existing_card_ids = {row["id"] for row in cursor.fetchall()}
|
|
|
+ missing_card_ids = sorted(set(card_ids) - existing_card_ids)
|
|
|
+ if missing_card_ids:
|
|
|
+ raise HTTPException(status_code=404, detail=f"卡片未发现: {missing_card_ids}")
|
|
|
+
|
|
|
+ bind_params = [(data.user_id, card_id) for card_id in card_ids]
|
|
|
+ cursor.executemany(
|
|
|
+ f"INSERT IGNORE INTO `{settings.DB_USER_CARD_TABLE_NAME}` (user_id, card_id) VALUES (%s, %s)",
|
|
|
+ bind_params
|
|
|
+ )
|
|
|
+ inserted_count = cursor.rowcount
|
|
|
+ db_conn.commit()
|
|
|
+
|
|
|
+ logger.info(f"Admin {current_user['id']} bound cards {card_ids} to user {data.user_id}")
|
|
|
+ return {
|
|
|
+ "message": "卡片绑定成功",
|
|
|
+ "user_id": data.user_id,
|
|
|
+ "card_id": card_ids,
|
|
|
+ "inserted_count": inserted_count
|
|
|
+ }
|
|
|
+
|
|
|
+ except HTTPException:
|
|
|
+ db_conn.rollback()
|
|
|
+ raise
|
|
|
+ except Exception as e:
|
|
|
+ db_conn.rollback()
|
|
|
+ logger.error(f"Bind card failed: {e}")
|
|
|
+ raise HTTPException(status_code=500, detail="卡片绑定失败")
|