CaiYouHui后端fastapi实现

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. # app/api/admin.py
  2. from fastapi import APIRouter, Depends, HTTPException
  3. from fastapi.responses import JSONResponse
  4. from sqlalchemy.orm import Session
  5. from sqlalchemy import text
  6. import json
  7. from ..dependencies.auth import get_current_user
  8. from ..models.user import User
  9. from ..database import get_db
  10. router = APIRouter(prefix="/admin", tags=["admin"])
  11. @router.get("/database/tables")
  12. async def get_database_tables(
  13. db: Session = Depends(get_db),
  14. current_user: User = Depends(get_current_user)
  15. ):
  16. """获取所有表(需要管理员权限)"""
  17. if not current_user.is_superuser:
  18. raise HTTPException(status_code=403, detail="需要管理员权限")
  19. try:
  20. # 使用 SQLAlchemy 执行原生 SQL
  21. result = db.execute(text("""
  22. SELECT name, type, sql
  23. FROM sqlite_master
  24. WHERE type='table'
  25. AND name NOT LIKE 'sqlite_%'
  26. ORDER BY name;
  27. """))
  28. tables = []
  29. for row in result:
  30. tables.append({
  31. "name": row[0],
  32. "type": row[1],
  33. "sql": row[2]
  34. })
  35. return JSONResponse(content={"tables": tables})
  36. except Exception as e:
  37. raise HTTPException(status_code=500, detail=str(e))
  38. @router.get("/database/table/{table_name}")
  39. async def get_table_data(
  40. table_name: str,
  41. limit: int = 100,
  42. db: Session = Depends(get_db),
  43. current_user: User = Depends(get_current_user)
  44. ):
  45. """获取表数据(需要管理员权限)"""
  46. if not current_user.is_superuser:
  47. raise HTTPException(status_code=403, detail="需要管理员权限")
  48. try:
  49. # 先验证表存在且可访问
  50. result = db.execute(text(f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}';"))
  51. if not result.fetchone():
  52. raise HTTPException(status_code=404, detail="表不存在")
  53. # 获取表结构
  54. columns_result = db.execute(text(f"PRAGMA table_info({table_name});"))
  55. columns = []
  56. for row in columns_result:
  57. columns.append({
  58. "cid": row[0],
  59. "name": row[1],
  60. "type": row[2],
  61. "notnull": bool(row[3]),
  62. "default_value": row[4],
  63. "pk": bool(row[5])
  64. })
  65. # 获取数据
  66. data_result = db.execute(text(f"SELECT * FROM {table_name} LIMIT {limit};"))
  67. # 获取列名
  68. column_names = [desc[0] for desc in data_result.cursor.description]
  69. # 获取数据行
  70. rows = []
  71. for row in data_result:
  72. row_dict = {}
  73. for i, value in enumerate(row):
  74. # 处理特殊类型(如 datetime)
  75. if hasattr(value, 'isoformat'):
  76. row_dict[column_names[i]] = value.isoformat()
  77. else:
  78. row_dict[column_names[i]] = value
  79. rows.append(row_dict)
  80. # 统计总数
  81. count_result = db.execute(text(f"SELECT COUNT(*) FROM {table_name};"))
  82. total_count = count_result.scalar()
  83. return JSONResponse(content={
  84. "table": table_name,
  85. "columns": columns,
  86. "data": rows,
  87. "total_count": total_count,
  88. "showing": len(rows)
  89. })
  90. except HTTPException:
  91. raise
  92. except Exception as e:
  93. raise HTTPException(status_code=500, detail=str(e))
  94. @router.post("/database/query")
  95. async def execute_custom_query(
  96. query: str,
  97. db: Session = Depends(get_db),
  98. current_user: User = Depends(get_current_user)
  99. ):
  100. """执行自定义查询(需要管理员权限,生产环境请谨慎)"""
  101. if not current_user.is_superuser:
  102. raise HTTPException(status_code=403, detail="需要管理员权限")
  103. # 安全限制:不允许修改操作
  104. query_lower = query.strip().lower()
  105. dangerous_keywords = ["drop", "delete", "update", "insert", "alter", "truncate"]
  106. if any(keyword in query_lower for keyword in dangerous_keywords):
  107. raise HTTPException(status_code=400, detail="只允许SELECT查询")
  108. try:
  109. result = db.execute(text(query))
  110. # 如果是查询语句
  111. if query_lower.startswith("select"):
  112. # 获取列名
  113. column_names = [desc[0] for desc in result.cursor.description]
  114. # 获取数据
  115. rows = []
  116. for row in result:
  117. row_dict = {}
  118. for i, value in enumerate(row):
  119. if hasattr(value, 'isoformat'):
  120. row_dict[column_names[i]] = value.isoformat()
  121. else:
  122. row_dict[column_names[i]] = value
  123. rows.append(row_dict)
  124. return JSONResponse(content={
  125. "query": query,
  126. "columns": column_names,
  127. "data": rows,
  128. "row_count": len(rows)
  129. })
  130. else:
  131. # 非查询语句
  132. db.commit()
  133. return JSONResponse(content={
  134. "query": query,
  135. "affected_rows": result.rowcount,
  136. "message": "执行成功"
  137. })
  138. except Exception as e:
  139. raise HTTPException(status_code=500, detail=str(e))