CaiYouHui后端fastapi实现


  1. from typing import Optional, Dict, Any, Tuple
  2. from datetime import datetime, timezone, timedelta
  3. from sqlalchemy import select, update, delete, func, or_
  4. from sqlalchemy.ext.asyncio import AsyncSession
  5. from fastapi import HTTPException, status, BackgroundTasks
  6. import secrets
  7. import string
  8. from app.models.user import User
  9. from app.models.token import Token
  10. from ..schemas.auth import LoginRequest, TokenResponse
  11. from app.core.security import (
  12. verify_password,
  13. create_access_token,
  14. create_refresh_token,
  15. create_verification_token,
  16. create_reset_token,
  17. decode_token,
  18. generate_verification_code
  19. )
  20. from app.core.email import email_service
  21. from ..config import settings
  22. class AuthService:
  23. def __init__(self, db: AsyncSession):
  24. self.db = db
  25. async def register_user(
  26. self,
  27. user_data: Dict[str, Any],
  28. background_tasks: BackgroundTasks,
  29. ip_address: Optional[str] = None,
  30. user_agent: Optional[str] = None
  31. ) -> User:
  32. """注册新用户(SQLAlchemy 2.0异步版)"""
  33. # 1. 检查用户是否存在 - 使用异步查询
  34. stmt = select(User).where(
  35. or_(
  36. User.username == user_data["username"],
  37. User.email == user_data["email"]
  38. )
  39. )
  40. result = await self.db.execute(stmt)
  41. existing_user = result.scalar_one_or_none()
  42. if existing_user:
  43. if existing_user.username == user_data["username"]:
  44. raise HTTPException(
  45. status_code=status.HTTP_400_BAD_REQUEST,
  46. detail="Username already registered"
  47. )
  48. else:
  49. raise HTTPException(
  50. status_code=status.HTTP_400_BAD_REQUEST,
  51. detail="Email already registered"
  52. )
  53. # 2. 创建用户
  54. from ..core.security import get_password_hash
  55. hashed_password = get_password_hash(user_data["password"])
  56. user = User(
  57. username=user_data["username"],
  58. email=user_data["email"],
  59. hashed_password=hashed_password,
  60. first_name=user_data.get("first_name"),
  61. last_name=user_data.get("last_name"),
  62. is_active=False, # 需要邮箱验证
  63. is_verified=False
  64. )
  65. # 3. 生成验证码
  66. verification_code = generate_verification_code()
  67. user.verification_code = verification_code
  68. user.verification_code_expires = datetime.now(timezone.utc) + timedelta(hours=24)
  69. # 4. 异步保存用户
  70. self.db.add(user)
  71. await self.db.commit()
  72. await self.db.refresh(user)
  73. # 5. 发送验证邮件(后台任务)
  74. verification_token = create_verification_token(user.email)
  75. verification_url = f"{settings.FRONTEND_URL}/verify-email?token={verification_token}"
  76. background_tasks.add_task(
  77. email_service.send_verification_email,
  78. user.email,
  79. user.username,
  80. verification_url,
  81. verification_code
  82. )
  83. return user
  84. async def login(
  85. self,
  86. login_data: LoginRequest,
  87. ip_address: Optional[str] = None,
  88. user_agent: Optional[str] = None
  89. ) -> Tuple[TokenResponse, User]:
  90. """用户登录(SQLAlchemy 2.0异步版)"""
  91. # 1. 查找用户 - 异步查询
  92. stmt = select(User).where(
  93. or_(
  94. User.username == login_data.username,
  95. User.email == login_data.username
  96. )
  97. )
  98. result = await self.db.execute(stmt)
  99. user = result.scalar_one_or_none()
  100. if not user:
  101. raise HTTPException(
  102. status_code=status.HTTP_401_UNAUTHORIZED,
  103. detail="Incorrect username or password"
  104. )
  105. # 2. 检查账户是否被锁定
  106. if user.is_locked and user.locked_until and user.locked_until > datetime.now(timezone.utc):
  107. raise HTTPException(
  108. status_code=status.HTTP_423_LOCKED,
  109. detail=f"Account is locked until {user.locked_until}"
  110. )
  111. # 3. 验证密码
  112. if not verify_password(login_data.password, user.hashed_password):
  113. # 记录失败尝试 - 异步更新
  114. new_attempts = (user.failed_login_attempts or 0) + 1
  115. update_data = {
  116. "failed_login_attempts": new_attempts,
  117. "updated_at": datetime.now(timezone.utc)
  118. }
  119. # 如果失败次数超过5次,锁定账户
  120. if new_attempts >= 5:
  121. update_data.update({
  122. "is_locked" : True,
  123. "locked_until": datetime.now(timezone.utc) + timedelta(minutes=30)
  124. })
  125. update_stmt = (
  126. update(User)
  127. .where(User.id == user.id)
  128. .values(**update_data)
  129. )
  130. await self.db.execute(update_stmt)
  131. await self.db.commit()
  132. raise HTTPException(
  133. status_code=status.HTTP_401_UNAUTHORIZED,
  134. detail="Incorrect username or password"
  135. )
  136. # 4. 检查邮箱是否已验证
  137. if not user.is_verified:
  138. raise HTTPException(
  139. status_code=status.HTTP_403_FORBIDDEN,
  140. detail="Email not verified"
  141. )
  142. # 5. 检查账户是否激活
  143. if not user.is_active:
  144. raise HTTPException(
  145. status_code=status.HTTP_403_FORBIDDEN,
  146. detail="Account is not active"
  147. )
  148. # 6. 重置失败尝试次数 - 异步更新
  149. reset_stmt = (
  150. update(User)
  151. .where(User.id == user.id)
  152. .values(
  153. failed_login_attempts = 0,
  154. last_login = datetime.now(timezone.utc),
  155. is_locked = False,
  156. locked_until = None,
  157. )
  158. )
  159. await self.db.execute(reset_stmt)
  160. # 7. 创建令牌
  161. access_token = create_access_token({"sub": user.username, "user_id": user.id})
  162. refresh_token = create_refresh_token({"sub": user.username, "user_id": user.id})
  163. # 8. 保存刷新令牌到数据库 - 异步添加
  164. refresh_token_entry = Token(
  165. token=refresh_token,
  166. token_type="refresh",
  167. expires_at=datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS),
  168. user_id=user.id,
  169. ip_address=ip_address,
  170. user_agent=user_agent
  171. )
  172. self.db.add(refresh_token_entry)
  173. # 9. 提交所有更改
  174. await self.db.commit()
  175. # 10. 刷新用户对象以获取最新数据
  176. await self.db.refresh(user)
  177. # 11. 构建响应
  178. token_response = TokenResponse(
  179. access_token=access_token,
  180. refresh_token=refresh_token,
  181. expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
  182. user=user
  183. )
  184. return token_response, user
  185. async def verify_email(
  186. self,
  187. token: str,
  188. code: Optional[str] = None
  189. ) -> bool:
  190. """验证邮箱(SQLAlchemy 2.0异步版)"""
  191. # 1. 解码令牌
  192. payload = decode_token(token)
  193. if not payload or payload.get("type") != "verify":
  194. raise HTTPException(
  195. status_code=status.HTTP_400_BAD_REQUEST,
  196. detail="Invalid verification token"
  197. )
  198. email = payload.get("email")
  199. if not email:
  200. raise HTTPException(
  201. status_code=status.HTTP_400_BAD_REQUEST,
  202. detail="Invalid token payload"
  203. )
  204. # 2. 查找用户 - 异步查询
  205. stmt = select(User).where(User.email == email)
  206. result = await self.db.execute(stmt)
  207. user = result.scalar_one_or_none()
  208. if not user:
  209. raise HTTPException(
  210. status_code=status.HTTP_404_NOT_FOUND,
  211. detail="User not found"
  212. )
  213. if user.is_verified:
  214. raise HTTPException(
  215. status_code=status.HTTP_400_BAD_REQUEST,
  216. detail="Email already verified"
  217. )
  218. # 3. 验证码验证
  219. if code:
  220. now = datetime.now(timezone.utc)
  221. if (not user.verification_code or
  222. user.verification_code != code or
  223. not user.verification_code_expires or
  224. user.verification_code_expires < now):
  225. raise HTTPException(
  226. status_code=status.HTTP_400_BAD_REQUEST,
  227. detail="Invalid or expired verification code"
  228. )
  229. # 4. 更新用户状态 - 异步更新
  230. update_stmt = (
  231. update(User)
  232. .where(User.id == user.id)
  233. .values(
  234. is_verified = True,
  235. is_active = True,
  236. verification_code = None,
  237. verification_code_expires = None,
  238. updated_at = datetime.now(timezone.utc)
  239. )
  240. )
  241. await self.db.execute(update_stmt)
  242. await self.db.commit()
  243. # 5. 发送欢迎邮件(假设是异步的)
  244. await email_service.send_welcome_email(user.email, user.username)
  245. return True
  246. async def resend_verification_email(
  247. self,
  248. email: str,
  249. background_tasks: BackgroundTasks
  250. ) -> bool:
  251. """重新发送验证邮件(SQLAlchemy 2.0异步版)"""
  252. # 1. 查找用户 - 异步查询
  253. stmt = select(User).where(User.email == email)
  254. result = await self.db.execute(stmt)
  255. user = result.scalar_one_or_none()
  256. if not user:
  257. # 出于安全考虑,即使用户不存在也返回成功
  258. return True
  259. if user.is_verified:
  260. raise HTTPException(
  261. status_code=status.HTTP_400_BAD_REQUEST,
  262. detail="Email already verified"
  263. )
  264. # 2. 生成新的验证码
  265. verification_code = generate_verification_code()
  266. # 3. 更新用户验证信息 - 异步更新
  267. update_stmt = (
  268. update(User)
  269. .where(User.id == user.id)
  270. .values(
  271. verification_code = verification_code,
  272. verification_code_expires = datetime.now(timezone.utc) + timedelta(hours=24),
  273. updated_at=datetime.now(timezone.utc)
  274. )
  275. )
  276. await self.db.execute(update_stmt)
  277. await self.db.commit()
  278. # 4. 发送验证邮件
  279. verification_token = create_verification_token(user.email)
  280. verification_url = f"{settings.FRONTEND_URL}/verify-email?token={verification_token}"
  281. background_tasks.add_task(
  282. email_service.send_verification_email,
  283. user.email,
  284. user.username,
  285. verification_url,
  286. verification_code
  287. )
  288. return True
  289. async def request_password_reset(
  290. self,
  291. email: str,
  292. background_tasks: BackgroundTasks
  293. ) -> bool:
  294. """请求密码重置(SQLAlchemy 2.0异步版)"""
  295. # 1. 查找用户 - 异步查询
  296. stmt = select(User).where(User.email == email)
  297. result = await self.db.execute(stmt)
  298. user = result.scalar_one_or_none()
  299. if not user:
  300. # 出于安全考虑,即使用户不存在也返回成功
  301. return True
  302. if not user.is_active:
  303. raise HTTPException(
  304. status_code=status.HTTP_400_BAD_REQUEST,
  305. detail="Account is not active"
  306. )
  307. # 2. 生成重置令牌
  308. reset_token = create_reset_token(user.email)
  309. reset_url = f"{settings.FRONTEND_URL}/reset-password?token={reset_token}"
  310. # 3. 保存重置令牌到数据库 - 异步添加
  311. reset_token_entry = Token(
  312. token=reset_token,
  313. token_type="reset",
  314. expires_at=datetime.now(timezone.utc) + timedelta(minutes=settings.RESET_TOKEN_EXPIRE_MINUTES),
  315. user_id=user.id
  316. )
  317. self.db.add(reset_token_entry)
  318. await self.db.commit()
  319. # 4. 发送重置邮件
  320. background_tasks.add_task(
  321. email_service.send_password_reset_email,
  322. user.email,
  323. user.username,
  324. reset_url
  325. )
  326. return True
  327. async def reset_password(
  328. self,
  329. token: str,
  330. new_password: str
  331. ) -> bool:
  332. """重置密码(SQLAlchemy 2.0异步版)"""
  333. # 1. 验证令牌
  334. payload = decode_token(token)
  335. if not payload or payload.get("type") != "reset":
  336. raise HTTPException(
  337. status_code=status.HTTP_400_BAD_REQUEST,
  338. detail="Invalid reset token"
  339. )
  340. email = payload.get("email")
  341. if not email:
  342. raise HTTPException(
  343. status_code=status.HTTP_400_BAD_REQUEST,
  344. detail="Invalid token payload"
  345. )
  346. # 2. 检查令牌是否在数据库中且未过期 - 异步查询
  347. stmt = select(Token).where(
  348. Token.token == token,
  349. Token.token_type == "reset",
  350. Token.is_revoked == False,
  351. Token.expires_at > datetime.now(timezone.utc)
  352. )
  353. result = await self.db.execute(stmt)
  354. token_entry = result.scalar_one_or_none()
  355. if not token_entry:
  356. raise HTTPException(
  357. status_code=status.HTTP_400_BAD_REQUEST,
  358. detail="Invalid or expired reset token"
  359. )
  360. # 3. 查找用户 - 异步查询
  361. user_stmt = select(User).where(User.email == email)
  362. user_result = await self.db.execute(user_stmt)
  363. user = user_result.scalar_one_or_none()
  364. if not user:
  365. raise HTTPException(
  366. status_code=status.HTTP_404_NOT_FOUND,
  367. detail="User not found"
  368. )
  369. if not user.is_active:
  370. raise HTTPException(
  371. status_code=status.HTTP_400_BAD_REQUEST,
  372. detail="Account is not active"
  373. )
  374. # 4. 更新密码 - 异步更新
  375. from ..core.security import get_password_hash
  376. update_user_stmt = (
  377. update(User)
  378. .where(User.id == user.id)
  379. .values(
  380. hashed_password = get_password_hash(new_password),
  381. last_password_change = datetime.now(timezone.utc),
  382. updated_at = datetime.now(timezone.utc)
  383. )
  384. )
  385. # 5. 撤销所有现有令牌 - 异步更新
  386. revoke_tokens_stmt = (
  387. update(Token)
  388. .where(
  389. Token.user_id == user.id,
  390. Token.token_type.in_(["access", "refresh"])
  391. )
  392. .values(is_revoked = True)
  393. )
  394. # 6. 标记重置令牌为已使用 - 异步更新
  395. revoke_reset_token_stmt = (
  396. update(Token)
  397. .where(Token.token == token)
  398. .values(is_revoked=True)
  399. )
  400. # 执行所有更新
  401. await self.db.execute(update_user_stmt)
  402. await self.db.execute(revoke_tokens_stmt)
  403. await self.db.execute(revoke_reset_token_stmt)
  404. await self.db.commit()
  405. return True
  406. async def refresh_token(
  407. self,
  408. refresh_token: str,
  409. ip_address: Optional[str] = None,
  410. user_agent: Optional[str] = None
  411. ) -> TokenResponse:
  412. """刷新访问令牌(SQLAlchemy 2.0异步版)"""
  413. # 1. 验证刷新令牌
  414. payload = decode_token(refresh_token)
  415. if not payload or payload.get("type") != "refresh":
  416. raise HTTPException(
  417. status_code=status.HTTP_401_UNAUTHORIZED,
  418. detail="Invalid refresh token"
  419. )
  420. # 2. 检查令牌是否在数据库中且有效 - 异步查询
  421. token_stmt = select(Token).where(
  422. Token.token == refresh_token,
  423. Token.token_type == "refresh",
  424. Token.is_revoked == False,
  425. Token.expires_at > datetime.now(timezone.utc)
  426. )
  427. token_result = await self.db.execute(token_stmt)
  428. token_entry = token_result.scalar_one_or_none()
  429. if not token_entry:
  430. raise HTTPException(
  431. status_code=status.HTTP_401_UNAUTHORIZED,
  432. detail="Invalid refresh token"
  433. )
  434. username = payload.get("sub")
  435. user_id = payload.get("user_id")
  436. if not username or not user_id:
  437. raise HTTPException(
  438. status_code=status.HTTP_401_UNAUTHORIZED,
  439. detail="Invalid token payload"
  440. )
  441. # 3. 查找用户 - 异步查询
  442. user_stmt = select(User).where(
  443. User.id == user_id,
  444. User.username == username,
  445. User.is_active == True
  446. )
  447. user_result = await self.db.execute(user_stmt)
  448. user = user_result.scalar_one_or_none()
  449. if not user:
  450. raise HTTPException(
  451. status_code=status.HTTP_401_UNAUTHORIZED,
  452. detail="User not found or inactive"
  453. )
  454. # 4. 创建新的访问令牌
  455. access_token = create_access_token({"sub": user.username, "user_id": user.id})
  456. # 5. 创建新的刷新令牌(刷新令牌轮换)
  457. new_refresh_token = create_refresh_token({"sub": user.username, "user_id": user.id})
  458. # 6. 保存新的刷新令牌 - 异步添加
  459. new_refresh_token_entry = Token(
  460. token=new_refresh_token,
  461. token_type = "refresh",
  462. expires_at = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS),
  463. user_id = user.id,
  464. ip_address = ip_address,
  465. user_agent = user_agent
  466. )
  467. # 7. 标记旧令牌为已撤销 - 异步更新
  468. revoke_stmt = (
  469. update(Token)
  470. .where(Token.token == refresh_token)
  471. .values(is_revoked = True)
  472. )
  473. # 执行更新和添加操作
  474. await self.db.execute(revoke_stmt)
  475. self.db.add(new_refresh_token_entry)
  476. await self.db.commit()
  477. # 8. 构建响应
  478. return TokenResponse(
  479. access_token=access_token,
  480. refresh_token=new_refresh_token,
  481. expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
  482. user=user
  483. )
  484. async def logout(
  485. self,
  486. refresh_token: str
  487. ) -> bool:
  488. """用户登出(SQLAlchemy 2.0异步版)"""
  489. # 1. 查找刷新令牌 - 异步查询
  490. stmt = select(Token).where(
  491. Token.token == refresh_token,
  492. Token.token_type == "refresh"
  493. )
  494. result = await self.db.execute(stmt)
  495. token_entry = result.scalar_one_or_none()
  496. # 2. 如果找到令牌,则撤销它 - 异步更新
  497. if token_entry:
  498. update_stmt = (
  499. update(Token)
  500. .where(Token.token == refresh_token)
  501. .values(is_revoked = True)
  502. )
  503. await self.db.execute(update_stmt)
  504. await self.db.commit()
  505. return True
  506. async def logout_all(
  507. self,
  508. user_id: int
  509. ) -> bool:
  510. """撤销用户的所有令牌(SQLAlchemy 2.0异步版)"""
  511. # 撤销用户的所有访问和刷新令牌 - 异步更新
  512. stmt = (
  513. update(Token)
  514. .where(
  515. Token.user_id == user_id,
  516. Token.token_type.in_(["access", "refresh"]),
  517. Token.is_revoked == False
  518. )
  519. .values(is_revoked = True)
  520. )
  521. result = await self.db.execute(stmt)
  522. await self.db.commit()
  523. return result.rowcount > 0