||
- # app/api/v1/auth.py
- from fastapi import APIRouter, HTTPException, status, Request, Depends
- from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
- from sqlalchemy.orm import Session
- from typing import Optional
-
- from app.schemas.user import UserCreate, UserLogin
- from app.schemas.token import TokenResponse, RefreshTokenRequest
- from app.core.security import (
- create_access_token,
- JWTManager,
- verify_access_token
- )
- from app.config import settings
- from app.database import get_db
- from app.models.user import User
- from app.services.user_service import UserService
-
- router = APIRouter(prefix="/auth", tags=["认证"])
- security = HTTPBearer()
-
- # 初始化 JWT 管理器
- jwt_manager = JWTManager(
- secret_key=settings.SECRET_KEY,
- algorithm=settings.ALGORITHM
- )
-
- import logging
- logger = logging.getLogger(__name__)
-
- @router.post("/register", response_model=TokenResponse, status_code=status.HTTP_201_CREATED)
- async def register(
- user_data: UserCreate,
- request: Request,
- db: Session = Depends(get_db)
- ):
- """用户注册"""
- user_service = UserService(db)
-
- try:
- # 创建用户
- user = await user_service.create_user(user_data)
-
- # 创建访问令牌
- access_token = create_access_token(
- data={
- "sub": user.username,
- "user_id": user.id,
- "email": user.email,
- "type": "access"
- },
- secret_key=settings.SECRET_KEY,
- expires_minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
- )
-
- # 创建刷新令牌
- refresh_token = jwt_manager.create_refresh_token(
- {
- "sub": user.username,
- "user_id": user.id,
- "type": "refresh"
- },
- expires_days=7
- )
-
- # 构建用户响应
- user_response = {
- "id": user.id,
- "username": user.username,
- "email": user.email,
- "full_name": user.full_name,
- "is_active": user.is_active,
- "is_verified": user.is_verified,
- "is_superuser": user.is_superuser,
- "created_at": user.created_at.isoformat() if user.created_at else None
- }
-
- return TokenResponse(
- access_token=access_token,
- token_type="bearer",
- expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
- refresh_token=refresh_token,
- user=user_response
- )
-
- except ValueError as e:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail=str(e)
- )
- except Exception as e:
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail=f"注册失败: {str(e)}"
- )
-
- @router.post("/login", response_model=TokenResponse)
- async def login(
- login_data: UserLogin,
- request: Request,
- db: Session = Depends(get_db)
- ):
- """用户登录"""
- user_service = UserService(db)
-
- logger.info("✅ 用户登录")
-
- # 验证用户
- user = await user_service.authenticate_user(
- login_data.username,
- login_data.password
- )
-
- if not user:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="用户名或密码错误"
- )
-
- if not user.is_active:
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail="用户账户已被禁用"
- )
-
- if user.is_locked:
- raise HTTPException(
- status_code=status.HTTP_423_LOCKED,
- detail="账户已被锁定,请联系管理员"
- )
-
- # 创建访问令牌
- access_token = jwt_manager.create_access_token(
- {
- "sub": user.username,
- "user_id": user.id,
- "email": user.email,
- "type": "access"
- },
- expires_minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
- )
-
- # 创建刷新令牌
- refresh_token = jwt_manager.create_refresh_token(
- {
- "sub": user.username,
- "user_id": user.id,
- "type": "refresh"
- },
- expires_days=7
- )
-
- # 构建用户响应
- user_response = {
- "id": user.id,
- "username": user.username,
- "email": user.email,
- "full_name": user.full_name,
- "is_active": user.is_active,
- "is_verified": user.is_verified,
- "is_superuser": user.is_superuser,
- "created_at": user.created_at.isoformat() if user.created_at else None,
- "last_login": user.last_login.isoformat() if user.last_login else None
- }
-
- return TokenResponse(
- access_token=access_token,
- token_type="bearer",
- expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
- refresh_token=refresh_token,
- user=user_response
- )
-
- @router.get("/me")
- async def get_current_user(
- credentials: HTTPAuthorizationCredentials = Depends(security),
- db: Session = Depends(get_db)
- ):
- """获取当前用户信息"""
- token = credentials.credentials
-
- logger.info("✅ 获取当前用户信息")
-
- # 验证令牌
- payload = verify_access_token(token, settings.SECRET_KEY)
- if not payload:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="无效的令牌"
- )
-
- username = payload.get("sub")
- user_service = UserService(db)
- user = await user_service.get_user_by_username(username)
-
- if not user:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="用户不存在"
- )
-
- if not user.is_active:
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail="用户账户已被禁用"
- )
-
- return {
- "id": user.id,
- "username": user.username,
- "email": user.email,
- "full_name": user.full_name,
- "is_active": user.is_active,
- "is_verified": user.is_verified,
- "is_superuser": user.is_superuser,
- "created_at": user.created_at.isoformat() if user.created_at else None,
- "last_login": user.last_login.isoformat() if user.last_login else None,
- "avatar": user.avatar
- }
-
- @router.post("/refresh")
- async def refresh_token(
- refresh_data: RefreshTokenRequest,
- db: Session = Depends(get_db)
- ):
- """刷新访问令牌"""
- refresh_token = refresh_data.refresh_token
-
- if not refresh_token:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="缺少刷新令牌"
- )
-
- # 验证刷新令牌
- payload = jwt_manager.verify_token(refresh_token)
-
- if not payload or payload.get("type") != "refresh":
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="无效的刷新令牌"
- )
-
- username = payload.get("sub")
- user_service = UserService(db)
- user = await user_service.get_user_by_username(username)
-
- if not user:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="用户不存在"
- )
-
- if not user.is_active:
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail="用户账户已被禁用"
- )
-
- # 创建新的访问令牌
- access_token = jwt_manager.create_access_token(
- {
- "sub": user.username,
- "user_id": user.id,
- "email": user.email,
- "type": "access"
- },
- expires_minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
- )
-
- return {
- "access_token": access_token,
- "token_type": "bearer",
- "expires_in": settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60
- }
-
- @router.post("/logout")
- async def logout(
- refresh_data: RefreshTokenRequest,
- db: Session = Depends(get_db),
- credentials: HTTPAuthorizationCredentials = Depends(security)
- ):
- """用户登出"""
- # 这里可以记录令牌到黑名单,或者简单返回成功
- # 在实际应用中,可能需要将令牌存储到Redis黑名单
- return {"message": "登出成功"}
-
- @router.get("/test-db")
- async def test_database(db: Session = Depends(get_db)):
- """测试数据库连接和用户统计"""
- user_service = UserService(db)
-
- user_count = await user_service.count_users()
- all_users = await user_service.list_users(limit=10)
-
- users_info = []
- for user in all_users:
- users_info.append({
- "id": user.id,
- "username": user.username,
- "email": user.email,
- "is_active": user.is_active
- })
-
- return {
- "database": "SQLite",
- "user_count": user_count,
- "users": users_info,
- "message": "数据库连接正常"
- }
-
- @router.get("/test")
- async def test_auth():
- """测试认证 API"""
- return {
- "message": "认证 API 正常运行(使用 SQLite 数据库)",
- "endpoints": {
- "POST /register": "用户注册",
- "POST /login": "用户登录",
- "GET /me": "获取当前用户",
- "POST /refresh": "刷新令牌",
- "POST /logout": "用户登出",
- "GET /test-db": "测试数据库",
- "GET /test": "测试接口"
- },
- "test_credentials": {
- "username": "admin",
- "password": "Admin123!"
- }
- }
|