# app/services/user_service.py from sqlalchemy import select, update, delete, func, or_ from sqlalchemy.orm import selectinload from sqlalchemy.ext.asyncio import AsyncSession from typing import List, Optional, Dict, Any from datetime import datetime, timezone, timedelta from app.models.user import User from app.core.security import password_hasher, password_validator from ..schemas.user import UserCreate, UserUpdate, UserResponse class UserService: """用户服务""" def __init__(self, db: AsyncSession): # 异步Session self.db = db async def create_user(self, user_data: UserCreate) -> User: """创建用户(异步版本)""" # 验证密码强度 is_valid, error_msg = password_validator.validate_password_strength(user_data.password) if not is_valid: raise ValueError(error_msg) # 检查用户是否已存在 - 使用异步查询 stmt = select(User).where( or_( User.username == user_data.username, User.email == user_data.email ) ) result = await self.db.execute(stmt) existing_user = result.scalar_one_or_none() if existing_user: if existing_user.username == user_data.username: raise ValueError("用户名已存在") else: raise ValueError("邮箱已被注册") # 哈希密码 hashed_password = password_hasher.hash_password(user_data.password) # 创建用户对象 user = User( username = user_data.username, email = user_data.email, hashed_password = hashed_password, full_name = user_data.full_name, is_active = True, is_verified = False ) # 异步保存 self.db.add(user) await self.db.commit() await self.db.refresh(user) return user async def get_user_by_id(self, user_id: int) -> Optional[User]: """通过ID获取用户(异步)""" stmt = select(User).where(User.id == user_id) result = await self.db.execute(stmt) return result.scalar_one_or_none() async def get_user_by_username(self, username: str) -> Optional[User]: """通过用户名获取用户(异步)""" stmt = select(User).where(User.username == username) result = await self.db.execute(stmt) return result.scalar_one_or_none() async def get_user_by_email(self, email: str) -> Optional[User]: """通过邮箱获取用户(异步)""" stmt = select(User).where(User.email == email) result = await self.db.execute(stmt) return result.scalar_one_or_none() async def authenticate_user(self, username: str, password: str) -> Optional[User]: """验证用户(SQLAlchemy 2.0异步版)""" # 通过用户名或邮箱查找用户 - 使用异步select stmt = select(User).where( or_( User.username == username, User.email == username ) ) result = await self.db.execute(stmt) user = result.scalar_one_or_none() if not user: return None # 验证密码 if not password_hasher.verify_password(password, user.hashed_password): # 记录失败尝试 - 使用异步更新 new_attempts = (user.failed_login_attempts or 0) + 1 update_stmt = ( update(User) .where(User.id == user.id) .values( failed_login_attempts=new_attempts, is_locked=(new_attempts >= 5) # 5次失败后锁定 ) ) await self.db.execute(update_stmt) await self.db.commit() return None # 登录成功 - 重置失败尝试并更新最后登录时间 update_stmt = ( update(User) .where(User.id == user.id) .values( failed_login_attempts = 0, last_login = datetime.now(timezone.utc), is_locked = False ) ) await self.db.execute(update_stmt) await self.db.commit() # 4. 刷新用户对象以获取最新数据 await self.db.refresh(user) return user # ==================== 用户操作 ==================== async def update_user(self, user_id: int, update_data: dict) -> Optional[User]: """更新用户信息(异步)""" user = await self.get_user_by_id(user_id) if not user: return None for key, value in update_data.items(): if hasattr(user, key) and value is not None: setattr(user, key, value) user.updated_at = datetime.now(timezone.utc) await self.db.commit() await self.db.refresh(user) return user async def change_password(self, user_id: int, current_password: str, new_password: str) -> bool: """修改用户密码(异步)""" # 获取用户 user = await self.get_user_by_id(user_id) if not user: return False # 验证当前密码 if not password_hasher.verify_password(current_password, user.hashed_password): return False # 验证新密码强度 is_valid, error_msg = password_validator.validate_password_strength(new_password) if not is_valid: raise ValueError(error_msg) # 更新密码 user.hashed_password = password_hasher.hash_password(new_password) user.last_password_change = datetime.now(timezone.utc) await self.db.commit() return True # ==================== 用户列表 ==================== async def list_users(self, skip: int = 0, limit: int = 100, active_only: bool = True) -> List[User]: """获取用户列表(异步)""" stmt = select(User) if active_only: stmt = stmt.where(User.is_active == True) stmt = stmt.offset(skip).limit(limit) result = await self.db.execute(stmt) users = result.scalars().all() return users async def delete_user(self, user_id: int) -> bool: """删除用户(软删除)""" """删除/禁用用户(异步)""" # 通常我们不会真正删除,而是标记为禁用 stmt = ( update(User) .where(User.id == user_id) .values(is_active = False, updated_at = datetime.now(timezone.utc)) ) result = await self.db.execute(stmt) await self.db.commit() # 检查是否影响了一行 return result.rowcount > 0 async def count_users(self, active_only: bool = True) -> int: """统计用户数量(异步)""" if active_only: stmt = select(func.count()).where(User.is_active == True) else: stmt = select(func.count()) result = await self.db.execute(stmt) return result.scalar() or 0 async def update_last_login(self, user_id: int) -> None: """更新最后登录时间(异步)""" stmt = ( update(User) .where(User.id == user_id) .values(last_login=datetime.now(timezone.utc)) ) await self.db.execute(stmt) await self.db.commit()