CaiYouHui后端fastapi实现

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. # app/api/admin.py
  2. from fastapi import APIRouter, Depends, HTTPException
  3. from fastapi.responses import JSONResponse
  4. from sqlalchemy.ext.asyncio import AsyncSession
  5. from sqlalchemy import text
  6. from ..dependencies.auth import get_current_user
  7. from ..models.user import User
  8. from ..database import get_async_db
  9. router = APIRouter(prefix="/admin", tags=["admin"])
  10. @router.get("/database/tables")
  11. async def get_database_tables(
  12. db: AsyncSession = Depends(get_async_db),
  13. current_user: User = Depends(get_current_user)
  14. ):
  15. """获取所有表(需要管理员权限)"""
  16. if not current_user.is_superuser:
  17. raise HTTPException(status_code=403, detail="需要管理员权限")
  18. try:
  19. # 使用 SQLAlchemy 执行原生 SQL
  20. result = db.execute(text("""
  21. SELECT name, type, sql
  22. FROM sqlite_master
  23. WHERE type='table'
  24. AND name NOT LIKE 'sqlite_%'
  25. ORDER BY name;
  26. """))
  27. tables = []
  28. for row in result:
  29. tables.append({
  30. "name": row[0],
  31. "type": row[1],
  32. "sql": row[2]
  33. })
  34. return JSONResponse(content={"tables": tables})
  35. except Exception as e:
  36. raise HTTPException(status_code=500, detail=str(e))
  37. @router.get("/database/table/{table_name}")
  38. async def get_table_data(
  39. table_name: str,
  40. limit: int = 100,
  41. db: AsyncSession = Depends(get_async_db),
  42. current_user: User = Depends(get_current_user)
  43. ):
  44. """获取表数据(需要管理员权限)"""
  45. if not current_user.is_superuser:
  46. raise HTTPException(status_code=403, detail="需要管理员权限")
  47. try:
  48. # 先验证表存在且可访问
  49. result = db.execute(text(f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}';"))
  50. if not result.fetchone():
  51. raise HTTPException(status_code=404, detail="表不存在")
  52. # 获取表结构
  53. columns_result = db.execute(text(f"PRAGMA table_info({table_name});"))
  54. columns = []
  55. for row in columns_result:
  56. columns.append({
  57. "cid": row[0],
  58. "name": row[1],
  59. "type": row[2],
  60. "notnull": bool(row[3]),
  61. "default_value": row[4],
  62. "pk": bool(row[5])
  63. })
  64. # 获取数据
  65. data_result = db.execute(text(f"SELECT * FROM {table_name} LIMIT {limit};"))
  66. # 获取列名
  67. column_names = [desc[0] for desc in data_result.cursor.description]
  68. # 获取数据行
  69. rows = []
  70. for row in data_result:
  71. row_dict = {}
  72. for i, value in enumerate(row):
  73. # 处理特殊类型(如 datetime)
  74. if hasattr(value, 'isoformat'):
  75. row_dict[column_names[i]] = value.isoformat()
  76. else:
  77. row_dict[column_names[i]] = value
  78. rows.append(row_dict)
  79. # 统计总数
  80. count_result = db.execute(text(f"SELECT COUNT(*) FROM {table_name};"))
  81. total_count = count_result.scalar()
  82. return JSONResponse(content={
  83. "table": table_name,
  84. "columns": columns,
  85. "data": rows,
  86. "total_count": total_count,
  87. "showing": len(rows)
  88. })
  89. except HTTPException:
  90. raise
  91. except Exception as e:
  92. raise HTTPException(status_code=500, detail=str(e))
  93. @router.post("/database/query")
  94. async def execute_custom_query(
  95. query: str,
  96. db: AsyncSession = Depends(get_async_db),
  97. current_user: User = Depends(get_current_user)
  98. ):
  99. """执行自定义查询(需要管理员权限,生产环境请谨慎)"""
  100. if not current_user.is_superuser:
  101. raise HTTPException(status_code=403, detail="需要管理员权限")
  102. # 安全限制:不允许修改操作
  103. query_lower = query.strip().lower()
  104. dangerous_keywords = ["drop", "delete", "update", "insert", "alter", "truncate"]
  105. if any(keyword in query_lower for keyword in dangerous_keywords):
  106. raise HTTPException(status_code=400, detail="只允许SELECT查询")
  107. try:
  108. result = db.execute(text(query))
  109. # 如果是查询语句
  110. if query_lower.startswith("select"):
  111. # 获取列名
  112. column_names = [desc[0] for desc in result.cursor.description]
  113. # 获取数据
  114. rows = []
  115. for row in result:
  116. row_dict = {}
  117. for i, value in enumerate(row):
  118. if hasattr(value, 'isoformat'):
  119. row_dict[column_names[i]] = value.isoformat()
  120. else:
  121. row_dict[column_names[i]] = value
  122. rows.append(row_dict)
  123. return JSONResponse(content={
  124. "query": query,
  125. "columns": column_names,
  126. "data": rows,
  127. "row_count": len(rows)
  128. })
  129. else:
  130. # 非查询语句
  131. db.commit()
  132. return JSONResponse(content={
  133. "query": query,
  134. "affected_rows": result.rowcount,
  135. "message": "执行成功"
  136. })
  137. except Exception as e:
  138. raise HTTPException(status_code=500, detail=str(e))