||
- from typing import Optional, Dict, Any, Tuple
- from datetime import datetime, timezone, timedelta
- from sqlalchemy import select, update, delete, func, or_
- from sqlalchemy.ext.asyncio import AsyncSession
- from fastapi import HTTPException, status, BackgroundTasks
- import secrets
- import string
-
- from app.models.user import User
- from app.models.token import Token
- from ..schemas.auth import LoginRequest, TokenResponse
- from app.core.security import (
- verify_password,
- create_access_token,
- create_refresh_token,
- create_verification_token,
- create_reset_token,
- decode_token,
- generate_verification_code
- )
- from app.core.email import email_service
- from ..config import settings
-
- class AuthService:
- def __init__(self, db: AsyncSession):
- self.db = db
-
- async def register_user(
- self,
- user_data: Dict[str, Any],
- background_tasks: BackgroundTasks,
- ip_address: Optional[str] = None,
- user_agent: Optional[str] = None
- ) -> User:
- """注册新用户(SQLAlchemy 2.0异步版)"""
- # 1. 检查用户是否存在 - 使用异步查询
- stmt = select(User).where(
- or_(
- User.username == user_data["username"],
- User.email == user_data["email"]
- )
- )
- result = await self.db.execute(stmt)
- existing_user = result.scalar_one_or_none()
-
- if existing_user:
- if existing_user.username == user_data["username"]:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Username already registered"
- )
- else:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Email already registered"
- )
-
- # 2. 创建用户
- from ..core.security import get_password_hash
- hashed_password = get_password_hash(user_data["password"])
-
- user = User(
- username=user_data["username"],
- email=user_data["email"],
- hashed_password=hashed_password,
- first_name=user_data.get("first_name"),
- last_name=user_data.get("last_name"),
- is_active=False, # 需要邮箱验证
- is_verified=False
- )
-
- # 3. 生成验证码
- verification_code = generate_verification_code()
- user.verification_code = verification_code
- user.verification_code_expires = datetime.now(timezone.utc) + timedelta(hours=24)
-
- # 4. 异步保存用户
- self.db.add(user)
- await self.db.commit()
- await self.db.refresh(user)
-
- # 5. 发送验证邮件(后台任务)
- verification_token = create_verification_token(user.email)
- verification_url = f"{settings.FRONTEND_URL}/verify-email?token={verification_token}"
-
- background_tasks.add_task(
- email_service.send_verification_email,
- user.email,
- user.username,
- verification_url,
- verification_code
- )
-
- return user
-
- async def login(
- self,
- login_data: LoginRequest,
- ip_address: Optional[str] = None,
- user_agent: Optional[str] = None
- ) -> Tuple[TokenResponse, User]:
- """用户登录(SQLAlchemy 2.0异步版)"""
- # 1. 查找用户 - 异步查询
- stmt = select(User).where(
- or_(
- User.username == login_data.username,
- User.email == login_data.username
- )
- )
- result = await self.db.execute(stmt)
- user = result.scalar_one_or_none()
-
- if not user:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Incorrect username or password"
- )
-
- # 2. 检查账户是否被锁定
- if user.is_locked and user.locked_until and user.locked_until > datetime.now(timezone.utc):
- raise HTTPException(
- status_code=status.HTTP_423_LOCKED,
- detail=f"Account is locked until {user.locked_until}"
- )
-
- # 3. 验证密码
- if not verify_password(login_data.password, user.hashed_password):
- # 记录失败尝试 - 异步更新
- new_attempts = (user.failed_login_attempts or 0) + 1
-
- update_data = {
- "failed_login_attempts": new_attempts,
- "updated_at": datetime.now(timezone.utc)
- }
-
- # 如果失败次数超过5次,锁定账户
- if new_attempts >= 5:
- update_data.update({
- "is_locked" : True,
- "locked_until": datetime.now(timezone.utc) + timedelta(minutes=30)
- })
-
- update_stmt = (
- update(User)
- .where(User.id == user.id)
- .values(**update_data)
- )
-
- await self.db.execute(update_stmt)
- await self.db.commit()
-
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Incorrect username or password"
- )
-
- # 4. 检查邮箱是否已验证
- if not user.is_verified:
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail="Email not verified"
- )
-
- # 5. 检查账户是否激活
- if not user.is_active:
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail="Account is not active"
- )
-
- # 6. 重置失败尝试次数 - 异步更新
- reset_stmt = (
- update(User)
- .where(User.id == user.id)
- .values(
- failed_login_attempts = 0,
- last_login = datetime.now(timezone.utc),
- is_locked = False,
- locked_until = None,
- )
- )
-
- await self.db.execute(reset_stmt)
-
- # 7. 创建令牌
- access_token = create_access_token({"sub": user.username, "user_id": user.id})
- refresh_token = create_refresh_token({"sub": user.username, "user_id": user.id})
-
- # 8. 保存刷新令牌到数据库 - 异步添加
- refresh_token_entry = Token(
- token=refresh_token,
- token_type="refresh",
- expires_at=datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS),
- user_id=user.id,
- ip_address=ip_address,
- user_agent=user_agent
- )
-
- self.db.add(refresh_token_entry)
-
- # 9. 提交所有更改
- await self.db.commit()
-
- # 10. 刷新用户对象以获取最新数据
- await self.db.refresh(user)
-
- # 11. 构建响应
- token_response = TokenResponse(
- access_token=access_token,
- refresh_token=refresh_token,
- expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
- user=user
- )
-
- return token_response, user
-
- async def verify_email(
- self,
- token: str,
- code: Optional[str] = None
- ) -> bool:
- """验证邮箱(SQLAlchemy 2.0异步版)"""
- # 1. 解码令牌
- payload = decode_token(token)
-
- if not payload or payload.get("type") != "verify":
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Invalid verification token"
- )
-
- email = payload.get("email")
- if not email:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Invalid token payload"
- )
-
- # 2. 查找用户 - 异步查询
- stmt = select(User).where(User.email == email)
- result = await self.db.execute(stmt)
- user = result.scalar_one_or_none()
-
- if not user:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="User not found"
- )
-
- if user.is_verified:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Email already verified"
- )
-
- # 3. 验证码验证
- if code:
- now = datetime.now(timezone.utc)
- if (not user.verification_code or
- user.verification_code != code or
- not user.verification_code_expires or
- user.verification_code_expires < now):
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Invalid or expired verification code"
- )
-
- # 4. 更新用户状态 - 异步更新
- update_stmt = (
- update(User)
- .where(User.id == user.id)
- .values(
- is_verified = True,
- is_active = True,
- verification_code = None,
- verification_code_expires = None,
- updated_at = datetime.now(timezone.utc)
- )
- )
-
- await self.db.execute(update_stmt)
- await self.db.commit()
-
- # 5. 发送欢迎邮件(假设是异步的)
- await email_service.send_welcome_email(user.email, user.username)
-
- return True
-
- async def resend_verification_email(
- self,
- email: str,
- background_tasks: BackgroundTasks
- ) -> bool:
- """重新发送验证邮件(SQLAlchemy 2.0异步版)"""
- # 1. 查找用户 - 异步查询
- stmt = select(User).where(User.email == email)
- result = await self.db.execute(stmt)
- user = result.scalar_one_or_none()
-
- if not user:
- # 出于安全考虑,即使用户不存在也返回成功
- return True
-
- if user.is_verified:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Email already verified"
- )
-
- # 2. 生成新的验证码
- verification_code = generate_verification_code()
-
- # 3. 更新用户验证信息 - 异步更新
- update_stmt = (
- update(User)
- .where(User.id == user.id)
- .values(
- verification_code = verification_code,
- verification_code_expires = datetime.now(timezone.utc) + timedelta(hours=24),
- updated_at=datetime.now(timezone.utc)
- )
- )
-
- await self.db.execute(update_stmt)
- await self.db.commit()
-
- # 4. 发送验证邮件
- verification_token = create_verification_token(user.email)
- verification_url = f"{settings.FRONTEND_URL}/verify-email?token={verification_token}"
-
- background_tasks.add_task(
- email_service.send_verification_email,
- user.email,
- user.username,
- verification_url,
- verification_code
- )
-
- return True
-
- async def request_password_reset(
- self,
- email: str,
- background_tasks: BackgroundTasks
- ) -> bool:
- """请求密码重置(SQLAlchemy 2.0异步版)"""
- # 1. 查找用户 - 异步查询
- stmt = select(User).where(User.email == email)
- result = await self.db.execute(stmt)
- user = result.scalar_one_or_none()
-
- if not user:
- # 出于安全考虑,即使用户不存在也返回成功
- return True
-
- if not user.is_active:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Account is not active"
- )
-
- # 2. 生成重置令牌
- reset_token = create_reset_token(user.email)
- reset_url = f"{settings.FRONTEND_URL}/reset-password?token={reset_token}"
-
- # 3. 保存重置令牌到数据库 - 异步添加
- reset_token_entry = Token(
- token=reset_token,
- token_type="reset",
- expires_at=datetime.now(timezone.utc) + timedelta(minutes=settings.RESET_TOKEN_EXPIRE_MINUTES),
- user_id=user.id
- )
-
- self.db.add(reset_token_entry)
- await self.db.commit()
-
- # 4. 发送重置邮件
- background_tasks.add_task(
- email_service.send_password_reset_email,
- user.email,
- user.username,
- reset_url
- )
-
- return True
-
- async def reset_password(
- self,
- token: str,
- new_password: str
- ) -> bool:
- """重置密码(SQLAlchemy 2.0异步版)"""
- # 1. 验证令牌
- payload = decode_token(token)
-
- if not payload or payload.get("type") != "reset":
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Invalid reset token"
- )
-
- email = payload.get("email")
- if not email:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Invalid token payload"
- )
-
- # 2. 检查令牌是否在数据库中且未过期 - 异步查询
- stmt = select(Token).where(
- Token.token == token,
- Token.token_type == "reset",
- Token.is_revoked == False,
- Token.expires_at > datetime.now(timezone.utc)
- )
- result = await self.db.execute(stmt)
- token_entry = result.scalar_one_or_none()
-
- if not token_entry:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Invalid or expired reset token"
- )
-
- # 3. 查找用户 - 异步查询
- user_stmt = select(User).where(User.email == email)
- user_result = await self.db.execute(user_stmt)
- user = user_result.scalar_one_or_none()
-
- if not user:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="User not found"
- )
-
- if not user.is_active:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Account is not active"
- )
-
- # 4. 更新密码 - 异步更新
- from ..core.security import get_password_hash
-
- update_user_stmt = (
- update(User)
- .where(User.id == user.id)
- .values(
- hashed_password = get_password_hash(new_password),
- last_password_change = datetime.now(timezone.utc),
- updated_at = datetime.now(timezone.utc)
- )
- )
-
- # 5. 撤销所有现有令牌 - 异步更新
- revoke_tokens_stmt = (
- update(Token)
- .where(
- Token.user_id == user.id,
- Token.token_type.in_(["access", "refresh"])
- )
- .values(is_revoked = True)
- )
-
- # 6. 标记重置令牌为已使用 - 异步更新
- revoke_reset_token_stmt = (
- update(Token)
- .where(Token.token == token)
- .values(is_revoked=True)
- )
-
- # 执行所有更新
- await self.db.execute(update_user_stmt)
- await self.db.execute(revoke_tokens_stmt)
- await self.db.execute(revoke_reset_token_stmt)
- await self.db.commit()
-
- return True
-
- async def refresh_token(
- self,
- refresh_token: str,
- ip_address: Optional[str] = None,
- user_agent: Optional[str] = None
- ) -> TokenResponse:
- """刷新访问令牌(SQLAlchemy 2.0异步版)"""
- # 1. 验证刷新令牌
- payload = decode_token(refresh_token)
-
- if not payload or payload.get("type") != "refresh":
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Invalid refresh token"
- )
-
- # 2. 检查令牌是否在数据库中且有效 - 异步查询
- token_stmt = select(Token).where(
- Token.token == refresh_token,
- Token.token_type == "refresh",
- Token.is_revoked == False,
- Token.expires_at > datetime.now(timezone.utc)
- )
- token_result = await self.db.execute(token_stmt)
- token_entry = token_result.scalar_one_or_none()
-
- if not token_entry:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Invalid refresh token"
- )
-
- username = payload.get("sub")
- user_id = payload.get("user_id")
-
- if not username or not user_id:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Invalid token payload"
- )
-
- # 3. 查找用户 - 异步查询
- user_stmt = select(User).where(
- User.id == user_id,
- User.username == username,
- User.is_active == True
- )
- user_result = await self.db.execute(user_stmt)
- user = user_result.scalar_one_or_none()
-
- if not user:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="User not found or inactive"
- )
-
- # 4. 创建新的访问令牌
- access_token = create_access_token({"sub": user.username, "user_id": user.id})
-
- # 5. 创建新的刷新令牌(刷新令牌轮换)
- new_refresh_token = create_refresh_token({"sub": user.username, "user_id": user.id})
-
- # 6. 保存新的刷新令牌 - 异步添加
- new_refresh_token_entry = Token(
- token=new_refresh_token,
- token_type = "refresh",
- expires_at = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS),
- user_id = user.id,
- ip_address = ip_address,
- user_agent = user_agent
- )
-
- # 7. 标记旧令牌为已撤销 - 异步更新
- revoke_stmt = (
- update(Token)
- .where(Token.token == refresh_token)
- .values(is_revoked = True)
- )
-
- # 执行更新和添加操作
- await self.db.execute(revoke_stmt)
- self.db.add(new_refresh_token_entry)
- await self.db.commit()
-
- # 8. 构建响应
- return TokenResponse(
- access_token=access_token,
- refresh_token=new_refresh_token,
- expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
- user=user
- )
-
- async def logout(
- self,
- refresh_token: str
- ) -> bool:
- """用户登出(SQLAlchemy 2.0异步版)"""
- # 1. 查找刷新令牌 - 异步查询
- stmt = select(Token).where(
- Token.token == refresh_token,
- Token.token_type == "refresh"
- )
- result = await self.db.execute(stmt)
- token_entry = result.scalar_one_or_none()
-
- # 2. 如果找到令牌,则撤销它 - 异步更新
- if token_entry:
- update_stmt = (
- update(Token)
- .where(Token.token == refresh_token)
- .values(is_revoked = True)
- )
- await self.db.execute(update_stmt)
- await self.db.commit()
-
- return True
-
- async def logout_all(
- self,
- user_id: int
- ) -> bool:
- """撤销用户的所有令牌(SQLAlchemy 2.0异步版)"""
- # 撤销用户的所有访问和刷新令牌 - 异步更新
- stmt = (
- update(Token)
- .where(
- Token.user_id == user_id,
- Token.token_type.in_(["access", "refresh"]),
- Token.is_revoked == False
- )
- .values(is_revoked = True)
- )
-
- result = await self.db.execute(stmt)
- await self.db.commit()
-
- return result.rowcount > 0
|