| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218 |
- # app/services/user_service.py
- from sqlalchemy import select, update, delete, func, or_
- from sqlalchemy.orm import selectinload
- from sqlalchemy.ext.asyncio import AsyncSession
- from typing import List, Optional, Dict, Any
- from datetime import datetime, timezone, timedelta
-
- from app.models.user import User
- from app.core.security import password_hasher, password_validator
- from ..schemas.user import UserCreate, UserUpdate, UserResponse
-
- class UserService:
- """用户服务"""
-
- def __init__(self, db: AsyncSession): # 异步Session
- self.db = db
-
- async def create_user(self, user_data: UserCreate) -> User:
- """创建用户(异步版本)"""
- # 验证密码强度
- is_valid, error_msg = password_validator.validate_password_strength(user_data.password)
- if not is_valid:
- raise ValueError(error_msg)
-
- # 检查用户是否已存在 - 使用异步查询
- stmt = select(User).where(
- or_(
- User.username == user_data.username,
- User.email == user_data.email
- )
- )
- result = await self.db.execute(stmt)
- existing_user = result.scalar_one_or_none()
-
- if existing_user:
- if existing_user.username == user_data.username:
- raise ValueError("用户名已存在")
- else:
- raise ValueError("邮箱已被注册")
-
- # 哈希密码
- hashed_password = password_hasher.hash_password(user_data.password)
-
- # 创建用户对象
- user = User(
- username=user_data.username,
- email=user_data.email,
- hashed_password=hashed_password,
- full_name=user_data.full_name,
- is_active=True,
- is_verified=False
- )
-
- # 异步保存
- self.db.add(user)
- await self.db.commit()
- await self.db.refresh(user)
-
- return user
-
- async def get_user_by_id(self, user_id: int) -> Optional[User]:
- """通过ID获取用户(异步)"""
- stmt = select(User).where(User.id == user_id)
- result = await self.db.execute(stmt)
- return result.scalar_one_or_none()
-
- async def get_user_by_username(self, username: str) -> Optional[User]:
- """通过用户名获取用户(异步)"""
- stmt = select(User).where(User.username == username)
- result = await self.db.execute(stmt)
- return result.scalar_one_or_none()
-
- async def get_user_by_email(self, email: str) -> Optional[User]:
- """通过邮箱获取用户(异步)"""
- stmt = select(User).where(User.email == email)
- result = await self.db.execute(stmt)
- return result.scalar_one_or_none()
-
- async def authenticate_user(self, username: str, password: str) -> Optional[User]:
- """验证用户(SQLAlchemy 2.0异步版)"""
- # 通过用户名或邮箱查找用户 - 使用异步select
- stmt = select(User).where(
- or_(
- User.username == username,
- User.email == username
- )
- )
- result = await self.db.execute(stmt)
- user = result.scalar_one_or_none()
-
- if not user:
- return None
-
- # 验证密码
- if not password_hasher.verify_password(password, user.hashed_password):
- # 记录失败尝试 - 使用异步更新
- new_attempts = (user.failed_login_attempts or 0) + 1
- update_stmt = (
- update(User)
- .where(User.id == user.id)
- .values(
- failed_login_attempts=new_attempts,
- is_locked=(new_attempts >= 5) # 5次失败后锁定
- )
- )
- await self.db.execute(update_stmt)
- await self.db.commit()
- return None
-
- # 登录成功 - 重置失败尝试并更新最后登录时间
- update_stmt = (
- update(User)
- .where(User.id == user.id)
- .values(
- failed_login_attempts = 0,
- last_login = datetime.now(timezone.utc),
- is_locked = False
- )
- )
-
- await self.db.execute(update_stmt)
- await self.db.commit()
-
- # 4. 刷新用户对象以获取最新数据
- await self.db.refresh(user)
-
- return user
-
- # ==================== 用户操作 ====================
- async def update_user(self, user_id: int, update_data: dict) -> Optional[User]:
- """更新用户信息(异步)"""
- user = await self.get_user_by_id(user_id)
- if not user:
- return None
-
- for key, value in update_data.items():
- if hasattr(user, key) and value is not None:
- setattr(user, key, value)
-
- user.updated_at = datetime.now(timezone.utc)
- await self.db.commit()
- await self.db.refresh(user)
-
- return user
-
- async def change_password(self, user_id: int, current_password: str, new_password: str) -> bool:
- """修改用户密码(异步)"""
- # 获取用户
- user = await self.get_user_by_id(user_id)
- if not user:
- return False
-
- # 验证当前密码
- if not password_hasher.verify_password(current_password, user.hashed_password):
- return False
-
- # 验证新密码强度
- is_valid, error_msg = password_validator.validate_password_strength(new_password)
- if not is_valid:
- raise ValueError(error_msg)
-
- # 更新密码
- user.hashed_password = password_hasher.hash_password(new_password)
- user.last_password_change = datetime.now(timezone.utc)
- await self.db.commit()
-
- return True
-
- # ==================== 用户列表 ====================
- async def list_users(self, skip: int = 0, limit: int = 100, active_only: bool = True) -> List[User]:
- """获取用户列表(异步)"""
- stmt = select(User)
-
- if active_only:
- stmt = stmt.where(User.is_active == True)
-
- stmt = stmt.offset(skip).limit(limit)
-
- result = await self.db.execute(stmt)
- users = result.scalars().all()
- return users
-
- async def delete_user(self, user_id: int) -> bool:
- """删除用户(软删除)"""
- """删除/禁用用户(异步)"""
- # 通常我们不会真正删除,而是标记为禁用
- stmt = (
- update(User)
- .where(User.id == user_id)
- .values(is_active = False, updated_at = datetime.now(timezone.utc))
- )
-
- result = await self.db.execute(stmt)
- await self.db.commit()
-
- # 检查是否影响了一行
- return result.rowcount > 0
-
- async def count_users(self, active_only: bool = True) -> int:
- """统计用户数量(异步)"""
- if active_only:
- stmt = select(func.count()).where(User.is_active == True)
- else:
- stmt = select(func.count())
-
- result = await self.db.execute(stmt)
- return result.scalar() or 0
-
- async def update_last_login(self, user_id: int) -> None:
- """更新最后登录时间(异步)"""
- stmt = (
- update(User)
- .where(User.id == user_id)
- .values(last_login=datetime.now(timezone.utc))
- )
-
- await self.db.execute(stmt)
- await self.db.commit()
|