CaiYouHui后端fastapi实现

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. from fastapi import Depends, HTTPException, status
  2. from fastapi.security import OAuth2PasswordBearer
  3. from sqlalchemy.orm import Session
  4. from jose import JWTError, jwt
  5. from typing import Optional
  6. from ..database import get_db
  7. from ..models.user import User
  8. from ..schemas.token import TokenData
  9. from ..config import settings
  10. oauth2_scheme = OAuth2PasswordBearer(
  11. tokenUrl=f"{settings.API_V1_PREFIX}/auth/login",
  12. auto_error=False
  13. )
  14. async def get_current_user(
  15. token: Optional[str] = Depends(oauth2_scheme),
  16. db: Session = Depends(get_db)
  17. ) -> Optional[User]:
  18. """获取当前用户"""
  19. if not token:
  20. print("dfsfdsfdfdsfd")
  21. return None
  22. credentials_exception = HTTPException(
  23. status_code=status.HTTP_401_UNAUTHORIZED,
  24. detail="Could not validate credentials",
  25. headers={"WWW-Authenticate": "Bearer"},
  26. )
  27. try:
  28. payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
  29. username: str = payload.get("sub")
  30. token_type: str = payload.get("type")
  31. if username is None or token_type != "access":
  32. raise credentials_exception
  33. token_data = TokenData(username=username)
  34. except JWTError:
  35. raise credentials_exception
  36. user = db.query(User).filter(
  37. User.username == token_data.username,
  38. User.is_active == True
  39. ).first()
  40. if user is None:
  41. raise credentials_exception
  42. return user
  43. async def get_current_active_user(
  44. current_user: User = Depends(get_current_user)
  45. ) -> User:
  46. """获取当前活跃用户"""
  47. if not current_user:
  48. raise HTTPException(
  49. status_code=status.HTTP_401_UNAUTHORIZED,
  50. detail="Not authenticated"
  51. )
  52. if not current_user.is_active:
  53. raise HTTPException(
  54. status_code=status.HTTP_400_BAD_REQUEST,
  55. detail="Inactive user"
  56. )
  57. return current_user
  58. async def get_current_superuser(
  59. current_user: User = Depends(get_current_user)
  60. ) -> User:
  61. """获取超级用户"""
  62. if not current_user or not current_user.is_superuser:
  63. raise HTTPException(
  64. status_code=status.HTTP_403_FORBIDDEN,
  65. detail="Not enough permissions"
  66. )
  67. return current_user
  68. def require_auth(current_user: Optional[User] = Depends(get_current_user)) -> User:
  69. """要求认证的依赖"""
  70. if not current_user:
  71. raise HTTPException(
  72. status_code=status.HTTP_401_UNAUTHORIZED,
  73. detail="Not authenticated"
  74. )
  75. return current_user
  76. # 权限检查装饰器
  77. def require_permission(permission: str):
  78. """权限检查装饰器"""
  79. def permission_dependency(
  80. current_user: User = Depends(get_current_active_user)
  81. ) -> User:
  82. # 这里实现具体的权限检查逻辑
  83. # 可以从数据库或缓存中获取用户权限
  84. if not current_user.is_superuser:
  85. # 检查用户是否有特定权限
  86. user_permissions = [] # 从数据库获取
  87. if permission not in user_permissions:
  88. raise HTTPException(
  89. status_code=status.HTTP_403_FORBIDDEN,
  90. detail="Insufficient permissions"
  91. )
  92. return current_user
  93. return permission_dependency