CaiYouHui后端fastapi实现

auth.py 9.5KB


  1. # app/api/v1/auth.py
  2. from fastapi import APIRouter, HTTPException, status, Request, Depends
  3. from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
  4. from sqlalchemy.ext.asyncio import AsyncSession
  5. from app.schemas.user import UserCreate, UserLogin
  6. from app.schemas.token import TokenResponse, RefreshTokenRequest
  7. from app.core.security import (
  8. create_access_token,
  9. JWTManager,
  10. verify_access_token
  11. )
  12. from app.config import settings
  13. from app.database import get_async_db
  14. from app.models.user import User
  15. from app.services.user_service import UserService
  16. router = APIRouter(prefix="/auth", tags=["认证"])
  17. security = HTTPBearer()
  18. # 初始化 JWT 管理器
  19. jwt_manager = JWTManager(
  20. secret_key=settings.SECRET_KEY,
  21. algorithm=settings.ALGORITHM
  22. )
  23. import logging
  24. logger = logging.getLogger(__name__)
  25. @router.post("/register", response_model=TokenResponse, status_code=status.HTTP_201_CREATED)
  26. async def register(
  27. user_data: UserCreate,
  28. request: Request,
  29. db: AsyncSession = Depends(get_async_db)
  30. ):
  31. """用户注册"""
  32. user_service = UserService(db)
  33. try:
  34. # 创建用户
  35. user = await user_service.create_user(user_data)
  36. # 创建访问令牌
  37. access_token = create_access_token(
  38. data={
  39. "sub": user.username,
  40. "user_id": user.id,
  41. "email": user.email,
  42. "type": "access"
  43. },
  44. secret_key=settings.SECRET_KEY,
  45. expires_minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
  46. )
  47. # 创建刷新令牌
  48. refresh_token = jwt_manager.create_refresh_token(
  49. {
  50. "sub": user.username,
  51. "user_id": user.id,
  52. "type": "refresh"
  53. },
  54. expires_days=7
  55. )
  56. # 构建用户响应
  57. user_response = {
  58. "id": user.id,
  59. "username": user.username,
  60. "email": user.email,
  61. "full_name": user.full_name,
  62. "is_active": user.is_active,
  63. "is_verified": user.is_verified,
  64. "is_superuser": user.is_superuser,
  65. "created_at": user.created_at.isoformat() if user.created_at else None
  66. }
  67. return TokenResponse(
  68. access_token=access_token,
  69. token_type="bearer",
  70. expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
  71. refresh_token=refresh_token,
  72. user=user_response
  73. )
  74. except ValueError as e:
  75. raise HTTPException(
  76. status_code=status.HTTP_400_BAD_REQUEST,
  77. detail=str(e)
  78. )
  79. except Exception as e:
  80. raise HTTPException(
  81. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  82. detail=f"注册失败: {str(e)}"
  83. )
  84. @router.post("/login", response_model=TokenResponse)
  85. async def login(
  86. login_data: UserLogin,
  87. request: Request,
  88. db: AsyncSession = Depends(get_async_db)
  89. ):
  90. """用户登录"""
  91. user_service = UserService(db)
  92. logger.info("✅ 用户登录")
  93. # 验证用户
  94. user = await user_service.authenticate_user(
  95. login_data.username,
  96. login_data.password
  97. )
  98. if not user:
  99. raise HTTPException(
  100. status_code=status.HTTP_401_UNAUTHORIZED,
  101. detail="用户名或密码错误"
  102. )
  103. if not user.is_active:
  104. raise HTTPException(
  105. status_code=status.HTTP_403_FORBIDDEN,
  106. detail="用户账户已被禁用"
  107. )
  108. if user.is_locked:
  109. raise HTTPException(
  110. status_code=status.HTTP_423_LOCKED,
  111. detail="账户已被锁定,请联系管理员"
  112. )
  113. # 创建访问令牌
  114. access_token = jwt_manager.create_access_token(
  115. {
  116. "sub": user.username,
  117. "user_id": user.id,
  118. "email": user.email,
  119. "type": "access"
  120. },
  121. expires_minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
  122. )
  123. # 创建刷新令牌
  124. refresh_token = jwt_manager.create_refresh_token(
  125. {
  126. "sub": user.username,
  127. "user_id": user.id,
  128. "type": "refresh"
  129. },
  130. expires_days=7
  131. )
  132. # 构建用户响应
  133. user_response = {
  134. "id": user.id,
  135. "username": user.username,
  136. "email": user.email,
  137. "full_name": user.full_name,
  138. "is_active": user.is_active,
  139. "is_verified": user.is_verified,
  140. "is_superuser": user.is_superuser,
  141. "created_at": user.created_at.isoformat() if user.created_at else None,
  142. "last_login": user.last_login.isoformat() if user.last_login else None
  143. }
  144. return TokenResponse(
  145. access_token=access_token,
  146. token_type="bearer",
  147. expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
  148. refresh_token=refresh_token,
  149. user=user_response
  150. )
  151. @router.get("/me")
  152. async def get_current_user(
  153. credentials: HTTPAuthorizationCredentials = Depends(security),
  154. db: AsyncSession = Depends(get_async_db)
  155. ):
  156. """获取当前用户信息"""
  157. token = credentials.credentials
  158. logger.info("✅ 获取当前用户信息")
  159. # 验证令牌
  160. payload = verify_access_token(token, settings.SECRET_KEY)
  161. if not payload:
  162. raise HTTPException(
  163. status_code=status.HTTP_401_UNAUTHORIZED,
  164. detail="无效的令牌"
  165. )
  166. username = payload.get("sub")
  167. user_service = UserService(db)
  168. user = await user_service.get_user_by_username(username)
  169. if not user:
  170. raise HTTPException(
  171. status_code=status.HTTP_404_NOT_FOUND,
  172. detail="用户不存在"
  173. )
  174. if not user.is_active:
  175. raise HTTPException(
  176. status_code=status.HTTP_403_FORBIDDEN,
  177. detail="用户账户已被禁用"
  178. )
  179. return {
  180. "id": user.id,
  181. "username": user.username,
  182. "email": user.email,
  183. "full_name": user.full_name,
  184. "is_active": user.is_active,
  185. "is_verified": user.is_verified,
  186. "is_superuser": user.is_superuser,
  187. "created_at": user.created_at.isoformat() if user.created_at else None,
  188. "last_login": user.last_login.isoformat() if user.last_login else None,
  189. "avatar": user.avatar
  190. }
  191. @router.post("/refresh")
  192. async def refresh_token(
  193. refresh_data: RefreshTokenRequest,
  194. db: AsyncSession = Depends(get_async_db)
  195. ):
  196. """刷新访问令牌"""
  197. refresh_token = refresh_data.refresh_token
  198. if not refresh_token:
  199. raise HTTPException(
  200. status_code=status.HTTP_400_BAD_REQUEST,
  201. detail="缺少刷新令牌"
  202. )
  203. # 验证刷新令牌
  204. payload = jwt_manager.verify_token(refresh_token)
  205. if not payload or payload.get("type") != "refresh":
  206. raise HTTPException(
  207. status_code=status.HTTP_401_UNAUTHORIZED,
  208. detail="无效的刷新令牌"
  209. )
  210. username = payload.get("sub")
  211. user_service = UserService(db)
  212. user = await user_service.get_user_by_username(username)
  213. if not user:
  214. raise HTTPException(
  215. status_code=status.HTTP_404_NOT_FOUND,
  216. detail="用户不存在"
  217. )
  218. if not user.is_active:
  219. raise HTTPException(
  220. status_code=status.HTTP_403_FORBIDDEN,
  221. detail="用户账户已被禁用"
  222. )
  223. # 创建新的访问令牌
  224. access_token = jwt_manager.create_access_token(
  225. {
  226. "sub": user.username,
  227. "user_id": user.id,
  228. "email": user.email,
  229. "type": "access"
  230. },
  231. expires_minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
  232. )
  233. return {
  234. "access_token": access_token,
  235. "token_type": "bearer",
  236. "expires_in": settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60
  237. }
  238. @router.post("/logout")
  239. async def logout(
  240. refresh_data: RefreshTokenRequest,
  241. db: AsyncSession = Depends(get_async_db),
  242. credentials: HTTPAuthorizationCredentials = Depends(security)
  243. ):
  244. """用户登出"""
  245. # 这里可以记录令牌到黑名单,或者简单返回成功
  246. # 在实际应用中,可能需要将令牌存储到Redis黑名单
  247. return {"message": "登出成功"}
  248. @router.get("/test-db")
  249. async def test_database(db: AsyncSession = Depends(get_async_db)):
  250. """测试数据库连接和用户统计"""
  251. user_service = UserService(db)
  252. user_count = await user_service.count_users()
  253. all_users = await user_service.list_users(limit=10)
  254. users_info = []
  255. for user in all_users:
  256. users_info.append({
  257. "id": user.id,
  258. "username": user.username,
  259. "email": user.email,
  260. "is_active": user.is_active
  261. })
  262. return {
  263. "database": "SQLite",
  264. "user_count": user_count,
  265. "users": users_info,
  266. "message": "数据库连接正常"
  267. }
  268. @router.get("/test")
  269. async def test_auth():
  270. """测试认证 API"""
  271. return {
  272. "message": "认证 API 正常运行(使用 SQLite 数据库)",
  273. "endpoints": {
  274. "POST /register": "用户注册",
  275. "POST /login": "用户登录",
  276. "GET /me": "获取当前用户",
  277. "POST /refresh": "刷新令牌",
  278. "POST /logout": "用户登出",
  279. "GET /test-db": "测试数据库",
  280. "GET /test": "测试接口"
  281. },
  282. "test_credentials": {
  283. "username": "admin",
  284. "password": "Admin123!"
  285. }
  286. }