from typing import Optional, Dict, Any, Tuple from datetime import datetime, timezone, timedelta from sqlalchemy import select, update, delete, func, or_ from sqlalchemy.ext.asyncio import AsyncSession from fastapi import HTTPException, status, BackgroundTasks import secrets import string from app.models.user import User from app.models.token import Token from ..schemas.auth import LoginRequest, TokenResponse from app.core.security import ( verify_password, create_access_token, create_refresh_token, create_verification_token, create_reset_token, decode_token, generate_verification_code ) from app.core.email import email_service from ..config import settings class AuthService: def __init__(self, db: AsyncSession): self.db = db async def register_user( self, user_data: Dict[str, Any], background_tasks: BackgroundTasks, ip_address: Optional[str] = None, user_agent: Optional[str] = None ) -> User: """注册新用户(SQLAlchemy 2.0异步版)""" # 1. 检查用户是否存在 - 使用异步查询 stmt = select(User).where( or_( User.username == user_data["username"], User.email == user_data["email"] ) ) result = await self.db.execute(stmt) existing_user = result.scalar_one_or_none() if existing_user: if existing_user.username == user_data["username"]: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Username already registered" ) else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered" ) # 2. 创建用户 from ..core.security import get_password_hash hashed_password = get_password_hash(user_data["password"]) user = User( username=user_data["username"], email=user_data["email"], hashed_password=hashed_password, first_name=user_data.get("first_name"), last_name=user_data.get("last_name"), is_active=False, # 需要邮箱验证 is_verified=False ) # 3. 生成验证码 verification_code = generate_verification_code() user.verification_code = verification_code user.verification_code_expires = datetime.now(timezone.utc) + timedelta(hours=24) # 4. 异步保存用户 self.db.add(user) await self.db.commit() await self.db.refresh(user) # 5. 发送验证邮件(后台任务) verification_token = create_verification_token(user.email) verification_url = f"{settings.FRONTEND_URL}/verify-email?token={verification_token}" background_tasks.add_task( email_service.send_verification_email, user.email, user.username, verification_url, verification_code ) return user async def login( self, login_data: LoginRequest, ip_address: Optional[str] = None, user_agent: Optional[str] = None ) -> Tuple[TokenResponse, User]: """用户登录(SQLAlchemy 2.0异步版)""" # 1. 查找用户 - 异步查询 stmt = select(User).where( or_( User.username == login_data.username, User.email == login_data.username ) ) result = await self.db.execute(stmt) user = result.scalar_one_or_none() if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect username or password" ) # 2. 检查账户是否被锁定 if user.is_locked and user.locked_until and user.locked_until > datetime.now(timezone.utc): raise HTTPException( status_code=status.HTTP_423_LOCKED, detail=f"Account is locked until {user.locked_until}" ) # 3. 验证密码 if not verify_password(login_data.password, user.hashed_password): # 记录失败尝试 - 异步更新 new_attempts = (user.failed_login_attempts or 0) + 1 update_data = { "failed_login_attempts": new_attempts, "updated_at": datetime.now(timezone.utc) } # 如果失败次数超过5次,锁定账户 if new_attempts >= 5: update_data.update({ "is_locked" : True, "locked_until": datetime.now(timezone.utc) + timedelta(minutes=30) }) update_stmt = ( update(User) .where(User.id == user.id) .values(**update_data) ) await self.db.execute(update_stmt) await self.db.commit() raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect username or password" ) # 4. 检查邮箱是否已验证 if not user.is_verified: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Email not verified" ) # 5. 检查账户是否激活 if not user.is_active: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Account is not active" ) # 6. 重置失败尝试次数 - 异步更新 reset_stmt = ( update(User) .where(User.id == user.id) .values( failed_login_attempts = 0, last_login = datetime.now(timezone.utc), is_locked = False, locked_until = None, ) ) await self.db.execute(reset_stmt) # 7. 创建令牌 access_token = create_access_token({"sub": user.username, "user_id": user.id}) refresh_token = create_refresh_token({"sub": user.username, "user_id": user.id}) # 8. 保存刷新令牌到数据库 - 异步添加 refresh_token_entry = Token( token=refresh_token, token_type="refresh", expires_at=datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS), user_id=user.id, ip_address=ip_address, user_agent=user_agent ) self.db.add(refresh_token_entry) # 9. 提交所有更改 await self.db.commit() # 10. 刷新用户对象以获取最新数据 await self.db.refresh(user) # 11. 构建响应 token_response = TokenResponse( access_token=access_token, refresh_token=refresh_token, expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60, user=user ) return token_response, user async def verify_email( self, token: str, code: Optional[str] = None ) -> bool: """验证邮箱(SQLAlchemy 2.0异步版)""" # 1. 解码令牌 payload = decode_token(token) if not payload or payload.get("type") != "verify": raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid verification token" ) email = payload.get("email") if not email: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid token payload" ) # 2. 查找用户 - 异步查询 stmt = select(User).where(User.email == email) result = await self.db.execute(stmt) user = result.scalar_one_or_none() if not user: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="User not found" ) if user.is_verified: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Email already verified" ) # 3. 验证码验证 if code: now = datetime.now(timezone.utc) if (not user.verification_code or user.verification_code != code or not user.verification_code_expires or user.verification_code_expires < now): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or expired verification code" ) # 4. 更新用户状态 - 异步更新 update_stmt = ( update(User) .where(User.id == user.id) .values( is_verified = True, is_active = True, verification_code = None, verification_code_expires = None, updated_at = datetime.now(timezone.utc) ) ) await self.db.execute(update_stmt) await self.db.commit() # 5. 发送欢迎邮件(假设是异步的) await email_service.send_welcome_email(user.email, user.username) return True async def resend_verification_email( self, email: str, background_tasks: BackgroundTasks ) -> bool: """重新发送验证邮件(SQLAlchemy 2.0异步版)""" # 1. 查找用户 - 异步查询 stmt = select(User).where(User.email == email) result = await self.db.execute(stmt) user = result.scalar_one_or_none() if not user: # 出于安全考虑,即使用户不存在也返回成功 return True if user.is_verified: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Email already verified" ) # 2. 生成新的验证码 verification_code = generate_verification_code() # 3. 更新用户验证信息 - 异步更新 update_stmt = ( update(User) .where(User.id == user.id) .values( verification_code = verification_code, verification_code_expires = datetime.now(timezone.utc) + timedelta(hours=24), updated_at=datetime.now(timezone.utc) ) ) await self.db.execute(update_stmt) await self.db.commit() # 4. 发送验证邮件 verification_token = create_verification_token(user.email) verification_url = f"{settings.FRONTEND_URL}/verify-email?token={verification_token}" background_tasks.add_task( email_service.send_verification_email, user.email, user.username, verification_url, verification_code ) return True async def request_password_reset( self, email: str, background_tasks: BackgroundTasks ) -> bool: """请求密码重置(SQLAlchemy 2.0异步版)""" # 1. 查找用户 - 异步查询 stmt = select(User).where(User.email == email) result = await self.db.execute(stmt) user = result.scalar_one_or_none() if not user: # 出于安全考虑,即使用户不存在也返回成功 return True if not user.is_active: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Account is not active" ) # 2. 生成重置令牌 reset_token = create_reset_token(user.email) reset_url = f"{settings.FRONTEND_URL}/reset-password?token={reset_token}" # 3. 保存重置令牌到数据库 - 异步添加 reset_token_entry = Token( token=reset_token, token_type="reset", expires_at=datetime.now(timezone.utc) + timedelta(minutes=settings.RESET_TOKEN_EXPIRE_MINUTES), user_id=user.id ) self.db.add(reset_token_entry) await self.db.commit() # 4. 发送重置邮件 background_tasks.add_task( email_service.send_password_reset_email, user.email, user.username, reset_url ) return True async def reset_password( self, token: str, new_password: str ) -> bool: """重置密码(SQLAlchemy 2.0异步版)""" # 1. 验证令牌 payload = decode_token(token) if not payload or payload.get("type") != "reset": raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid reset token" ) email = payload.get("email") if not email: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid token payload" ) # 2. 检查令牌是否在数据库中且未过期 - 异步查询 stmt = select(Token).where( Token.token == token, Token.token_type == "reset", Token.is_revoked == False, Token.expires_at > datetime.now(timezone.utc) ) result = await self.db.execute(stmt) token_entry = result.scalar_one_or_none() if not token_entry: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or expired reset token" ) # 3. 查找用户 - 异步查询 user_stmt = select(User).where(User.email == email) user_result = await self.db.execute(user_stmt) user = user_result.scalar_one_or_none() if not user: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="User not found" ) if not user.is_active: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Account is not active" ) # 4. 更新密码 - 异步更新 from ..core.security import get_password_hash update_user_stmt = ( update(User) .where(User.id == user.id) .values( hashed_password = get_password_hash(new_password), last_password_change = datetime.now(timezone.utc), updated_at = datetime.now(timezone.utc) ) ) # 5. 撤销所有现有令牌 - 异步更新 revoke_tokens_stmt = ( update(Token) .where( Token.user_id == user.id, Token.token_type.in_(["access", "refresh"]) ) .values(is_revoked = True) ) # 6. 标记重置令牌为已使用 - 异步更新 revoke_reset_token_stmt = ( update(Token) .where(Token.token == token) .values(is_revoked=True) ) # 执行所有更新 await self.db.execute(update_user_stmt) await self.db.execute(revoke_tokens_stmt) await self.db.execute(revoke_reset_token_stmt) await self.db.commit() return True async def refresh_token( self, refresh_token: str, ip_address: Optional[str] = None, user_agent: Optional[str] = None ) -> TokenResponse: """刷新访问令牌(SQLAlchemy 2.0异步版)""" # 1. 验证刷新令牌 payload = decode_token(refresh_token) if not payload or payload.get("type") != "refresh": raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token" ) # 2. 检查令牌是否在数据库中且有效 - 异步查询 token_stmt = select(Token).where( Token.token == refresh_token, Token.token_type == "refresh", Token.is_revoked == False, Token.expires_at > datetime.now(timezone.utc) ) token_result = await self.db.execute(token_stmt) token_entry = token_result.scalar_one_or_none() if not token_entry: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token" ) username = payload.get("sub") user_id = payload.get("user_id") if not username or not user_id: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token payload" ) # 3. 查找用户 - 异步查询 user_stmt = select(User).where( User.id == user_id, User.username == username, User.is_active == True ) user_result = await self.db.execute(user_stmt) user = user_result.scalar_one_or_none() if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive" ) # 4. 创建新的访问令牌 access_token = create_access_token({"sub": user.username, "user_id": user.id}) # 5. 创建新的刷新令牌(刷新令牌轮换) new_refresh_token = create_refresh_token({"sub": user.username, "user_id": user.id}) # 6. 保存新的刷新令牌 - 异步添加 new_refresh_token_entry = Token( token=new_refresh_token, token_type = "refresh", expires_at = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS), user_id = user.id, ip_address = ip_address, user_agent = user_agent ) # 7. 标记旧令牌为已撤销 - 异步更新 revoke_stmt = ( update(Token) .where(Token.token == refresh_token) .values(is_revoked = True) ) # 执行更新和添加操作 await self.db.execute(revoke_stmt) self.db.add(new_refresh_token_entry) await self.db.commit() # 8. 构建响应 return TokenResponse( access_token=access_token, refresh_token=new_refresh_token, expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60, user=user ) async def logout( self, refresh_token: str ) -> bool: """用户登出(SQLAlchemy 2.0异步版)""" # 1. 查找刷新令牌 - 异步查询 stmt = select(Token).where( Token.token == refresh_token, Token.token_type == "refresh" ) result = await self.db.execute(stmt) token_entry = result.scalar_one_or_none() # 2. 如果找到令牌,则撤销它 - 异步更新 if token_entry: update_stmt = ( update(Token) .where(Token.token == refresh_token) .values(is_revoked = True) ) await self.db.execute(update_stmt) await self.db.commit() return True async def logout_all( self, user_id: int ) -> bool: """撤销用户的所有令牌(SQLAlchemy 2.0异步版)""" # 撤销用户的所有访问和刷新令牌 - 异步更新 stmt = ( update(Token) .where( Token.user_id == user_id, Token.token_type.in_(["access", "refresh"]), Token.is_revoked == False ) .values(is_revoked = True) ) result = await self.db.execute(stmt) await self.db.commit() return result.rowcount > 0