from typing import Optional, Dict, Any, Tuple from datetime import datetime, timedelta from sqlalchemy.orm import Session from fastapi import HTTPException, status, BackgroundTasks import secrets import string from ..models.user import User from ..models.token import Token from ..schemas.auth import LoginRequest, TokenResponse from ..core.security import ( verify_password, create_access_token, create_refresh_token, create_verification_token, create_reset_token, decode_token, generate_verification_code ) from ..core.email import email_service from ..config import settings class AuthService: def __init__(self, db: Session): self.db = db async def register_user( self, user_data: Dict[str, Any], background_tasks: BackgroundTasks, ip_address: Optional[str] = None, user_agent: Optional[str] = None ) -> User: """注册新用户""" # 检查用户是否存在 existing_user = self.db.query(User).filter( (User.username == user_data["username"]) | (User.email == user_data["email"]) ).first() if existing_user: if existing_user.username == user_data["username"]: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Username already registered" ) else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered" ) # 创建用户 from ..core.security import get_password_hash hashed_password = get_password_hash(user_data["password"]) user = User( username=user_data["username"], email=user_data["email"], hashed_password=hashed_password, first_name=user_data.get("first_name"), last_name=user_data.get("last_name"), is_active=False, # 需要邮箱验证 is_verified=False ) # 生成验证码 verification_code = generate_verification_code() user.verification_code = verification_code user.verification_code_expires = datetime.utcnow() + timedelta(hours=24) self.db.add(user) self.db.commit() self.db.refresh(user) # 发送验证邮件(后台任务) verification_token = create_verification_token(user.email) verification_url = f"{settings.FRONTEND_URL}/verify-email?token={verification_token}" background_tasks.add_task( email_service.send_verification_email, user.email, user.username, verification_url, verification_code ) return user async def login( self, login_data: LoginRequest, ip_address: Optional[str] = None, user_agent: Optional[str] = None ) -> Tuple[TokenResponse, User]: """用户登录""" user = self.db.query(User).filter( (User.username == login_data.username) | (User.email == login_data.username) ).first() if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect username or password" ) # 检查账户是否被锁定 if user.is_locked and user.locked_until and user.locked_until > datetime.utcnow(): raise HTTPException( status_code=status.HTTP_423_LOCKED, detail=f"Account is locked until {user.locked_until}" ) # 验证密码 if not verify_password(login_data.password, user.hashed_password): # 记录失败尝试 user.failed_login_attempts += 1 # 如果失败次数超过5次,锁定账户 if user.failed_login_attempts >= 5: user.is_locked = True user.locked_until = datetime.utcnow() + timedelta(minutes=30) self.db.commit() raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect username or password" ) # 检查邮箱是否已验证 if not user.is_verified: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Email not verified" ) # 检查账户是否激活 if not user.is_active: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Account is not active" ) # 重置失败尝试次数 user.failed_login_attempts = 0 user.last_login = datetime.utcnow() user.is_locked = False user.locked_until = None self.db.commit() # 创建令牌 access_token = create_access_token({"sub": user.username, "user_id": user.id}) refresh_token = create_refresh_token({"sub": user.username, "user_id": user.id}) # 保存刷新令牌到数据库 refresh_token_entry = Token( token=refresh_token, token_type="refresh", expires_at=datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS), user_id=user.id, ip_address=ip_address, user_agent=user_agent ) self.db.add(refresh_token_entry) self.db.commit() # 构建响应 token_response = TokenResponse( access_token=access_token, refresh_token=refresh_token, expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60, user=user ) return token_response, user async def verify_email( self, token: str, code: Optional[str] = None ) -> bool: """验证邮箱""" payload = decode_token(token) if not payload or payload.get("type") != "verify": raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid verification token" ) email = payload.get("email") if not email: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid token payload" ) user = self.db.query(User).filter(User.email == email).first() if not user: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="User not found" ) if user.is_verified: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Email already verified" ) # 验证码验证 if code: if (not user.verification_code or user.verification_code != code or not user.verification_code_expires or user.verification_code_expires < datetime.utcnow()): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or expired verification code" ) # 更新用户状态 user.is_verified = True user.is_active = True user.verification_code = None user.verification_code_expires = None self.db.commit() # 发送欢迎邮件 await email_service.send_welcome_email(user.email, user.username) return True async def resend_verification_email( self, email: str, background_tasks: BackgroundTasks ) -> bool: """重新发送验证邮件""" user = self.db.query(User).filter(User.email == email).first() if not user: # 出于安全考虑,即使用户不存在也返回成功 return True if user.is_verified: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Email already verified" ) # 生成新的验证码 verification_code = generate_verification_code() user.verification_code = verification_code user.verification_code_expires = datetime.utcnow() + timedelta(hours=24) self.db.commit() # 发送验证邮件 verification_token = create_verification_token(user.email) verification_url = f"{settings.FRONTEND_URL}/verify-email?token={verification_token}" background_tasks.add_task( email_service.send_verification_email, user.email, user.username, verification_url, verification_code ) return True async def request_password_reset( self, email: str, background_tasks: BackgroundTasks ) -> bool: """请求密码重置""" user = self.db.query(User).filter(User.email == email).first() if not user: # 出于安全考虑,即使用户不存在也返回成功 return True if not user.is_active: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Account is not active" ) # 生成重置令牌 reset_token = create_reset_token(user.email) reset_url = f"{settings.FRONTEND_URL}/reset-password?token={reset_token}" # 保存重置令牌到数据库 reset_token_entry = Token( token=reset_token, token_type="reset", expires_at=datetime.utcnow() + timedelta(minutes=settings.RESET_TOKEN_EXPIRE_MINUTES), user_id=user.id ) self.db.add(reset_token_entry) self.db.commit() # 发送重置邮件 background_tasks.add_task( email_service.send_password_reset_email, user.email, user.username, reset_url ) return True async def reset_password( self, token: str, new_password: str ) -> bool: """重置密码""" # 验证令牌 payload = decode_token(token) if not payload or payload.get("type") != "reset": raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid reset token" ) email = payload.get("email") if not email: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid token payload" ) # 检查令牌是否在数据库中且未过期 token_entry = self.db.query(Token).filter( Token.token == token, Token.token_type == "reset", Token.is_revoked == False, Token.expires_at > datetime.utcnow() ).first() if not token_entry: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or expired reset token" ) user = self.db.query(User).filter(User.email == email).first() if not user: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="User not found" ) if not user.is_active: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Account is not active" ) # 更新密码 from ..core.security import get_password_hash user.hashed_password = get_password_hash(new_password) user.last_password_change = datetime.utcnow() # 撤销所有现有令牌 self.db.query(Token).filter( Token.user_id == user.id, Token.token_type.in_(["access", "refresh"]) ).update({"is_revoked": True}) # 标记重置令牌为已使用 token_entry.is_revoked = True self.db.commit() return True async def refresh_token( self, refresh_token: str, ip_address: Optional[str] = None, user_agent: Optional[str] = None ) -> TokenResponse: """刷新访问令牌""" # 验证刷新令牌 payload = decode_token(refresh_token) if not payload or payload.get("type") != "refresh": raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token" ) # 检查令牌是否在数据库中且有效 token_entry = self.db.query(Token).filter( Token.token == refresh_token, Token.token_type == "refresh", Token.is_revoked == False, Token.expires_at > datetime.utcnow() ).first() if not token_entry: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token" ) username = payload.get("sub") user_id = payload.get("user_id") if not username or not user_id: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token payload" ) user = self.db.query(User).filter( User.id == user_id, User.username == username, User.is_active == True ).first() if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive" ) # 创建新的访问令牌 access_token = create_access_token({"sub": user.username, "user_id": user.id}) # 可选:创建新的刷新令牌(刷新令牌轮换) new_refresh_token = create_refresh_token({"sub": user.username, "user_id": user.id}) # 保存新的刷新令牌 new_refresh_token_entry = Token( token=new_refresh_token, token_type="refresh", expires_at=datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS), user_id=user.id, ip_address=ip_address, user_agent=user_agent ) # 标记旧令牌为已撤销 token_entry.is_revoked = True self.db.add(new_refresh_token_entry) self.db.commit() # 构建响应 return TokenResponse( access_token=access_token, refresh_token=new_refresh_token, expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60, user=user ) async def logout( self, refresh_token: str ) -> bool: """用户登出""" # 撤销刷新令牌 token_entry = self.db.query(Token).filter( Token.token == refresh_token, Token.token_type == "refresh" ).first() if token_entry: token_entry.is_revoked = True self.db.commit() return True async def logout_all( self, user_id: int ) -> bool: """撤销用户的所有令牌""" self.db.query(Token).filter( Token.user_id == user_id, Token.token_type.in_(["access", "refresh"]), Token.is_revoked == False ).update({"is_revoked": True}) self.db.commit() return True