| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498 |
- from typing import Optional, Dict, Any, Tuple
- from datetime import datetime, timedelta
- from sqlalchemy.orm import Session
- from fastapi import HTTPException, status, BackgroundTasks
- import secrets
- import string
-
- from ..models.user import User
- from ..models.token import Token
- from ..schemas.auth import LoginRequest, TokenResponse
- from ..core.security import (
- verify_password,
- create_access_token,
- create_refresh_token,
- create_verification_token,
- create_reset_token,
- decode_token,
- generate_verification_code
- )
- from ..core.email import email_service
- from ..config import settings
-
- class AuthService:
- def __init__(self, db: Session):
- self.db = db
-
- async def register_user(
- self,
- user_data: Dict[str, Any],
- background_tasks: BackgroundTasks,
- ip_address: Optional[str] = None,
- user_agent: Optional[str] = None
- ) -> User:
- """注册新用户"""
- # 检查用户是否存在
- existing_user = self.db.query(User).filter(
- (User.username == user_data["username"]) |
- (User.email == user_data["email"])
- ).first()
-
- if existing_user:
- if existing_user.username == user_data["username"]:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Username already registered"
- )
- else:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Email already registered"
- )
-
- # 创建用户
- from ..core.security import get_password_hash
- hashed_password = get_password_hash(user_data["password"])
-
- user = User(
- username=user_data["username"],
- email=user_data["email"],
- hashed_password=hashed_password,
- first_name=user_data.get("first_name"),
- last_name=user_data.get("last_name"),
- is_active=False, # 需要邮箱验证
- is_verified=False
- )
-
- # 生成验证码
- verification_code = generate_verification_code()
- user.verification_code = verification_code
- user.verification_code_expires = datetime.utcnow() + timedelta(hours=24)
-
- self.db.add(user)
- self.db.commit()
- self.db.refresh(user)
-
- # 发送验证邮件(后台任务)
- verification_token = create_verification_token(user.email)
- verification_url = f"{settings.FRONTEND_URL}/verify-email?token={verification_token}"
-
- background_tasks.add_task(
- email_service.send_verification_email,
- user.email,
- user.username,
- verification_url,
- verification_code
- )
-
- return user
-
- async def login(
- self,
- login_data: LoginRequest,
- ip_address: Optional[str] = None,
- user_agent: Optional[str] = None
- ) -> Tuple[TokenResponse, User]:
- """用户登录"""
- user = self.db.query(User).filter(
- (User.username == login_data.username) |
- (User.email == login_data.username)
- ).first()
-
- if not user:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Incorrect username or password"
- )
-
- # 检查账户是否被锁定
- if user.is_locked and user.locked_until and user.locked_until > datetime.utcnow():
- raise HTTPException(
- status_code=status.HTTP_423_LOCKED,
- detail=f"Account is locked until {user.locked_until}"
- )
-
- # 验证密码
- if not verify_password(login_data.password, user.hashed_password):
- # 记录失败尝试
- user.failed_login_attempts += 1
-
- # 如果失败次数超过5次,锁定账户
- if user.failed_login_attempts >= 5:
- user.is_locked = True
- user.locked_until = datetime.utcnow() + timedelta(minutes=30)
-
- self.db.commit()
-
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Incorrect username or password"
- )
-
- # 检查邮箱是否已验证
- if not user.is_verified:
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail="Email not verified"
- )
-
- # 检查账户是否激活
- if not user.is_active:
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail="Account is not active"
- )
-
- # 重置失败尝试次数
- user.failed_login_attempts = 0
- user.last_login = datetime.utcnow()
- user.is_locked = False
- user.locked_until = None
- self.db.commit()
-
- # 创建令牌
- access_token = create_access_token({"sub": user.username, "user_id": user.id})
- refresh_token = create_refresh_token({"sub": user.username, "user_id": user.id})
-
- # 保存刷新令牌到数据库
- refresh_token_entry = Token(
- token=refresh_token,
- token_type="refresh",
- expires_at=datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS),
- user_id=user.id,
- ip_address=ip_address,
- user_agent=user_agent
- )
-
- self.db.add(refresh_token_entry)
- self.db.commit()
-
- # 构建响应
- token_response = TokenResponse(
- access_token=access_token,
- refresh_token=refresh_token,
- expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
- user=user
- )
-
- return token_response, user
-
- async def verify_email(
- self,
- token: str,
- code: Optional[str] = None
- ) -> bool:
- """验证邮箱"""
- payload = decode_token(token)
-
- if not payload or payload.get("type") != "verify":
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Invalid verification token"
- )
-
- email = payload.get("email")
- if not email:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Invalid token payload"
- )
-
- user = self.db.query(User).filter(User.email == email).first()
-
- if not user:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="User not found"
- )
-
- if user.is_verified:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Email already verified"
- )
-
- # 验证码验证
- if code:
- if (not user.verification_code or
- user.verification_code != code or
- not user.verification_code_expires or
- user.verification_code_expires < datetime.utcnow()):
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Invalid or expired verification code"
- )
-
- # 更新用户状态
- user.is_verified = True
- user.is_active = True
- user.verification_code = None
- user.verification_code_expires = None
- self.db.commit()
-
- # 发送欢迎邮件
- await email_service.send_welcome_email(user.email, user.username)
-
- return True
-
- async def resend_verification_email(
- self,
- email: str,
- background_tasks: BackgroundTasks
- ) -> bool:
- """重新发送验证邮件"""
- user = self.db.query(User).filter(User.email == email).first()
-
- if not user:
- # 出于安全考虑,即使用户不存在也返回成功
- return True
-
- if user.is_verified:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Email already verified"
- )
-
- # 生成新的验证码
- verification_code = generate_verification_code()
- user.verification_code = verification_code
- user.verification_code_expires = datetime.utcnow() + timedelta(hours=24)
-
- self.db.commit()
-
- # 发送验证邮件
- verification_token = create_verification_token(user.email)
- verification_url = f"{settings.FRONTEND_URL}/verify-email?token={verification_token}"
-
- background_tasks.add_task(
- email_service.send_verification_email,
- user.email,
- user.username,
- verification_url,
- verification_code
- )
-
- return True
-
- async def request_password_reset(
- self,
- email: str,
- background_tasks: BackgroundTasks
- ) -> bool:
- """请求密码重置"""
- user = self.db.query(User).filter(User.email == email).first()
-
- if not user:
- # 出于安全考虑,即使用户不存在也返回成功
- return True
-
- if not user.is_active:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Account is not active"
- )
-
- # 生成重置令牌
- reset_token = create_reset_token(user.email)
- reset_url = f"{settings.FRONTEND_URL}/reset-password?token={reset_token}"
-
- # 保存重置令牌到数据库
- reset_token_entry = Token(
- token=reset_token,
- token_type="reset",
- expires_at=datetime.utcnow() + timedelta(minutes=settings.RESET_TOKEN_EXPIRE_MINUTES),
- user_id=user.id
- )
-
- self.db.add(reset_token_entry)
- self.db.commit()
-
- # 发送重置邮件
- background_tasks.add_task(
- email_service.send_password_reset_email,
- user.email,
- user.username,
- reset_url
- )
-
- return True
-
- async def reset_password(
- self,
- token: str,
- new_password: str
- ) -> bool:
- """重置密码"""
- # 验证令牌
- payload = decode_token(token)
-
- if not payload or payload.get("type") != "reset":
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Invalid reset token"
- )
-
- email = payload.get("email")
- if not email:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Invalid token payload"
- )
-
- # 检查令牌是否在数据库中且未过期
- token_entry = self.db.query(Token).filter(
- Token.token == token,
- Token.token_type == "reset",
- Token.is_revoked == False,
- Token.expires_at > datetime.utcnow()
- ).first()
-
- if not token_entry:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Invalid or expired reset token"
- )
-
- user = self.db.query(User).filter(User.email == email).first()
-
- if not user:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="User not found"
- )
-
- if not user.is_active:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Account is not active"
- )
-
- # 更新密码
- from ..core.security import get_password_hash
- user.hashed_password = get_password_hash(new_password)
- user.last_password_change = datetime.utcnow()
-
- # 撤销所有现有令牌
- self.db.query(Token).filter(
- Token.user_id == user.id,
- Token.token_type.in_(["access", "refresh"])
- ).update({"is_revoked": True})
-
- # 标记重置令牌为已使用
- token_entry.is_revoked = True
-
- self.db.commit()
-
- return True
-
- async def refresh_token(
- self,
- refresh_token: str,
- ip_address: Optional[str] = None,
- user_agent: Optional[str] = None
- ) -> TokenResponse:
- """刷新访问令牌"""
- # 验证刷新令牌
- payload = decode_token(refresh_token)
-
- if not payload or payload.get("type") != "refresh":
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Invalid refresh token"
- )
-
- # 检查令牌是否在数据库中且有效
- token_entry = self.db.query(Token).filter(
- Token.token == refresh_token,
- Token.token_type == "refresh",
- Token.is_revoked == False,
- Token.expires_at > datetime.utcnow()
- ).first()
-
- if not token_entry:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Invalid refresh token"
- )
-
- username = payload.get("sub")
- user_id = payload.get("user_id")
-
- if not username or not user_id:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Invalid token payload"
- )
-
- user = self.db.query(User).filter(
- User.id == user_id,
- User.username == username,
- User.is_active == True
- ).first()
-
- if not user:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="User not found or inactive"
- )
-
- # 创建新的访问令牌
- access_token = create_access_token({"sub": user.username, "user_id": user.id})
-
- # 可选:创建新的刷新令牌(刷新令牌轮换)
- new_refresh_token = create_refresh_token({"sub": user.username, "user_id": user.id})
-
- # 保存新的刷新令牌
- new_refresh_token_entry = Token(
- token=new_refresh_token,
- token_type="refresh",
- expires_at=datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS),
- user_id=user.id,
- ip_address=ip_address,
- user_agent=user_agent
- )
-
- # 标记旧令牌为已撤销
- token_entry.is_revoked = True
-
- self.db.add(new_refresh_token_entry)
- self.db.commit()
-
- # 构建响应
- return TokenResponse(
- access_token=access_token,
- refresh_token=new_refresh_token,
- expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
- user=user
- )
-
- async def logout(
- self,
- refresh_token: str
- ) -> bool:
- """用户登出"""
- # 撤销刷新令牌
- token_entry = self.db.query(Token).filter(
- Token.token == refresh_token,
- Token.token_type == "refresh"
- ).first()
-
- if token_entry:
- token_entry.is_revoked = True
- self.db.commit()
-
- return True
-
- async def logout_all(
- self,
- user_id: int
- ) -> bool:
- """撤销用户的所有令牌"""
- self.db.query(Token).filter(
- Token.user_id == user_id,
- Token.token_type.in_(["access", "refresh"]),
- Token.is_revoked == False
- ).update({"is_revoked": True})
-
- self.db.commit()
- return True
|