| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139 |
- # app/database.py
- from sqlalchemy.orm import DeclarativeBase
- from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine, AsyncEngine
- from contextlib import asynccontextmanager
- from typing import AsyncGenerator
-
- from .config import settings
-
-
- # ====================== 1. 基类定义 ======================
- class Base(DeclarativeBase):
- """SQLAlchemy 2.0 声明式基类"""
- pass
-
-
- # ====================== 2. 异步URL转换 ======================
- def ensure_async_url(database_url: str) -> str:
- """确保使用异步数据库URL"""
- # 如果已经是异步URL,直接返回
- if any(x in database_url for x in ["+asyncpg", "+aiomysql", "+aiosqlite"]):
- return database_url
-
- # 转换同步URL为异步URL
- if database_url.startswith("postgresql://"):
- return database_url.replace("postgresql://", "postgresql+asyncpg://")
- elif database_url.startswith("mysql://"):
- return database_url.replace("mysql://", "mysql+aiomysql://")
- elif database_url.startswith("sqlite://"):
- return database_url.replace("sqlite://", "sqlite+aiosqlite://")
- else:
- raise ValueError(f"不支持的数据库类型: {database_url}")
-
-
- # ====================== 3. 异步引擎配置 ======================
- async_database_url = ensure_async_url(settings.DATABASE_URL)
-
- # 连接参数
- connect_args = {}
- if "sqlite+aiosqlite" in async_database_url:
- connect_args = {"check_same_thread": False}
-
- async_engine: AsyncEngine = create_async_engine(
- async_database_url,
- connect_args=connect_args,
- echo=settings.DEBUG,
- pool_size=20,
- max_overflow=10,
- pool_pre_ping=True,
- future=True, # 启用2.0特性
- )
-
-
- # ====================== 4. 异步会话工厂 ======================
- AsyncSessionLocal = async_sessionmaker(
- bind=async_engine,
- class_=AsyncSession,
- expire_on_commit=False,
- autoflush=False,
- )
-
-
- # ====================== 5. 异步依赖注入 ======================
- async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
- """
- FastAPI兼容的异步数据库依赖
- 注意:移除了@asynccontextmanager装饰器
- """
- session = AsyncSessionLocal()
- try:
- yield session
- await session.commit()
- except Exception:
- await session.rollback()
- raise
- finally:
- await session.close()
-
-
- # ====================== 6. 数据库初始化 ======================
- async def init_async_db():
- """异步初始化数据库"""
- from .models.user import User
- from sqlalchemy import select
- from .core.security import password_hasher
-
- # 创建表
- async with async_engine.begin() as conn:
- await conn.run_sync(Base.metadata.create_all)
- print("✅ 数据库表创建完成")
-
- # 创建默认管理员用户
- async with AsyncSessionLocal() as session:
- try:
- # 异步查询
- stmt = select(User).where(User.username == "admin")
- result = await session.execute(stmt)
- admin = result.scalar_one_or_none()
-
- if not admin:
- admin_user = User(
- username="admin",
- email="admin@caiyouhui.com",
- hashed_password=password_hasher.hash_password("Admin123!"),
- full_name="系统管理员",
- is_active=True,
- is_verified=True,
- is_superuser=True
- )
- session.add(admin_user)
- await session.commit()
- print("✅ 默认管理员用户已创建")
- except Exception as e:
- print(f"⚠️ 创建管理员用户时出错: {e}")
- await session.rollback()
-
-
- # ====================== 7. 连接健康检查 ======================
- async def check_async_connection():
- """检查数据库连接"""
- from sqlalchemy import text # 需要导入text
- try:
- async with async_engine.connect() as conn:
- await conn.execute(text("SELECT 1"))
- print("✅ 数据库连接正常")
- return True
- except Exception as e:
- print(f"❌ 数据库连接失败: {e}")
- return False
-
-
- # ====================== 8. 导出 ======================
- __all__ = [
- "Base",
- "async_engine",
- "AsyncSessionLocal",
- "get_async_db",
- "init_async_db",
- "check_async_connection",
- ]
|