CaiYouHui后端fastapi实现

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. # app/services/user_service.py
  2. from sqlalchemy import select, update, delete, func, or_
  3. from sqlalchemy.orm import selectinload
  4. from sqlalchemy.ext.asyncio import AsyncSession
  5. from typing import List, Optional, Dict, Any
  6. from datetime import datetime, timezone, timedelta
  7. from app.models.user import User
  8. from app.core.security import password_hasher, password_validator
  9. from ..schemas.user import UserCreate, UserUpdate, UserResponse
  10. class UserService:
  11. """用户服务"""
  12. def __init__(self, db: AsyncSession): # 异步Session
  13. self.db = db
  14. async def create_user(self, user_data: UserCreate) -> User:
  15. """创建用户(异步版本)"""
  16. # 验证密码强度
  17. is_valid, error_msg = password_validator.validate_password_strength(user_data.password)
  18. if not is_valid:
  19. raise ValueError(error_msg)
  20. # 检查用户是否已存在 - 使用异步查询
  21. stmt = select(User).where(
  22. or_(
  23. User.username == user_data.username,
  24. User.email == user_data.email
  25. )
  26. )
  27. result = await self.db.execute(stmt)
  28. existing_user = result.scalar_one_or_none()
  29. if existing_user:
  30. if existing_user.username == user_data.username:
  31. raise ValueError("用户名已存在")
  32. else:
  33. raise ValueError("邮箱已被注册")
  34. # 哈希密码
  35. hashed_password = password_hasher.hash_password(user_data.password)
  36. # 创建用户对象
  37. user = User(
  38. username = user_data.username,
  39. email = user_data.email,
  40. hashed_password = hashed_password,
  41. full_name = user_data.full_name,
  42. is_active = True,
  43. is_verified = False
  44. )
  45. # 异步保存
  46. self.db.add(user)
  47. await self.db.commit()
  48. await self.db.refresh(user)
  49. return user
  50. async def get_user_by_id(self, user_id: int) -> Optional[User]:
  51. """通过ID获取用户(异步)"""
  52. stmt = select(User).where(User.id == user_id)
  53. result = await self.db.execute(stmt)
  54. return result.scalar_one_or_none()
  55. async def get_user_by_username(self, username: str) -> Optional[User]:
  56. """通过用户名获取用户(异步)"""
  57. stmt = select(User).where(User.username == username)
  58. result = await self.db.execute(stmt)
  59. return result.scalar_one_or_none()
  60. async def get_user_by_email(self, email: str) -> Optional[User]:
  61. """通过邮箱获取用户(异步)"""
  62. stmt = select(User).where(User.email == email)
  63. result = await self.db.execute(stmt)
  64. return result.scalar_one_or_none()
  65. async def authenticate_user(self, username: str, password: str) -> Optional[User]:
  66. """验证用户(SQLAlchemy 2.0异步版)"""
  67. # 通过用户名或邮箱查找用户 - 使用异步select
  68. stmt = select(User).where(
  69. or_(
  70. User.username == username,
  71. User.email == username
  72. )
  73. )
  74. result = await self.db.execute(stmt)
  75. user = result.scalar_one_or_none()
  76. if not user:
  77. return None
  78. # 验证密码
  79. if not password_hasher.verify_password(password, user.hashed_password):
  80. # 记录失败尝试 - 使用异步更新
  81. new_attempts = (user.failed_login_attempts or 0) + 1
  82. update_stmt = (
  83. update(User)
  84. .where(User.id == user.id)
  85. .values(
  86. failed_login_attempts=new_attempts,
  87. is_locked=(new_attempts >= 5) # 5次失败后锁定
  88. )
  89. )
  90. await self.db.execute(update_stmt)
  91. await self.db.commit()
  92. return None
  93. # 登录成功 - 重置失败尝试并更新最后登录时间
  94. update_stmt = (
  95. update(User)
  96. .where(User.id == user.id)
  97. .values(
  98. failed_login_attempts = 0,
  99. last_login = datetime.now(timezone.utc),
  100. is_locked = False
  101. )
  102. )
  103. await self.db.execute(update_stmt)
  104. await self.db.commit()
  105. # 4. 刷新用户对象以获取最新数据
  106. await self.db.refresh(user)
  107. return user
  108. # ==================== 用户操作 ====================
  109. async def update_user(self, user_id: int, update_data: dict) -> Optional[User]:
  110. """更新用户信息(异步)"""
  111. user = await self.get_user_by_id(user_id)
  112. if not user:
  113. return None
  114. for key, value in update_data.items():
  115. if hasattr(user, key) and value is not None:
  116. setattr(user, key, value)
  117. user.updated_at = datetime.now(timezone.utc)
  118. await self.db.commit()
  119. await self.db.refresh(user)
  120. return user
  121. async def change_password(self, user_id: int, current_password: str, new_password: str) -> bool:
  122. """修改用户密码(异步)"""
  123. # 获取用户
  124. user = await self.get_user_by_id(user_id)
  125. if not user:
  126. return False
  127. # 验证当前密码
  128. if not password_hasher.verify_password(current_password, user.hashed_password):
  129. return False
  130. # 验证新密码强度
  131. is_valid, error_msg = password_validator.validate_password_strength(new_password)
  132. if not is_valid:
  133. raise ValueError(error_msg)
  134. # 更新密码
  135. user.hashed_password = password_hasher.hash_password(new_password)
  136. user.last_password_change = datetime.now(timezone.utc)
  137. await self.db.commit()
  138. return True
  139. # ==================== 用户列表 ====================
  140. async def list_users(self, skip: int = 0, limit: int = 100, active_only: bool = True) -> List[User]:
  141. """获取用户列表(异步)"""
  142. stmt = select(User)
  143. if active_only:
  144. stmt = stmt.where(User.is_active == True)
  145. stmt = stmt.offset(skip).limit(limit)
  146. result = await self.db.execute(stmt)
  147. users = result.scalars().all()
  148. return users
  149. async def delete_user(self, user_id: int) -> bool:
  150. """删除用户(软删除)"""
  151. """删除/禁用用户(异步)"""
  152. # 通常我们不会真正删除,而是标记为禁用
  153. stmt = (
  154. update(User)
  155. .where(User.id == user_id)
  156. .values(is_active = False, updated_at = datetime.now(timezone.utc))
  157. )
  158. result = await self.db.execute(stmt)
  159. await self.db.commit()
  160. # 检查是否影响了一行
  161. return result.rowcount > 0
  162. async def count_users(self, active_only: bool = True) -> int:
  163. """统计用户数量(异步)"""
  164. if active_only:
  165. stmt = select(func.count()).where(User.is_active == True)
  166. else:
  167. stmt = select(func.count())
  168. result = await self.db.execute(stmt)
  169. return result.scalar() or 0
  170. async def update_last_login(self, user_id: int) -> None:
  171. """更新最后登录时间(异步)"""
  172. stmt = (
  173. update(User)
  174. .where(User.id == user_id)
  175. .values(last_login=datetime.now(timezone.utc))
  176. )
  177. await self.db.execute(stmt)
  178. await self.db.commit()