# app/database.py from sqlalchemy.orm import DeclarativeBase from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine, AsyncEngine from contextlib import asynccontextmanager from typing import AsyncGenerator from .config import settings # ====================== 1. 基类定义 ====================== class Base(DeclarativeBase): """SQLAlchemy 2.0 声明式基类""" pass # ====================== 2. 异步URL转换 ====================== def ensure_async_url(database_url: str) -> str: """确保使用异步数据库URL""" # 如果已经是异步URL,直接返回 if any(x in database_url for x in ["+asyncpg", "+aiomysql", "+aiosqlite"]): return database_url # 转换同步URL为异步URL if database_url.startswith("postgresql://"): return database_url.replace("postgresql://", "postgresql+asyncpg://") elif database_url.startswith("mysql://"): return database_url.replace("mysql://", "mysql+aiomysql://") elif database_url.startswith("sqlite://"): return database_url.replace("sqlite://", "sqlite+aiosqlite://") else: raise ValueError(f"不支持的数据库类型: {database_url}") # ====================== 3. 异步引擎配置 ====================== async_database_url = ensure_async_url(settings.DATABASE_URL) # 连接参数 connect_args = {} if "sqlite+aiosqlite" in async_database_url: connect_args = {"check_same_thread": False} async_engine: AsyncEngine = create_async_engine( async_database_url, connect_args=connect_args, echo=settings.DEBUG, pool_size=20, max_overflow=10, pool_pre_ping=True, future=True, # 启用2.0特性 ) # ====================== 4. 异步会话工厂 ====================== AsyncSessionLocal = async_sessionmaker( bind=async_engine, class_=AsyncSession, expire_on_commit=False, autoflush=False, ) # ====================== 5. 异步依赖注入 ====================== async def get_async_db() -> AsyncGenerator[AsyncSession, None]: """ FastAPI兼容的异步数据库依赖 注意:移除了@asynccontextmanager装饰器 """ session = AsyncSessionLocal() try: yield session await session.commit() except Exception: await session.rollback() raise finally: await session.close() # ====================== 6. 数据库初始化 ====================== async def init_async_db(): """异步初始化数据库""" from .models.user import User from sqlalchemy import select from .core.security import password_hasher # 创建表 async with async_engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) print("✅ 数据库表创建完成") # 创建默认管理员用户 async with AsyncSessionLocal() as session: try: # 异步查询 stmt = select(User).where(User.username == "admin") result = await session.execute(stmt) admin = result.scalar_one_or_none() if not admin: admin_user = User( username="admin", email="admin@caiyouhui.com", hashed_password=password_hasher.hash_password("Admin123!"), full_name="系统管理员", is_active=True, is_verified=True, is_superuser=True ) session.add(admin_user) await session.commit() print("✅ 默认管理员用户已创建") except Exception as e: print(f"⚠️ 创建管理员用户时出错: {e}") await session.rollback() # ====================== 7. 连接健康检查 ====================== async def check_async_connection(): """检查数据库连接""" from sqlalchemy import text # 需要导入text try: async with async_engine.connect() as conn: await conn.execute(text("SELECT 1")) print("✅ 数据库连接正常") return True except Exception as e: print(f"❌ 数据库连接失败: {e}") return False # ====================== 8. 导出 ====================== __all__ = [ "Base", "async_engine", "AsyncSessionLocal", "get_async_db", "init_async_db", "check_async_connection", ]