CaiYouHui后端fastapi实现

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # app/database.py
  2. from sqlalchemy.orm import DeclarativeBase
  3. from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine, AsyncEngine
  4. from contextlib import asynccontextmanager
  5. from typing import AsyncGenerator
  6. from .config import settings
  7. # ====================== 1. 基类定义 ======================
  8. class Base(DeclarativeBase):
  9. """SQLAlchemy 2.0 声明式基类"""
  10. pass
  11. # ====================== 2. 异步URL转换 ======================
  12. def ensure_async_url(database_url: str) -> str:
  13. """确保使用异步数据库URL"""
  14. # 如果已经是异步URL,直接返回
  15. if any(x in database_url for x in ["+asyncpg", "+aiomysql", "+aiosqlite"]):
  16. return database_url
  17. # 转换同步URL为异步URL
  18. if database_url.startswith("postgresql://"):
  19. return database_url.replace("postgresql://", "postgresql+asyncpg://")
  20. elif database_url.startswith("mysql://"):
  21. return database_url.replace("mysql://", "mysql+aiomysql://")
  22. elif database_url.startswith("sqlite://"):
  23. return database_url.replace("sqlite://", "sqlite+aiosqlite://")
  24. else:
  25. raise ValueError(f"不支持的数据库类型: {database_url}")
  26. # ====================== 3. 异步引擎配置 ======================
  27. async_database_url = ensure_async_url(settings.DATABASE_URL)
  28. # 连接参数
  29. connect_args = {}
  30. if "sqlite+aiosqlite" in async_database_url:
  31. connect_args = {"check_same_thread": False}
  32. async_engine: AsyncEngine = create_async_engine(
  33. async_database_url,
  34. connect_args=connect_args,
  35. echo=settings.DEBUG,
  36. pool_size=20,
  37. max_overflow=10,
  38. pool_pre_ping=True,
  39. future=True, # 启用2.0特性
  40. )
  41. # ====================== 4. 异步会话工厂 ======================
  42. AsyncSessionLocal = async_sessionmaker(
  43. bind=async_engine,
  44. class_=AsyncSession,
  45. expire_on_commit=False,
  46. autoflush=False,
  47. )
  48. # ====================== 5. 异步依赖注入 ======================
  49. async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
  50. """
  51. FastAPI兼容的异步数据库依赖
  52. 注意:移除了@asynccontextmanager装饰器
  53. """
  54. session = AsyncSessionLocal()
  55. try:
  56. yield session
  57. await session.commit()
  58. except Exception:
  59. await session.rollback()
  60. raise
  61. finally:
  62. await session.close()
  63. # ====================== 6. 数据库初始化 ======================
  64. async def init_async_db():
  65. """异步初始化数据库"""
  66. from .models.user import User
  67. from sqlalchemy import select
  68. from .core.security import password_hasher
  69. # 创建表
  70. async with async_engine.begin() as conn:
  71. await conn.run_sync(Base.metadata.create_all)
  72. print("✅ 数据库表创建完成")
  73. # 创建默认管理员用户
  74. async with AsyncSessionLocal() as session:
  75. try:
  76. # 异步查询
  77. stmt = select(User).where(User.username == "admin")
  78. result = await session.execute(stmt)
  79. admin = result.scalar_one_or_none()
  80. if not admin:
  81. admin_user = User(
  82. username="admin",
  83. email="admin@caiyouhui.com",
  84. hashed_password=password_hasher.hash_password("Admin123!"),
  85. full_name="系统管理员",
  86. is_active=True,
  87. is_verified=True,
  88. is_superuser=True
  89. )
  90. session.add(admin_user)
  91. await session.commit()
  92. print("✅ 默认管理员用户已创建")
  93. except Exception as e:
  94. print(f"⚠️ 创建管理员用户时出错: {e}")
  95. await session.rollback()
  96. # ====================== 7. 连接健康检查 ======================
  97. async def check_async_connection():
  98. """检查数据库连接"""
  99. from sqlalchemy import text # 需要导入text
  100. try:
  101. async with async_engine.connect() as conn:
  102. await conn.execute(text("SELECT 1"))
  103. print("✅ 数据库连接正常")
  104. return True
  105. except Exception as e:
  106. print(f"❌ 数据库连接失败: {e}")
  107. return False
  108. # ====================== 8. 导出 ======================
  109. __all__ = [
  110. "Base",
  111. "async_engine",
  112. "AsyncSessionLocal",
  113. "get_async_db",
  114. "init_async_db",
  115. "check_async_connection",
  116. ]