CaiYouHui后端fastapi实现

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. # app/services/user_service.py
  2. from sqlalchemy import select, update, delete, func, or_
  3. from sqlalchemy.orm import selectinload
  4. from sqlalchemy.ext.asyncio import AsyncSession
  5. from typing import List, Optional, Dict, Any
  6. from datetime import datetime, timezone, timedelta
  7. from app.models.user import User
  8. from app.core.email import email_service
  9. from app.core.security import password_hasher, password_validator
  10. from ..schemas.user import UserCreate, UserUpdate, UserResponse
  11. class UserService:
  12. """用户服务"""
  13. def __init__(self, db: AsyncSession): # 异步Session
  14. self.db = db
  15. async def create_user(self, user_data: UserCreate) -> User:
  16. """创建用户(异步版本)"""
  17. # 验证密码强度
  18. is_valid, error_msg = password_validator.validate_password_strength(user_data.password)
  19. if not is_valid:
  20. raise ValueError(error_msg)
  21. # 检查用户是否已存在 - 使用异步查询
  22. stmt = select(User).where(
  23. or_(
  24. User.username == user_data.username,
  25. User.email == user_data.email
  26. )
  27. )
  28. result = await self.db.execute(stmt)
  29. existing_user = result.scalar_one_or_none()
  30. if existing_user:
  31. if existing_user.username == user_data.username:
  32. raise ValueError("用户名已存在")
  33. else:
  34. raise ValueError("邮箱已被注册")
  35. # 哈希密码
  36. hashed_password = password_hasher.hash_password(user_data.password)
  37. # 生成随机昵称
  38. nick_name = "CY_" + email_service.generate_random_string(12)
  39. # 创建用户对象
  40. user = User(
  41. username = user_data.username,
  42. email = user_data.email,
  43. hashed_password = hashed_password,
  44. full_name = nick_name,
  45. is_active = True,
  46. is_verified = False
  47. )
  48. # 异步保存
  49. self.db.add(user)
  50. await self.db.commit()
  51. await self.db.refresh(user)
  52. return user
  53. async def get_user_by_id(self, user_id: int) -> Optional[User]:
  54. """通过ID获取用户(异步)"""
  55. stmt = select(User).where(User.id == user_id)
  56. result = await self.db.execute(stmt)
  57. return result.scalar_one_or_none()
  58. async def get_user_by_username(self, username: str) -> Optional[User]:
  59. """通过用户名获取用户(异步)"""
  60. stmt = select(User).where(User.username == username)
  61. result = await self.db.execute(stmt)
  62. return result.scalar_one_or_none()
  63. async def get_user_by_email(self, email: str) -> Optional[User]:
  64. """通过邮箱获取用户(异步)"""
  65. stmt = select(User).where(User.email == email)
  66. result = await self.db.execute(stmt)
  67. return result.scalar_one_or_none()
  68. async def authenticate_user(self, username: str, password: str) -> Optional[User]:
  69. """验证用户(SQLAlchemy 2.0异步版)"""
  70. # 通过用户名或邮箱查找用户 - 使用异步select
  71. stmt = select(User).where(
  72. or_(
  73. User.username == username,
  74. User.email == username
  75. )
  76. )
  77. result = await self.db.execute(stmt)
  78. user = result.scalar_one_or_none()
  79. if not user:
  80. return None
  81. # 验证密码
  82. if not password_hasher.verify_password(password, user.hashed_password):
  83. # 记录失败尝试 - 使用异步更新
  84. new_attempts = (user.failed_login_attempts or 0) + 1
  85. update_stmt = (
  86. update(User)
  87. .where(User.id == user.id)
  88. .values(
  89. failed_login_attempts=new_attempts,
  90. is_locked=(new_attempts >= 5) # 5次失败后锁定
  91. )
  92. )
  93. await self.db.execute(update_stmt)
  94. await self.db.commit()
  95. return None
  96. # 登录成功 - 重置失败尝试并更新最后登录时间
  97. update_stmt = (
  98. update(User)
  99. .where(User.id == user.id)
  100. .values(
  101. failed_login_attempts = 0,
  102. last_login = datetime.now(timezone.utc),
  103. is_locked = False
  104. )
  105. )
  106. await self.db.execute(update_stmt)
  107. await self.db.commit()
  108. # 4. 刷新用户对象以获取最新数据
  109. await self.db.refresh(user)
  110. return user
  111. # ==================== 用户操作 ====================
  112. async def update_user(self, user_id: int, update_data: dict) -> Optional[User]:
  113. """更新用户信息(异步)"""
  114. user = await self.get_user_by_id(user_id)
  115. if not user:
  116. return None
  117. for key, value in update_data.items():
  118. if hasattr(user, key) and value is not None:
  119. setattr(user, key, value)
  120. user.updated_at = datetime.now(timezone.utc)
  121. await self.db.commit()
  122. await self.db.refresh(user)
  123. return user
  124. async def change_password(self, user_id: int, current_password: str, new_password: str) -> bool:
  125. """修改用户密码(异步)"""
  126. # 获取用户
  127. user = await self.get_user_by_id(user_id)
  128. if not user:
  129. return False
  130. # 验证当前密码
  131. if not password_hasher.verify_password(current_password, user.hashed_password):
  132. return False
  133. # 验证新密码强度
  134. is_valid, error_msg = password_validator.validate_password_strength(new_password)
  135. if not is_valid:
  136. raise ValueError(error_msg)
  137. # 更新密码
  138. user.hashed_password = password_hasher.hash_password(new_password)
  139. user.last_password_change = datetime.now(timezone.utc)
  140. await self.db.commit()
  141. return True
  142. # ==================== 用户列表 ====================
  143. async def list_users(self, skip: int = 0, limit: int = 100, active_only: bool = True) -> List[User]:
  144. """获取用户列表(异步)"""
  145. stmt = select(User)
  146. if active_only:
  147. stmt = stmt.where(User.is_active == True)
  148. stmt = stmt.offset(skip).limit(limit)
  149. result = await self.db.execute(stmt)
  150. users = result.scalars().all()
  151. return users
  152. async def delete_user(self, user_id: int) -> bool:
  153. """删除用户(软删除)"""
  154. """删除/禁用用户(异步)"""
  155. # 通常我们不会真正删除,而是标记为禁用
  156. stmt = (
  157. update(User)
  158. .where(User.id == user_id)
  159. .values(is_active = False, updated_at = datetime.now(timezone.utc))
  160. )
  161. result = await self.db.execute(stmt)
  162. await self.db.commit()
  163. # 检查是否影响了一行
  164. return result.rowcount > 0
  165. async def count_users(self, active_only: bool = True) -> int:
  166. """统计用户数量(异步)"""
  167. if active_only:
  168. stmt = select(func.count()).where(User.is_active == True)
  169. else:
  170. stmt = select(func.count())
  171. result = await self.db.execute(stmt)
  172. return result.scalar() or 0
  173. async def update_last_login(self, user_id: int) -> None:
  174. """更新最后登录时间(异步)"""
  175. stmt = (
  176. update(User)
  177. .where(User.id == user_id)
  178. .values(last_login=datetime.now(timezone.utc))
  179. )
  180. await self.db.execute(stmt)
  181. await self.db.commit()