CaiYouHui后端fastapi实现

security.py 12KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394
  1. # app/core/security.py
  2. import hashlib
  3. import secrets
  4. import base64
  5. import hmac
  6. import time
  7. import json
  8. from typing import Optional, Dict, Any, Tuple
  9. from datetime import datetime, timezone, timedelta
  10. import re
  11. # 密码哈希工具
  12. class PasswordHasher:
  13. """密码哈希和验证工具(使用 PBKDF2)"""
  14. @staticmethod
  15. def hash_password(password: str) -> str:
  16. """哈希密码"""
  17. # 生成随机盐(16字节)
  18. salt = secrets.token_bytes(16)
  19. # 使用 PBKDF2-HMAC-SHA256
  20. dk = hashlib.pbkdf2_hmac(
  21. 'sha256',
  22. password.encode('utf-8'),
  23. salt,
  24. 100000 # 迭代次数
  25. )
  26. # 组合 salt + hash,然后 base64 编码
  27. combined = salt + dk
  28. return base64.b64encode(combined).decode('utf-8')
  29. @staticmethod
  30. def verify_password(password: str, hashed_password: str) -> bool:
  31. """验证密码"""
  32. try:
  33. # 解码 base64
  34. decoded = base64.b64decode(hashed_password.encode('utf-8'))
  35. # 提取 salt (前16字节) 和存储的 hash
  36. salt = decoded[:16]
  37. stored_hash = decoded[16:]
  38. # 用相同的盐计算输入密码的 hash
  39. dk = hashlib.pbkdf2_hmac(
  40. 'sha256',
  41. password.encode('utf-8'),
  42. salt,
  43. 100000
  44. )
  45. # 使用常量时间比较防止时序攻击
  46. return secrets.compare_digest(dk, stored_hash)
  47. except Exception:
  48. return False
  49. # JWT 管理器
  50. class JWTManager:
  51. """JWT 令牌管理工具"""
  52. def __init__(self, secret_key: str, algorithm: str = "HS256"):
  53. self.secret_key = secret_key
  54. self.algorithm = algorithm
  55. def create_access_token(
  56. self,
  57. data: Dict[str, Any],
  58. expires_delta: Optional[timedelta] = None,
  59. expires_minutes: Optional[int] = None
  60. ) -> str:
  61. """创建访问令牌
  62. Args:
  63. data: 要编码的数据
  64. expires_delta: 过期时间差
  65. expires_minutes: 过期分钟数
  66. Returns:
  67. str: JWT 令牌
  68. """
  69. return self.create_token(data, expires_delta, expires_minutes, token_type="access")
  70. def create_refresh_token(
  71. self,
  72. data: Dict[str, Any],
  73. expires_delta: Optional[timedelta] = None,
  74. expires_days: int = 7
  75. ) -> str:
  76. """创建刷新令牌
  77. Args:
  78. data: 要编码的数据
  79. expires_delta: 过期时间差
  80. expires_days: 过期天数
  81. Returns:
  82. str: JWT 刷新令牌
  83. """
  84. if expires_delta is None:
  85. expires_delta = timedelta(days=expires_days)
  86. return self.create_token(data, expires_delta, token_type="refresh")
  87. def create_verification_token(
  88. self,
  89. data: Dict[str, Any],
  90. expires_delta: Optional[timedelta] = None,
  91. expires_hours: int = 24
  92. ) -> str:
  93. """创建验证令牌(用于邮箱验证等)
  94. Args:
  95. data: 要编码的数据
  96. expires_delta: 过期时间差
  97. expires_hours: 过期小时数
  98. Returns:
  99. str: JWT 验证令牌
  100. """
  101. if expires_delta is None:
  102. expires_delta = timedelta(hours=expires_hours)
  103. return self.create_token(data, expires_delta, token_type="verify")
  104. def create_reset_token(
  105. self,
  106. data: Dict[str, Any],
  107. expires_delta: Optional[timedelta] = None,
  108. expires_minutes: int = 30
  109. ) -> str:
  110. """创建密码重置令牌
  111. Args:
  112. data: 要编码的数据
  113. expires_delta: 过期时间差
  114. expires_minutes: 过期分钟数
  115. Returns:
  116. str: JWT 重置令牌
  117. """
  118. if expires_delta is None:
  119. expires_delta = timedelta(minutes=expires_minutes)
  120. return self.create_token(data, expires_delta, token_type="reset")
  121. def create_token(
  122. self,
  123. data: Dict[str, Any],
  124. expires_delta: Optional[timedelta] = None,
  125. expires_minutes: Optional[int] = None,
  126. token_type: str = "access"
  127. ) -> str:
  128. """创建 JWT 令牌(通用方法)
  129. Args:
  130. data: 要编码的数据
  131. expires_delta: 过期时间差
  132. expires_minutes: 过期分钟数
  133. token_type: 令牌类型
  134. Returns:
  135. str: JWT 令牌
  136. """
  137. # 复制数据以避免修改原始数据
  138. payload = data.copy()
  139. # 设置过期时间
  140. if expires_delta:
  141. expire = datetime.now(timezone.utc) + expires_delta
  142. elif expires_minutes:
  143. expire = datetime.now(timezone.utc) + timedelta(minutes=expires_minutes)
  144. else:
  145. expire = datetime.now(timezone.utc) + timedelta(minutes=30) # 默认30分钟
  146. # 添加标准声明
  147. payload.update({
  148. "exp": int(expire.timestamp()),
  149. "iat": int(datetime.now(timezone.utc).timestamp()),
  150. "iss": "caiyouhui-api",
  151. "type": token_type
  152. })
  153. # 编码 header 和 payload
  154. header = json.dumps({"alg": self.algorithm, "typ": "JWT"})
  155. payload_str = json.dumps(payload)
  156. header_b64 = base64.urlsafe_b64encode(header.encode()).decode().rstrip('=')
  157. payload_b64 = base64.urlsafe_b64encode(payload_str.encode()).decode().rstrip('=')
  158. # 创建签名
  159. message = f"{header_b64}.{payload_b64}"
  160. signature = hmac.new(
  161. self.secret_key.encode(),
  162. message.encode(),
  163. hashlib.sha256
  164. ).digest()
  165. signature_b64 = base64.urlsafe_b64encode(signature).decode().rstrip('=')
  166. return f"{header_b64}.{payload_b64}.{signature_b64}"
  167. def verify_token(self, token: str) -> Optional[Dict[str, Any]]:
  168. """验证 JWT 令牌
  169. Args:
  170. token: JWT 令牌
  171. Returns:
  172. Optional[Dict]: 解码后的数据,如果令牌无效则返回 None
  173. """
  174. try:
  175. parts = token.split('.')
  176. if len(parts) != 3:
  177. return None
  178. header_b64, payload_b64, signature_b64 = parts
  179. # 验证签名
  180. message = f"{header_b64}.{payload_b64}"
  181. expected_signature = hmac.new(
  182. self.secret_key.encode(),
  183. message.encode(),
  184. hashlib.sha256
  185. ).digest()
  186. expected_signature_b64 = base64.urlsafe_b64encode(expected_signature).decode().rstrip('=')
  187. # 使用常量时间比较
  188. if not secrets.compare_digest(signature_b64, expected_signature_b64):
  189. return None
  190. # 解码 payload
  191. payload_json = base64.urlsafe_b64decode(payload_b64 + '=' * (4 - len(payload_b64) % 4)).decode()
  192. payload = json.loads(payload_json)
  193. # 检查过期时间
  194. if 'exp' in payload and payload['exp'] < int(time.time()):
  195. return None
  196. return payload
  197. except Exception:
  198. return None
  199. def decode_token(self, token: str) -> Optional[Dict[str, Any]]:
  200. """解码令牌(不验证签名,用于调试)
  201. Args:
  202. token: JWT 令牌
  203. Returns:
  204. Optional[Dict]: 解码后的数据
  205. """
  206. try:
  207. parts = token.split('.')
  208. if len(parts) != 3:
  209. return None
  210. _, payload_b64, _ = parts
  211. # 解码 payload
  212. payload_json = base64.urlsafe_b64decode(payload_b64 + '=' * (4 - len(payload_b64) % 4)).decode()
  213. return json.loads(payload_json)
  214. except Exception:
  215. return None
  216. # 密码验证工具
  217. class PasswordValidator:
  218. """密码强度验证工具"""
  219. @staticmethod
  220. def validate_password_strength(password: str) -> Tuple[bool, str]:
  221. """验证密码强度"""
  222. # 检查最小长度
  223. if len(password) < 8:
  224. return False, "密码必须至少8个字符"
  225. # 检查最大长度
  226. if len(password) > 100:
  227. return False, "密码不能超过100个字符"
  228. # 检查是否包含大写字母
  229. if not re.search(r'[A-Z]', password):
  230. return False, "密码必须包含至少一个大写字母"
  231. # 检查是否包含小写字母
  232. if not re.search(r'[a-z]', password):
  233. return False, "密码必须包含至少一个小写字母"
  234. # 检查是否包含数字
  235. if not re.search(r'\d', password):
  236. return False, "密码必须包含至少一个数字"
  237. # 检查是否包含特殊字符
  238. if not re.search(r'[!@#$%^&*(),.?":{}|<>]', password):
  239. return False, "密码必须包含至少一个特殊字符"
  240. return True, "密码强度足够"
  241. # 令牌工具
  242. class TokenUtils:
  243. """令牌相关工具"""
  244. @staticmethod
  245. def generate_verification_code(length: int = 6) -> str:
  246. """生成数字验证码"""
  247. digits = '0123456789'
  248. return ''.join(secrets.choice(digits) for _ in range(length))
  249. @staticmethod
  250. def generate_reset_token(length: int = 32) -> str:
  251. """生成密码重置令牌"""
  252. return secrets.token_urlsafe(length)
  253. @staticmethod
  254. def generate_api_key(length: int = 32) -> str:
  255. """生成 API 密钥"""
  256. import string
  257. alphabet = string.ascii_letters + string.digits
  258. return ''.join(secrets.choice(alphabet) for _ in range(length))
  259. # 导出的便捷函数
  260. def create_access_token(
  261. data: Dict[str, Any],
  262. secret_key: str,
  263. expires_delta: Optional[timedelta] = None,
  264. expires_minutes: int = 30,
  265. algorithm: str = "HS256"
  266. ) -> str:
  267. """创建访问令牌(便捷函数)
  268. Args:
  269. data: 要编码的数据
  270. secret_key: 密钥
  271. expires_delta: 过期时间差
  272. expires_minutes: 过期分钟数
  273. algorithm: 算法
  274. Returns:
  275. str: JWT 令牌
  276. """
  277. jwt_manager = JWTManager(secret_key, algorithm)
  278. return jwt_manager.create_access_token(data, expires_delta, expires_minutes)
  279. def verify_access_token(token: str, secret_key: str, algorithm: str = "HS256") -> Optional[Dict[str, Any]]:
  280. """验证访问令牌(便捷函数)
  281. Args:
  282. token: JWT 令牌
  283. secret_key: 密钥
  284. algorithm: 算法
  285. Returns:
  286. Optional[Dict]: 解码后的数据
  287. """
  288. jwt_manager = JWTManager(secret_key, algorithm)
  289. return jwt_manager.verify_token(token)
  290. def get_password_hash(password: str) -> str:
  291. """哈希密码(便捷函数)
  292. Args:
  293. password: 明文密码
  294. Returns:
  295. str: 哈希后的密码
  296. """
  297. return PasswordHasher.hash_password(password)
  298. # 创建全局实例
  299. password_hasher = PasswordHasher()
  300. password_validator = PasswordValidator()
  301. token_utils = TokenUtils()
  302. # 导出函数(保持向后兼容)
  303. verify_password = password_hasher.verify_password
  304. # 导出的函数和类
  305. __all__ = [
  306. # 便捷函数
  307. "create_access_token",
  308. "verify_access_token",
  309. "verify_password",
  310. "get_password_hash",
  311. # 类
  312. "PasswordHasher",
  313. "JWTManager",
  314. "PasswordValidator",
  315. "TokenUtils",
  316. # 实例
  317. "password_hasher",
  318. "password_validator",
  319. "token_utils",
  320. ]