CaiYouHui后端fastapi实现

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498
  1. from typing import Optional, Dict, Any, Tuple
  2. from datetime import datetime, timedelta
  3. from sqlalchemy.orm import Session
  4. from fastapi import HTTPException, status, BackgroundTasks
  5. import secrets
  6. import string
  7. from ..models.user import User
  8. from ..models.token import Token
  9. from ..schemas.auth import LoginRequest, TokenResponse
  10. from ..core.security import (
  11. verify_password,
  12. create_access_token,
  13. create_refresh_token,
  14. create_verification_token,
  15. create_reset_token,
  16. decode_token,
  17. generate_verification_code
  18. )
  19. from ..core.email import email_service
  20. from ..config import settings
  21. class AuthService:
  22. def __init__(self, db: Session):
  23. self.db = db
  24. async def register_user(
  25. self,
  26. user_data: Dict[str, Any],
  27. background_tasks: BackgroundTasks,
  28. ip_address: Optional[str] = None,
  29. user_agent: Optional[str] = None
  30. ) -> User:
  31. """注册新用户"""
  32. # 检查用户是否存在
  33. existing_user = self.db.query(User).filter(
  34. (User.username == user_data["username"]) |
  35. (User.email == user_data["email"])
  36. ).first()
  37. if existing_user:
  38. if existing_user.username == user_data["username"]:
  39. raise HTTPException(
  40. status_code=status.HTTP_400_BAD_REQUEST,
  41. detail="Username already registered"
  42. )
  43. else:
  44. raise HTTPException(
  45. status_code=status.HTTP_400_BAD_REQUEST,
  46. detail="Email already registered"
  47. )
  48. # 创建用户
  49. from ..core.security import get_password_hash
  50. hashed_password = get_password_hash(user_data["password"])
  51. user = User(
  52. username=user_data["username"],
  53. email=user_data["email"],
  54. hashed_password=hashed_password,
  55. first_name=user_data.get("first_name"),
  56. last_name=user_data.get("last_name"),
  57. is_active=False, # 需要邮箱验证
  58. is_verified=False
  59. )
  60. # 生成验证码
  61. verification_code = generate_verification_code()
  62. user.verification_code = verification_code
  63. user.verification_code_expires = datetime.utcnow() + timedelta(hours=24)
  64. self.db.add(user)
  65. self.db.commit()
  66. self.db.refresh(user)
  67. # 发送验证邮件(后台任务)
  68. verification_token = create_verification_token(user.email)
  69. verification_url = f"{settings.FRONTEND_URL}/verify-email?token={verification_token}"
  70. background_tasks.add_task(
  71. email_service.send_verification_email,
  72. user.email,
  73. user.username,
  74. verification_url,
  75. verification_code
  76. )
  77. return user
  78. async def login(
  79. self,
  80. login_data: LoginRequest,
  81. ip_address: Optional[str] = None,
  82. user_agent: Optional[str] = None
  83. ) -> Tuple[TokenResponse, User]:
  84. """用户登录"""
  85. user = self.db.query(User).filter(
  86. (User.username == login_data.username) |
  87. (User.email == login_data.username)
  88. ).first()
  89. if not user:
  90. raise HTTPException(
  91. status_code=status.HTTP_401_UNAUTHORIZED,
  92. detail="Incorrect username or password"
  93. )
  94. # 检查账户是否被锁定
  95. if user.is_locked and user.locked_until and user.locked_until > datetime.utcnow():
  96. raise HTTPException(
  97. status_code=status.HTTP_423_LOCKED,
  98. detail=f"Account is locked until {user.locked_until}"
  99. )
  100. # 验证密码
  101. if not verify_password(login_data.password, user.hashed_password):
  102. # 记录失败尝试
  103. user.failed_login_attempts += 1
  104. # 如果失败次数超过5次,锁定账户
  105. if user.failed_login_attempts >= 5:
  106. user.is_locked = True
  107. user.locked_until = datetime.utcnow() + timedelta(minutes=30)
  108. self.db.commit()
  109. raise HTTPException(
  110. status_code=status.HTTP_401_UNAUTHORIZED,
  111. detail="Incorrect username or password"
  112. )
  113. # 检查邮箱是否已验证
  114. if not user.is_verified:
  115. raise HTTPException(
  116. status_code=status.HTTP_403_FORBIDDEN,
  117. detail="Email not verified"
  118. )
  119. # 检查账户是否激活
  120. if not user.is_active:
  121. raise HTTPException(
  122. status_code=status.HTTP_403_FORBIDDEN,
  123. detail="Account is not active"
  124. )
  125. # 重置失败尝试次数
  126. user.failed_login_attempts = 0
  127. user.last_login = datetime.utcnow()
  128. user.is_locked = False
  129. user.locked_until = None
  130. self.db.commit()
  131. # 创建令牌
  132. access_token = create_access_token({"sub": user.username, "user_id": user.id})
  133. refresh_token = create_refresh_token({"sub": user.username, "user_id": user.id})
  134. # 保存刷新令牌到数据库
  135. refresh_token_entry = Token(
  136. token=refresh_token,
  137. token_type="refresh",
  138. expires_at=datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS),
  139. user_id=user.id,
  140. ip_address=ip_address,
  141. user_agent=user_agent
  142. )
  143. self.db.add(refresh_token_entry)
  144. self.db.commit()
  145. # 构建响应
  146. token_response = TokenResponse(
  147. access_token=access_token,
  148. refresh_token=refresh_token,
  149. expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
  150. user=user
  151. )
  152. return token_response, user
  153. async def verify_email(
  154. self,
  155. token: str,
  156. code: Optional[str] = None
  157. ) -> bool:
  158. """验证邮箱"""
  159. payload = decode_token(token)
  160. if not payload or payload.get("type") != "verify":
  161. raise HTTPException(
  162. status_code=status.HTTP_400_BAD_REQUEST,
  163. detail="Invalid verification token"
  164. )
  165. email = payload.get("email")
  166. if not email:
  167. raise HTTPException(
  168. status_code=status.HTTP_400_BAD_REQUEST,
  169. detail="Invalid token payload"
  170. )
  171. user = self.db.query(User).filter(User.email == email).first()
  172. if not user:
  173. raise HTTPException(
  174. status_code=status.HTTP_404_NOT_FOUND,
  175. detail="User not found"
  176. )
  177. if user.is_verified:
  178. raise HTTPException(
  179. status_code=status.HTTP_400_BAD_REQUEST,
  180. detail="Email already verified"
  181. )
  182. # 验证码验证
  183. if code:
  184. if (not user.verification_code or
  185. user.verification_code != code or
  186. not user.verification_code_expires or
  187. user.verification_code_expires < datetime.utcnow()):
  188. raise HTTPException(
  189. status_code=status.HTTP_400_BAD_REQUEST,
  190. detail="Invalid or expired verification code"
  191. )
  192. # 更新用户状态
  193. user.is_verified = True
  194. user.is_active = True
  195. user.verification_code = None
  196. user.verification_code_expires = None
  197. self.db.commit()
  198. # 发送欢迎邮件
  199. await email_service.send_welcome_email(user.email, user.username)
  200. return True
  201. async def resend_verification_email(
  202. self,
  203. email: str,
  204. background_tasks: BackgroundTasks
  205. ) -> bool:
  206. """重新发送验证邮件"""
  207. user = self.db.query(User).filter(User.email == email).first()
  208. if not user:
  209. # 出于安全考虑,即使用户不存在也返回成功
  210. return True
  211. if user.is_verified:
  212. raise HTTPException(
  213. status_code=status.HTTP_400_BAD_REQUEST,
  214. detail="Email already verified"
  215. )
  216. # 生成新的验证码
  217. verification_code = generate_verification_code()
  218. user.verification_code = verification_code
  219. user.verification_code_expires = datetime.utcnow() + timedelta(hours=24)
  220. self.db.commit()
  221. # 发送验证邮件
  222. verification_token = create_verification_token(user.email)
  223. verification_url = f"{settings.FRONTEND_URL}/verify-email?token={verification_token}"
  224. background_tasks.add_task(
  225. email_service.send_verification_email,
  226. user.email,
  227. user.username,
  228. verification_url,
  229. verification_code
  230. )
  231. return True
  232. async def request_password_reset(
  233. self,
  234. email: str,
  235. background_tasks: BackgroundTasks
  236. ) -> bool:
  237. """请求密码重置"""
  238. user = self.db.query(User).filter(User.email == email).first()
  239. if not user:
  240. # 出于安全考虑,即使用户不存在也返回成功
  241. return True
  242. if not user.is_active:
  243. raise HTTPException(
  244. status_code=status.HTTP_400_BAD_REQUEST,
  245. detail="Account is not active"
  246. )
  247. # 生成重置令牌
  248. reset_token = create_reset_token(user.email)
  249. reset_url = f"{settings.FRONTEND_URL}/reset-password?token={reset_token}"
  250. # 保存重置令牌到数据库
  251. reset_token_entry = Token(
  252. token=reset_token,
  253. token_type="reset",
  254. expires_at=datetime.utcnow() + timedelta(minutes=settings.RESET_TOKEN_EXPIRE_MINUTES),
  255. user_id=user.id
  256. )
  257. self.db.add(reset_token_entry)
  258. self.db.commit()
  259. # 发送重置邮件
  260. background_tasks.add_task(
  261. email_service.send_password_reset_email,
  262. user.email,
  263. user.username,
  264. reset_url
  265. )
  266. return True
  267. async def reset_password(
  268. self,
  269. token: str,
  270. new_password: str
  271. ) -> bool:
  272. """重置密码"""
  273. # 验证令牌
  274. payload = decode_token(token)
  275. if not payload or payload.get("type") != "reset":
  276. raise HTTPException(
  277. status_code=status.HTTP_400_BAD_REQUEST,
  278. detail="Invalid reset token"
  279. )
  280. email = payload.get("email")
  281. if not email:
  282. raise HTTPException(
  283. status_code=status.HTTP_400_BAD_REQUEST,
  284. detail="Invalid token payload"
  285. )
  286. # 检查令牌是否在数据库中且未过期
  287. token_entry = self.db.query(Token).filter(
  288. Token.token == token,
  289. Token.token_type == "reset",
  290. Token.is_revoked == False,
  291. Token.expires_at > datetime.utcnow()
  292. ).first()
  293. if not token_entry:
  294. raise HTTPException(
  295. status_code=status.HTTP_400_BAD_REQUEST,
  296. detail="Invalid or expired reset token"
  297. )
  298. user = self.db.query(User).filter(User.email == email).first()
  299. if not user:
  300. raise HTTPException(
  301. status_code=status.HTTP_404_NOT_FOUND,
  302. detail="User not found"
  303. )
  304. if not user.is_active:
  305. raise HTTPException(
  306. status_code=status.HTTP_400_BAD_REQUEST,
  307. detail="Account is not active"
  308. )
  309. # 更新密码
  310. from ..core.security import get_password_hash
  311. user.hashed_password = get_password_hash(new_password)
  312. user.last_password_change = datetime.utcnow()
  313. # 撤销所有现有令牌
  314. self.db.query(Token).filter(
  315. Token.user_id == user.id,
  316. Token.token_type.in_(["access", "refresh"])
  317. ).update({"is_revoked": True})
  318. # 标记重置令牌为已使用
  319. token_entry.is_revoked = True
  320. self.db.commit()
  321. return True
  322. async def refresh_token(
  323. self,
  324. refresh_token: str,
  325. ip_address: Optional[str] = None,
  326. user_agent: Optional[str] = None
  327. ) -> TokenResponse:
  328. """刷新访问令牌"""
  329. # 验证刷新令牌
  330. payload = decode_token(refresh_token)
  331. if not payload or payload.get("type") != "refresh":
  332. raise HTTPException(
  333. status_code=status.HTTP_401_UNAUTHORIZED,
  334. detail="Invalid refresh token"
  335. )
  336. # 检查令牌是否在数据库中且有效
  337. token_entry = self.db.query(Token).filter(
  338. Token.token == refresh_token,
  339. Token.token_type == "refresh",
  340. Token.is_revoked == False,
  341. Token.expires_at > datetime.utcnow()
  342. ).first()
  343. if not token_entry:
  344. raise HTTPException(
  345. status_code=status.HTTP_401_UNAUTHORIZED,
  346. detail="Invalid refresh token"
  347. )
  348. username = payload.get("sub")
  349. user_id = payload.get("user_id")
  350. if not username or not user_id:
  351. raise HTTPException(
  352. status_code=status.HTTP_401_UNAUTHORIZED,
  353. detail="Invalid token payload"
  354. )
  355. user = self.db.query(User).filter(
  356. User.id == user_id,
  357. User.username == username,
  358. User.is_active == True
  359. ).first()
  360. if not user:
  361. raise HTTPException(
  362. status_code=status.HTTP_401_UNAUTHORIZED,
  363. detail="User not found or inactive"
  364. )
  365. # 创建新的访问令牌
  366. access_token = create_access_token({"sub": user.username, "user_id": user.id})
  367. # 可选:创建新的刷新令牌(刷新令牌轮换)
  368. new_refresh_token = create_refresh_token({"sub": user.username, "user_id": user.id})
  369. # 保存新的刷新令牌
  370. new_refresh_token_entry = Token(
  371. token=new_refresh_token,
  372. token_type="refresh",
  373. expires_at=datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS),
  374. user_id=user.id,
  375. ip_address=ip_address,
  376. user_agent=user_agent
  377. )
  378. # 标记旧令牌为已撤销
  379. token_entry.is_revoked = True
  380. self.db.add(new_refresh_token_entry)
  381. self.db.commit()
  382. # 构建响应
  383. return TokenResponse(
  384. access_token=access_token,
  385. refresh_token=new_refresh_token,
  386. expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
  387. user=user
  388. )
  389. async def logout(
  390. self,
  391. refresh_token: str
  392. ) -> bool:
  393. """用户登出"""
  394. # 撤销刷新令牌
  395. token_entry = self.db.query(Token).filter(
  396. Token.token == refresh_token,
  397. Token.token_type == "refresh"
  398. ).first()
  399. if token_entry:
  400. token_entry.is_revoked = True
  401. self.db.commit()
  402. return True
  403. async def logout_all(
  404. self,
  405. user_id: int
  406. ) -> bool:
  407. """撤销用户的所有令牌"""
  408. self.db.query(Token).filter(
  409. Token.user_id == user_id,
  410. Token.token_type.in_(["access", "refresh"]),
  411. Token.is_revoked == False
  412. ).update({"is_revoked": True})
  413. self.db.commit()
  414. return True