|
|
@@ -1,15 +1,10 @@
|
|
|
import base64
|
|
|
-import hashlib
|
|
|
-import hmac
|
|
|
import json
|
|
|
-import secrets
|
|
|
-from datetime import datetime, timedelta, timezone
|
|
|
-from typing import Optional, List
|
|
|
+from typing import List, Optional
|
|
|
|
|
|
-from fastapi import APIRouter, Depends, HTTPException, Query, status
|
|
|
-from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
|
|
+from fastapi import APIRouter, Depends, Header, HTTPException, status
|
|
|
from mysql.connector.pooling import PooledMySQLConnection
|
|
|
-from pydantic import BaseModel, ConfigDict, Field
|
|
|
+from pydantic import BaseModel, Field
|
|
|
|
|
|
from app.core.config import settings
|
|
|
from app.core.database_loader import get_db_connection
|
|
|
@@ -19,164 +14,57 @@ logger = get_logger(__name__)
|
|
|
router = APIRouter()
|
|
|
db_dependency = Depends(get_db_connection)
|
|
|
|
|
|
-oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.API_PREFIX}/users/login")
|
|
|
-
|
|
|
-# Hard-coded by design: used only for promoting the current user to admin.
|
|
|
-ADMIN_GRANT_KEY = settings.ADMIN_GRANT_KEY
|
|
|
-TOKEN_SECRET_KEY = settings.TOKEN_SECRET_KEY
|
|
|
-TOKEN_EXPIRE_MINUTES = 60 * 24 * 7
|
|
|
-
|
|
|
-
|
|
|
-class UserRegisterRequest(BaseModel):
|
|
|
- username: str = Field(..., min_length=6, max_length=20, pattern=r"^[A-Za-z0-9]+$")
|
|
|
- nickname: str = Field(..., min_length=1, max_length=20)
|
|
|
- password: str = Field(..., min_length=6, max_length=20, pattern=r"^[A-Za-z0-9]+$")
|
|
|
-
|
|
|
-
|
|
|
-class UserUpdateRequest(BaseModel):
|
|
|
- model_config = ConfigDict(extra="forbid")
|
|
|
-
|
|
|
- nickname: Optional[str] = Field(None, min_length=1, max_length=20)
|
|
|
- password: Optional[str] = Field(None, min_length=6, max_length=20, pattern=r"^[A-Za-z0-9]+$")
|
|
|
-
|
|
|
-
|
|
|
-class UserLoginRequest(BaseModel):
|
|
|
- username: str = Field(..., min_length=6, max_length=20, pattern=r"^[A-Za-z0-9]+$")
|
|
|
- password: str = Field(..., min_length=6, max_length=20, pattern=r"^[A-Za-z0-9]+$")
|
|
|
-
|
|
|
-
|
|
|
-class AdminGrantRequest(BaseModel):
|
|
|
- key: str
|
|
|
- user_id: Optional[int] = None
|
|
|
-
|
|
|
|
|
|
class BindCardRequest(BaseModel):
|
|
|
- user_id: int
|
|
|
+ user_id: int = Field(..., ge=0)
|
|
|
card_id: List[int] = Field(..., min_length=1)
|
|
|
|
|
|
|
|
|
-class UserResponse(BaseModel):
|
|
|
- id: int
|
|
|
- username: str
|
|
|
- nickname: str
|
|
|
- is_admin: bool
|
|
|
-
|
|
|
-
|
|
|
-class UserListItem(BaseModel):
|
|
|
- id: int
|
|
|
- username: str
|
|
|
- nickname: str
|
|
|
- is_admin: bool
|
|
|
- created_at: datetime
|
|
|
- updated_at: datetime
|
|
|
-
|
|
|
-
|
|
|
-class UserListWithTotal(BaseModel):
|
|
|
- total: int
|
|
|
- list: List[UserListItem]
|
|
|
-
|
|
|
-
|
|
|
-class UserListResponseWrapper(BaseModel):
|
|
|
- data: UserListWithTotal
|
|
|
-
|
|
|
-
|
|
|
-class TokenResponse(BaseModel):
|
|
|
- access_token: str
|
|
|
- token_type: str = "bearer"
|
|
|
-
|
|
|
-
|
|
|
-def _auth_exception() -> HTTPException:
|
|
|
+def _auth_exception(detail: str = "用户认证信息无效") -> HTTPException:
|
|
|
return HTTPException(
|
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
- detail="Invalid or expired token",
|
|
|
- headers={"WWW-Authenticate": "Bearer"},
|
|
|
+ detail=detail,
|
|
|
+ headers={"WWW-Authenticate": "X-USER-BASE64"},
|
|
|
)
|
|
|
|
|
|
|
|
|
-def _hash_password(password: str, salt: Optional[str] = None) -> str:
|
|
|
- salt = salt or secrets.token_hex(16)
|
|
|
- password_hash = hashlib.sha256(f"{salt}:{password}".encode("utf-8")).hexdigest()
|
|
|
- return f"{salt}${password_hash}"
|
|
|
-
|
|
|
-
|
|
|
-def _verify_password(password: str, stored_password: str) -> bool:
|
|
|
+def _decode_user_base64(user_base64: str) -> dict:
|
|
|
+ """解析外部认证系统传入的 X-USER-BASE64 用户信息。"""
|
|
|
try:
|
|
|
- salt, old_hash = stored_password.split("$", 1)
|
|
|
- except ValueError:
|
|
|
- return False
|
|
|
- return hmac.compare_digest(_hash_password(password, salt), f"{salt}${old_hash}")
|
|
|
-
|
|
|
-
|
|
|
-def _b64encode(data: bytes) -> str:
|
|
|
- return base64.urlsafe_b64encode(data).decode("utf-8").rstrip("=")
|
|
|
-
|
|
|
-
|
|
|
-def _b64decode(data: str) -> bytes:
|
|
|
- padding = "=" * (-len(data) % 4)
|
|
|
- return base64.urlsafe_b64decode(f"{data}{padding}".encode("utf-8"))
|
|
|
-
|
|
|
-
|
|
|
-def _create_access_token(user_id: int) -> str:
|
|
|
- expires_at = datetime.now(timezone.utc) + timedelta(minutes=TOKEN_EXPIRE_MINUTES)
|
|
|
- payload = {"user_id": user_id, "exp": int(expires_at.timestamp())}
|
|
|
- body = _b64encode(json.dumps(payload, separators=(",", ":")).encode("utf-8"))
|
|
|
- signature = hmac.new(TOKEN_SECRET_KEY.encode("utf-8"), body.encode("utf-8"), hashlib.sha256).hexdigest()
|
|
|
- return f"{body}.{signature}"
|
|
|
-
|
|
|
-
|
|
|
-def _decode_access_token(token: str) -> int:
|
|
|
- try:
|
|
|
- body, signature = token.split(".", 1)
|
|
|
- expected_signature = hmac.new(
|
|
|
- TOKEN_SECRET_KEY.encode("utf-8"),
|
|
|
- body.encode("utf-8"),
|
|
|
- hashlib.sha256
|
|
|
- ).hexdigest()
|
|
|
- if not hmac.compare_digest(signature, expected_signature):
|
|
|
- raise ValueError("bad signature")
|
|
|
-
|
|
|
- payload = json.loads(_b64decode(body))
|
|
|
- if int(payload.get("exp", 0)) < int(datetime.now(timezone.utc).timestamp()):
|
|
|
- raise ValueError("expired")
|
|
|
- return int(payload["user_id"])
|
|
|
+ padding = "=" * (-len(user_base64) % 4)
|
|
|
+ decoded_bytes = base64.urlsafe_b64decode(f"{user_base64}{padding}".encode("utf-8"))
|
|
|
+ payload = json.loads(decoded_bytes.decode("utf-8"))
|
|
|
+ return payload.get("user", payload)
|
|
|
except Exception:
|
|
|
- raise _auth_exception()
|
|
|
-
|
|
|
-
|
|
|
-def _normalize_user(user: dict) -> dict:
|
|
|
- user["is_admin"] = bool(user.get("is_admin"))
|
|
|
- return user
|
|
|
+ raise _auth_exception("X-USER-BASE64 解析失败")
|
|
|
|
|
|
|
|
|
-def _get_user_by_id(db_conn: PooledMySQLConnection, user_id: int) -> Optional[dict]:
|
|
|
- with db_conn.cursor(dictionary=True) as cursor:
|
|
|
- cursor.execute(
|
|
|
- f"SELECT id, username, nickname, is_admin FROM `{settings.DB_USER_TABLE_NAME}` WHERE id = %s",
|
|
|
- (user_id,)
|
|
|
- )
|
|
|
- user = cursor.fetchone()
|
|
|
- return _normalize_user(user) if user else None
|
|
|
+def get_current_user(x_user_base64: Optional[str] = Header(None, alias="X-USER-BASE64")) -> dict:
|
|
|
+ if not x_user_base64:
|
|
|
+ raise _auth_exception("缺少 X-USER-BASE64 请求头")
|
|
|
|
|
|
+ user_data = _decode_user_base64(x_user_base64)
|
|
|
+ user_id = user_data.get("id")
|
|
|
+ role_code_list = user_data.get("roleCodeList") or []
|
|
|
|
|
|
-def _get_user_by_username(db_conn: PooledMySQLConnection, username: str) -> Optional[dict]:
|
|
|
- with db_conn.cursor(dictionary=True) as cursor:
|
|
|
- cursor.execute(
|
|
|
- f"SELECT id, username, nickname, password, is_admin FROM `{settings.DB_USER_TABLE_NAME}` WHERE username = %s",
|
|
|
- (username,)
|
|
|
- )
|
|
|
- user = cursor.fetchone()
|
|
|
- return _normalize_user(user) if user else None
|
|
|
+ if user_id is None:
|
|
|
+ raise _auth_exception("X-USER-BASE64 缺少用户 id")
|
|
|
+ if not isinstance(role_code_list, list):
|
|
|
+ raise _auth_exception("X-USER-BASE64 的 roleCodeList 格式错误")
|
|
|
|
|
|
+ try:
|
|
|
+ user_id = int(user_id)
|
|
|
+ except (TypeError, ValueError):
|
|
|
+ raise _auth_exception("X-USER-BASE64 的用户 id 格式错误")
|
|
|
|
|
|
-def get_current_user(
|
|
|
- token: str = Depends(oauth2_scheme),
|
|
|
- db_conn: PooledMySQLConnection = db_dependency
|
|
|
-) -> dict:
|
|
|
- user_id = _decode_access_token(token)
|
|
|
- user = _get_user_by_id(db_conn, user_id)
|
|
|
- if not user:
|
|
|
- raise _auth_exception()
|
|
|
- return user
|
|
|
+ return {
|
|
|
+ "id": user_id,
|
|
|
+ "is_admin": "admin" in role_code_list,
|
|
|
+ "roleCodeList": role_code_list,
|
|
|
+ "nickname": user_data.get("nickname"),
|
|
|
+ "account": user_data.get("account"),
|
|
|
+ "raw": user_data
|
|
|
+ }
|
|
|
|
|
|
|
|
|
def require_admin_user(current_user: dict = Depends(get_current_user)) -> dict:
|
|
|
@@ -186,6 +74,7 @@ def require_admin_user(current_user: dict = Depends(get_current_user)) -> dict:
|
|
|
|
|
|
|
|
|
def check_card_permission(db_conn: PooledMySQLConnection, current_user: dict, card_id: int):
|
|
|
+ """管理员直接放行;普通用户需要在用户-卡片绑定表中存在对应关系。"""
|
|
|
if current_user.get("is_admin"):
|
|
|
return
|
|
|
|
|
|
@@ -200,188 +89,17 @@ def check_card_permission(db_conn: PooledMySQLConnection, current_user: dict, ca
|
|
|
raise HTTPException(status_code=403, detail="没有该卡片权限")
|
|
|
|
|
|
|
|
|
-@router.post("/register", response_model=UserResponse, summary="注册用户")
|
|
|
-def register_user(data: UserRegisterRequest, db_conn: PooledMySQLConnection = db_dependency):
|
|
|
- try:
|
|
|
- if _get_user_by_username(db_conn, data.username):
|
|
|
- raise HTTPException(status_code=400, detail="该用户名已经存在")
|
|
|
-
|
|
|
- with db_conn.cursor(dictionary=True) as cursor:
|
|
|
- cursor.execute(
|
|
|
- f"INSERT INTO `{settings.DB_USER_TABLE_NAME}` (username, nickname, password) VALUES (%s, %s, %s)",
|
|
|
- (data.username, data.nickname, _hash_password(data.password))
|
|
|
- )
|
|
|
- db_conn.commit()
|
|
|
-
|
|
|
- user_id = cursor.lastrowid
|
|
|
- logger.info(f"User created, id: {user_id}, username: {data.username}")
|
|
|
- user = _get_user_by_id(db_conn, user_id)
|
|
|
- return UserResponse.model_validate(user)
|
|
|
-
|
|
|
- except HTTPException:
|
|
|
- db_conn.rollback()
|
|
|
- raise
|
|
|
- except Exception as e:
|
|
|
- db_conn.rollback()
|
|
|
- logger.error(f"Register user failed: {e}")
|
|
|
- raise HTTPException(status_code=500, detail="注册用户失败")
|
|
|
-
|
|
|
-
|
|
|
-@router.post("/login", response_model=TokenResponse, summary="用户登录")
|
|
|
-def login_for_access_token(
|
|
|
- form_data: OAuth2PasswordRequestForm = Depends(),
|
|
|
- db_conn: PooledMySQLConnection = db_dependency
|
|
|
-):
|
|
|
- username = form_data.username
|
|
|
- password = form_data.password
|
|
|
-
|
|
|
- user = _get_user_by_username(db_conn, username)
|
|
|
- if not user or not _verify_password(password, user["password"]):
|
|
|
- raise HTTPException(
|
|
|
- status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
- detail="用户名或密码错误",
|
|
|
- headers={"WWW-Authenticate": "Bearer"},
|
|
|
- )
|
|
|
-
|
|
|
- return TokenResponse(access_token=_create_access_token(user["id"]))
|
|
|
-
|
|
|
-
|
|
|
-@router.put("/me", response_model=UserResponse, summary="修改当前的昵称或密码")
|
|
|
-def update_current_user(
|
|
|
- data: UserUpdateRequest,
|
|
|
- current_user: dict = Depends(get_current_user),
|
|
|
- db_conn: PooledMySQLConnection = db_dependency
|
|
|
-):
|
|
|
- update_fields = []
|
|
|
- params = []
|
|
|
-
|
|
|
- if data.nickname is not None:
|
|
|
- update_fields.append("nickname = %s")
|
|
|
- params.append(data.nickname)
|
|
|
- if data.password is not None:
|
|
|
- update_fields.append("password = %s")
|
|
|
- params.append(_hash_password(data.password))
|
|
|
-
|
|
|
- if not update_fields:
|
|
|
- raise HTTPException(status_code=400, detail="No fields to update")
|
|
|
-
|
|
|
- try:
|
|
|
- params.append(current_user["id"])
|
|
|
- with db_conn.cursor() as cursor:
|
|
|
- query = (
|
|
|
- f"UPDATE `{settings.DB_USER_TABLE_NAME}` SET "
|
|
|
- f"{', '.join(update_fields)} WHERE id = %s"
|
|
|
- )
|
|
|
- cursor.execute(query, tuple(params))
|
|
|
- db_conn.commit()
|
|
|
-
|
|
|
- return UserResponse.model_validate(_get_user_by_id(db_conn, current_user["id"]))
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- db_conn.rollback()
|
|
|
- logger.error(f"Update user failed: {e}")
|
|
|
- raise HTTPException(status_code=500, detail="修改用户失败")
|
|
|
-
|
|
|
-
|
|
|
-@router.get("/list", response_model=UserListResponseWrapper, summary="查询用户列表")
|
|
|
-def list_users(
|
|
|
- nickname: Optional[str] = Query(None, description="按昵称模糊搜索"),
|
|
|
- sort_order: str = Query("desc", pattern="^(asc|desc)$", description="按创建时间排序: asc 或 desc"),
|
|
|
- skip: int = Query(0, ge=0),
|
|
|
- page_num: Optional[int] = Query(None, ge=1),
|
|
|
- limit: int = Query(100, ge=1, le=1000),
|
|
|
- current_user: dict = Depends(require_admin_user),
|
|
|
- db_conn: PooledMySQLConnection = db_dependency
|
|
|
-):
|
|
|
- if page_num is not None:
|
|
|
- skip = (page_num - 1) * limit
|
|
|
-
|
|
|
- try:
|
|
|
- with db_conn.cursor(dictionary=True) as cursor:
|
|
|
- conditions = []
|
|
|
- params = []
|
|
|
-
|
|
|
- if nickname:
|
|
|
- conditions.append("nickname LIKE %s")
|
|
|
- params.append(f"%{nickname}%")
|
|
|
-
|
|
|
- where_clause = ""
|
|
|
- if conditions:
|
|
|
- where_clause = " WHERE " + " AND ".join(conditions)
|
|
|
-
|
|
|
- count_query = f"SELECT COUNT(*) as total FROM `{settings.DB_USER_TABLE_NAME}`" + where_clause
|
|
|
- cursor.execute(count_query, tuple(params))
|
|
|
- total_count = cursor.fetchone()["total"]
|
|
|
-
|
|
|
- order_sql = "ASC" if sort_order.lower() == "asc" else "DESC"
|
|
|
- data_query = (
|
|
|
- f"SELECT id, username, nickname, is_admin, created_at, updated_at "
|
|
|
- f"FROM `{settings.DB_USER_TABLE_NAME}`"
|
|
|
- f"{where_clause} ORDER BY created_at {order_sql}, id DESC LIMIT %s OFFSET %s"
|
|
|
- )
|
|
|
- data_params = params.copy()
|
|
|
- data_params.extend([limit, skip])
|
|
|
- cursor.execute(data_query, tuple(data_params))
|
|
|
- users = [_normalize_user(user) for user in cursor.fetchall()]
|
|
|
-
|
|
|
- return {
|
|
|
- "data": {
|
|
|
- "total": total_count,
|
|
|
- "list": users
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"Query user list failed: {e}")
|
|
|
- raise HTTPException(status_code=500, detail="查询用户列表失败")
|
|
|
-
|
|
|
-
|
|
|
-@router.post("/grant_admin", response_model=UserResponse, summary="给用户授予管理员权限")
|
|
|
-def grant_current_user_admin(
|
|
|
- data: AdminGrantRequest,
|
|
|
- current_user: dict = Depends(get_current_user),
|
|
|
- db_conn: PooledMySQLConnection = db_dependency
|
|
|
-):
|
|
|
- if not hmac.compare_digest(data.key, ADMIN_GRANT_KEY):
|
|
|
- raise HTTPException(status_code=403, detail="Invalid admin key")
|
|
|
-
|
|
|
- try:
|
|
|
- target_user_id = data.user_id or current_user["id"]
|
|
|
- if not _get_user_by_id(db_conn, target_user_id):
|
|
|
- raise HTTPException(status_code=404, detail="未发现该用户")
|
|
|
-
|
|
|
- with db_conn.cursor() as cursor:
|
|
|
- cursor.execute(
|
|
|
- f"UPDATE `{settings.DB_USER_TABLE_NAME}` SET is_admin = TRUE WHERE id = %s",
|
|
|
- (target_user_id,)
|
|
|
- )
|
|
|
- db_conn.commit()
|
|
|
- return UserResponse.model_validate(_get_user_by_id(db_conn, target_user_id))
|
|
|
-
|
|
|
- except HTTPException:
|
|
|
- db_conn.rollback()
|
|
|
- raise
|
|
|
- except Exception as e:
|
|
|
- db_conn.rollback()
|
|
|
- logger.error(f"Grant admin failed: {e}")
|
|
|
- raise HTTPException(status_code=500, detail="Grant admin failed")
|
|
|
-
|
|
|
-
|
|
|
-@router.post("/bind_card", status_code=200, summary="给用户绑定卡片ID")
|
|
|
+@router.post("/bind_card", status_code=200, summary="给外部用户绑定卡片ID [用户调用]")
|
|
|
def bind_card_to_user(
|
|
|
data: BindCardRequest,
|
|
|
current_user: dict = Depends(require_admin_user),
|
|
|
db_conn: PooledMySQLConnection = db_dependency
|
|
|
):
|
|
|
try:
|
|
|
- # Keep the request field as card_id, but accept multiple card ids.
|
|
|
+ # 请求字段保持 card_id,但支持一次绑定多个卡片,并自动去重。
|
|
|
card_ids = list(dict.fromkeys(data.card_id))
|
|
|
|
|
|
with db_conn.cursor(dictionary=True) as cursor:
|
|
|
- cursor.execute(f"SELECT id FROM `{settings.DB_USER_TABLE_NAME}` WHERE id = %s", (data.user_id,))
|
|
|
- if not cursor.fetchone():
|
|
|
- raise HTTPException(status_code=404, detail="未发现该用户")
|
|
|
-
|
|
|
format_strings = ",".join(["%s"] * len(card_ids))
|
|
|
cursor.execute(
|
|
|
f"SELECT id FROM `{settings.DB_CARD_TABLE_NAME}` WHERE id IN ({format_strings})",
|
|
|
@@ -400,7 +118,7 @@ def bind_card_to_user(
|
|
|
inserted_count = cursor.rowcount
|
|
|
db_conn.commit()
|
|
|
|
|
|
- logger.info(f"Admin {current_user['id']} bound cards {card_ids} to user {data.user_id}")
|
|
|
+ logger.info(f"Admin {current_user['id']} bound cards {card_ids} to external user {data.user_id}")
|
|
|
return {
|
|
|
"message": "卡片绑定成功",
|
|
|
"user_id": data.user_id,
|