Просмотр исходного кода

修改接口调用数据库操作实现异步

root 1 месяц назад
Родитель
Сommit
41def1c900

+ 15
- 0
.gitignore Просмотреть файл

1
+# Python缓存文件
2
+# __pycache__/
3
+# *.py[cod]
4
+# *$py.class
5
+# *.so
6
+
7
+# 特定目录下的缓存文件(可选)
8
+app/__pycache__
9
+app/*/__pycache__/
10
+app/*/*/__pycache__/
11
+
12
+# 或者更通用的匹配所有__pycache__
13
+# **/__pycache__/
14
+
15
+*.db

+ 5
- 6
app/api/v1/admin.py Просмотреть файл

1
 # app/api/admin.py
1
 # app/api/admin.py
2
 from fastapi import APIRouter, Depends, HTTPException
2
 from fastapi import APIRouter, Depends, HTTPException
3
 from fastapi.responses import JSONResponse
3
 from fastapi.responses import JSONResponse
4
-from sqlalchemy.orm import Session
4
+from sqlalchemy.ext.asyncio import AsyncSession
5
 from sqlalchemy import text
5
 from sqlalchemy import text
6
-import json
7
 
6
 
8
 from ..dependencies.auth import get_current_user
7
 from ..dependencies.auth import get_current_user
9
 from ..models.user import User
8
 from ..models.user import User
10
-from ..database import get_db
9
+from ..database import get_async_db
11
 
10
 
12
 router = APIRouter(prefix="/admin", tags=["admin"])
11
 router = APIRouter(prefix="/admin", tags=["admin"])
13
 
12
 
14
 @router.get("/database/tables")
13
 @router.get("/database/tables")
15
 async def get_database_tables(
14
 async def get_database_tables(
16
-    db: Session = Depends(get_db),
15
+    db: AsyncSession = Depends(get_async_db),
17
     current_user: User = Depends(get_current_user)
16
     current_user: User = Depends(get_current_user)
18
 ):
17
 ):
19
     """获取所有表(需要管理员权限)"""
18
     """获取所有表(需要管理员权限)"""
47
 async def get_table_data(
46
 async def get_table_data(
48
     table_name: str,
47
     table_name: str,
49
     limit: int = 100,
48
     limit: int = 100,
50
-    db: Session = Depends(get_db),
49
+    db: AsyncSession = Depends(get_async_db),
51
     current_user: User = Depends(get_current_user)
50
     current_user: User = Depends(get_current_user)
52
 ):
51
 ):
53
     """获取表数据(需要管理员权限)"""
52
     """获取表数据(需要管理员权限)"""
111
 @router.post("/database/query")
110
 @router.post("/database/query")
112
 async def execute_custom_query(
111
 async def execute_custom_query(
113
     query: str,
112
     query: str,
114
-    db: Session = Depends(get_db),
113
+    db: AsyncSession = Depends(get_async_db),
115
     current_user: User = Depends(get_current_user)
114
     current_user: User = Depends(get_current_user)
116
 ):
115
 ):
117
     """执行自定义查询(需要管理员权限,生产环境请谨慎)"""
116
     """执行自定义查询(需要管理员权限,生产环境请谨慎)"""

+ 8
- 9
app/api/v1/auth.py Просмотреть файл

1
 # app/api/v1/auth.py
1
 # app/api/v1/auth.py
2
 from fastapi import APIRouter, HTTPException, status, Request, Depends
2
 from fastapi import APIRouter, HTTPException, status, Request, Depends
3
 from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
3
 from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
4
-from sqlalchemy.orm import Session
5
-from typing import Optional
4
+from sqlalchemy.ext.asyncio import AsyncSession
6
 
5
 
7
 from app.schemas.user import UserCreate, UserLogin
6
 from app.schemas.user import UserCreate, UserLogin
8
 from app.schemas.token import TokenResponse, RefreshTokenRequest
7
 from app.schemas.token import TokenResponse, RefreshTokenRequest
12
     verify_access_token
11
     verify_access_token
13
 )
12
 )
14
 from app.config import settings
13
 from app.config import settings
15
-from app.database import get_db
14
+from app.database import get_async_db
16
 from app.models.user import User
15
 from app.models.user import User
17
 from app.services.user_service import UserService
16
 from app.services.user_service import UserService
18
 
17
 
32
 async def register(
31
 async def register(
33
     user_data: UserCreate,
32
     user_data: UserCreate,
34
     request: Request,
33
     request: Request,
35
-    db: Session = Depends(get_db)
34
+    db: AsyncSession = Depends(get_async_db)
36
 ):
35
 ):
37
     """用户注册"""
36
     """用户注册"""
38
     user_service = UserService(db)
37
     user_service = UserService(db)
98
 async def login(
97
 async def login(
99
     login_data: UserLogin,
98
     login_data: UserLogin,
100
     request: Request,
99
     request: Request,
101
-    db: Session = Depends(get_db)
100
+    db: AsyncSession = Depends(get_async_db)
102
 ):
101
 ):
103
     """用户登录"""
102
     """用户登录"""
104
     user_service = UserService(db)
103
     user_service = UserService(db)
174
 @router.get("/me")
173
 @router.get("/me")
175
 async def get_current_user(
174
 async def get_current_user(
176
     credentials: HTTPAuthorizationCredentials = Depends(security),
175
     credentials: HTTPAuthorizationCredentials = Depends(security),
177
-    db: Session = Depends(get_db)
176
+    db: AsyncSession = Depends(get_async_db)
178
 ):
177
 ):
179
     """获取当前用户信息"""
178
     """获取当前用户信息"""
180
     token = credentials.credentials
179
     token = credentials.credentials
221
 @router.post("/refresh")
220
 @router.post("/refresh")
222
 async def refresh_token(
221
 async def refresh_token(
223
     refresh_data: RefreshTokenRequest,
222
     refresh_data: RefreshTokenRequest,
224
-    db: Session = Depends(get_db)
223
+    db: AsyncSession = Depends(get_async_db)
225
 ):
224
 ):
226
     """刷新访问令牌"""
225
     """刷新访问令牌"""
227
     refresh_token = refresh_data.refresh_token
226
     refresh_token = refresh_data.refresh_token
277
 @router.post("/logout")
276
 @router.post("/logout")
278
 async def logout(
277
 async def logout(
279
     refresh_data: RefreshTokenRequest,
278
     refresh_data: RefreshTokenRequest,
280
-    db: Session = Depends(get_db),
279
+    db: AsyncSession = Depends(get_async_db),
281
     credentials: HTTPAuthorizationCredentials = Depends(security)
280
     credentials: HTTPAuthorizationCredentials = Depends(security)
282
 ):
281
 ):
283
     """用户登出"""
282
     """用户登出"""
286
     return {"message": "登出成功"}
285
     return {"message": "登出成功"}
287
 
286
 
288
 @router.get("/test-db")
287
 @router.get("/test-db")
289
-async def test_database(db: Session = Depends(get_db)):
288
+async def test_database(db: AsyncSession = Depends(get_async_db)):
290
     """测试数据库连接和用户统计"""
289
     """测试数据库连接和用户统计"""
291
     user_service = UserService(db)
290
     user_service = UserService(db)
292
     
291
     

+ 47
- 34
app/api/v1/users.py Просмотреть файл

1
 # app/api/v1/users.py
1
 # app/api/v1/users.py
2
 from fastapi import APIRouter, Depends, HTTPException, status, Query
2
 from fastapi import APIRouter, Depends, HTTPException, status, Query
3
-from sqlalchemy.orm import Session
3
+from sqlalchemy.ext.asyncio import AsyncSession
4
 from typing import List, Optional
4
 from typing import List, Optional
5
-from datetime import datetime
5
+from datetime import datetime, timezone, timedelta
6
 
6
 
7
 from app.schemas.user import UserProfile, UserUpdate, PasswordChange
7
 from app.schemas.user import UserProfile, UserUpdate, PasswordChange
8
-from app.database import get_db
8
+from app.database import get_async_db  # 使用异步依赖
9
 from app.services.user_service import UserService
9
 from app.services.user_service import UserService
10
 from app.core.security import password_validator
10
 from app.core.security import password_validator
11
 
11
 
17
 router = APIRouter(prefix="/users", tags=["用户管理"])
17
 router = APIRouter(prefix="/users", tags=["用户管理"])
18
 security = HTTPBearer()
18
 security = HTTPBearer()
19
 
19
 
20
-def get_current_user(
20
+
21
+# ==================== 异步认证依赖 ====================
22
+async def get_current_user(
21
     credentials: HTTPAuthorizationCredentials = Depends(security),
23
     credentials: HTTPAuthorizationCredentials = Depends(security),
22
-    db: Session = Depends(get_db)
24
+    db: AsyncSession = Depends(get_async_db)  # 异步Session
23
 ):
25
 ):
24
-    """获取当前用户依赖"""
26
+    """获取当前用户依赖(异步版本)"""
25
     token = credentials.credentials
27
     token = credentials.credentials
26
     
28
     
27
     # 验证令牌
29
     # 验证令牌
33
         )
35
         )
34
     
36
     
35
     username = payload.get("sub")
37
     username = payload.get("sub")
38
+    if not username:
39
+        raise HTTPException(
40
+            status_code=status.HTTP_401_UNAUTHORIZED,
41
+            detail="令牌无效"
42
+        )
43
+    
36
     user_service = UserService(db)
44
     user_service = UserService(db)
37
-    user = user_service.get_user_by_username(username)
45
+    user = await user_service.get_user_by_username(username)  # 使用await
38
     
46
     
39
     if not user:
47
     if not user:
40
         raise HTTPException(
48
         raise HTTPException(
50
     
58
     
51
     return user
59
     return user
52
 
60
 
61
+
62
+# ==================== 用户个人资料相关 ====================
53
 @router.get("/me", response_model=UserProfile)
63
 @router.get("/me", response_model=UserProfile)
54
 async def get_my_profile(
64
 async def get_my_profile(
55
     current_user = Depends(get_current_user),
65
     current_user = Depends(get_current_user),
56
-    db: Session = Depends(get_db)
66
+    db: AsyncSession = Depends(get_async_db)  # 异步Session
57
 ):
67
 ):
58
     """获取我的资料"""
68
     """获取我的资料"""
59
-    user_service = UserService(db)
60
-    user = await user_service.get_user_by_id(current_user.id)
61
-    
62
-    if not user:
63
-        raise HTTPException(
64
-            status_code=status.HTTP_404_NOT_FOUND,
65
-            detail="用户不存在"
66
-        )
67
-    
69
+    # 直接从current_user获取,无需再查询数据库
68
     return UserProfile(
70
     return UserProfile(
69
-        id=user.id,
70
-        username=user.username,
71
-        email=user.email,
72
-        full_name=user.full_name,
73
-        is_active=user.is_active,
74
-        is_verified=user.is_verified,
75
-        created_at=user.created_at,
76
-        last_login=user.last_login,
77
-        avatar=user.avatar
71
+        id=current_user.id,
72
+        username=current_user.username,
73
+        email=current_user.email,
74
+        full_name=current_user.full_name,
75
+        is_active=current_user.is_active,
76
+        is_verified=current_user.is_verified,
77
+        created_at=current_user.created_at,
78
+        last_login=current_user.last_login,
79
+        avatar=current_user.avatar
78
     )
80
     )
79
 
81
 
82
+
80
 @router.put("/me", response_model=UserProfile)
83
 @router.put("/me", response_model=UserProfile)
81
 async def update_my_profile(
84
 async def update_my_profile(
82
     user_data: UserUpdate,
85
     user_data: UserUpdate,
83
     current_user = Depends(get_current_user),
86
     current_user = Depends(get_current_user),
84
-    db: Session = Depends(get_db)
87
+    db: AsyncSession = Depends(get_async_db)
85
 ):
88
 ):
86
     """更新我的资料"""
89
     """更新我的资料"""
87
     user_service = UserService(db)
90
     user_service = UserService(db)
97
                 detail="邮箱已被使用"
100
                 detail="邮箱已被使用"
98
             )
101
             )
99
     
102
     
103
+    # 更新用户信息
100
     updated_user = await user_service.update_user(current_user.id, update_dict)
104
     updated_user = await user_service.update_user(current_user.id, update_dict)
101
     
105
     
102
     if not updated_user:
106
     if not updated_user:
117
         avatar=updated_user.avatar
121
         avatar=updated_user.avatar
118
     )
122
     )
119
 
123
 
124
+
120
 @router.post("/me/change-password")
125
 @router.post("/me/change-password")
121
 async def change_password(
126
 async def change_password(
122
     password_data: PasswordChange,
127
     password_data: PasswordChange,
123
     current_user = Depends(get_current_user),
128
     current_user = Depends(get_current_user),
124
-    db: Session = Depends(get_db)
129
+    db: AsyncSession = Depends(get_async_db)
125
 ):
130
 ):
126
     """修改密码"""
131
     """修改密码"""
127
     user_service = UserService(db)
132
     user_service = UserService(db)
147
             detail=str(e)
152
             detail=str(e)
148
         )
153
         )
149
 
154
 
155
+
156
+# ==================== 管理员操作 ====================
150
 @router.get("/", response_model=List[UserProfile])
157
 @router.get("/", response_model=List[UserProfile])
151
 async def list_users(
158
 async def list_users(
152
     skip: int = Query(0, ge=0),
159
     skip: int = Query(0, ge=0),
153
     limit: int = Query(100, ge=1, le=100),
160
     limit: int = Query(100, ge=1, le=100),
154
     active_only: bool = Query(True),
161
     active_only: bool = Query(True),
155
     current_user = Depends(get_current_user),
162
     current_user = Depends(get_current_user),
156
-    db: Session = Depends(get_db)
163
+    db: AsyncSession = Depends(get_async_db)
157
 ):
164
 ):
158
     """获取用户列表(需要管理员权限)"""
165
     """获取用户列表(需要管理员权限)"""
159
     if not current_user.is_superuser:
166
     if not current_user.is_superuser:
180
         for user in users
187
         for user in users
181
     ]
188
     ]
182
 
189
 
190
+
183
 @router.get("/{user_id}", response_model=UserProfile)
191
 @router.get("/{user_id}", response_model=UserProfile)
184
 async def get_user(
192
 async def get_user(
185
     user_id: int,
193
     user_id: int,
186
     current_user = Depends(get_current_user),
194
     current_user = Depends(get_current_user),
187
-    db: Session = Depends(get_db)
195
+    db: AsyncSession = Depends(get_async_db)
188
 ):
196
 ):
189
-    """获取单个用户信息(需要管理员权限查看其他用户)"""
197
+    """获取单个用户信息"""
190
     user_service = UserService(db)
198
     user_service = UserService(db)
191
     
199
     
192
     # 普通用户只能查看自己的信息
200
     # 普通用户只能查看自己的信息
196
             detail="只能查看自己的用户信息"
204
             detail="只能查看自己的用户信息"
197
         )
205
         )
198
     
206
     
207
+    # 管理员可以查看任何人,普通用户只能查看自己(前面已校验)
199
     user = await user_service.get_user_by_id(user_id)
208
     user = await user_service.get_user_by_id(user_id)
200
     
209
     
201
     if not user:
210
     if not user:
216
         avatar=user.avatar
225
         avatar=user.avatar
217
     )
226
     )
218
 
227
 
228
+
219
 @router.delete("/{user_id}")
229
 @router.delete("/{user_id}")
220
 async def delete_user(
230
 async def delete_user(
221
     user_id: int,
231
     user_id: int,
222
     current_user = Depends(get_current_user),
232
     current_user = Depends(get_current_user),
223
-    db: Session = Depends(get_db)
233
+    db: AsyncSession = Depends(get_async_db)
224
 ):
234
 ):
225
     """删除用户(需要管理员权限)"""
235
     """删除用户(需要管理员权限)"""
226
     if not current_user.is_superuser:
236
     if not current_user.is_superuser:
247
     
257
     
248
     return {"message": "用户已禁用"}
258
     return {"message": "用户已禁用"}
249
 
259
 
260
+
250
 @router.get("/test")
261
 @router.get("/test")
251
 async def test_users():
262
 async def test_users():
252
     """测试用户管理 API"""
263
     """测试用户管理 API"""
253
     return {
264
     return {
254
-        "message": "用户管理 API 正常运行(使用 SQLite 数据库)",
265
+        "message": "用户管理 API 正常运行(使用 SQLite 数据库异步版本)",
255
         "endpoints": {
266
         "endpoints": {
256
             "GET /me": "获取我的资料",
267
             "GET /me": "获取我的资料",
257
             "PUT /me": "更新我的资料",
268
             "PUT /me": "更新我的资料",
259
             "GET /": "获取用户列表(管理员)",
270
             "GET /": "获取用户列表(管理员)",
260
             "GET /{user_id}": "获取单个用户",
271
             "GET /{user_id}": "获取单个用户",
261
             "DELETE /{user_id}": "删除用户(管理员)"
272
             "DELETE /{user_id}": "删除用户(管理员)"
262
-        }
273
+        },
274
+        "async": True,
275
+        "database": "异步SQLAlchemy 2.0"
263
     }
276
     }

+ 15
- 4
app/config.py Просмотреть файл

11
     # 项目配置
11
     # 项目配置
12
     PROJECT_NAME: str = os.getenv("PROJECT_NAME", "CaiYouHui 采油会")
12
     PROJECT_NAME: str = os.getenv("PROJECT_NAME", "CaiYouHui 采油会")
13
     VERSION: str = os.getenv("VERSION", "1.0.0")
13
     VERSION: str = os.getenv("VERSION", "1.0.0")
14
+
15
+    # API配置
14
     API_V1_PREFIX: str = os.getenv("API_V1_PREFIX", "/api/v1")
16
     API_V1_PREFIX: str = os.getenv("API_V1_PREFIX", "/api/v1")
17
+    APP_RUN_PORT: int = int(os.getenv("APP_RUN_PORT", "10003"))
15
     
18
     
16
     # 安全配置
19
     # 安全配置
17
     SECRET_KEY: str = os.getenv("SECRET_KEY", secrets.token_urlsafe(32))
20
     SECRET_KEY: str = os.getenv("SECRET_KEY", secrets.token_urlsafe(32))
19
     ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "1440"))  # 24小时
22
     ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "1440"))  # 24小时
20
     
23
     
21
     # 数据库配置 - SQLite
24
     # 数据库配置 - SQLite
22
-    DATABASE_URL: str = os.getenv("DATABASE_URL", "sqlite:///./caiyouhui.db")
23
-    
25
+    DATABASE_URL: str = os.getenv("DATABASE_URL", "sqlite+aiosqlite:///./caiyouhui.db")
26
+
24
     # CORS 配置
27
     # CORS 配置
25
     BACKEND_CORS_ORIGINS: List[str] = os.getenv(
28
     BACKEND_CORS_ORIGINS: List[str] = os.getenv(
26
         "BACKEND_CORS_ORIGINS", 
29
         "BACKEND_CORS_ORIGINS", 
29
     
32
     
30
     # 调试模式
33
     # 调试模式
31
     DEBUG: bool = os.getenv("DEBUG", "True").lower() == "true"
34
     DEBUG: bool = os.getenv("DEBUG", "True").lower() == "true"
32
-    
35
+
36
+    # 日志配置
37
+    LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
38
+    LOG_FILE: str = os.getenv("LOG_FILE", "logs/app.log")
39
+
33
     # 文件上传
40
     # 文件上传
34
     UPLOAD_DIR: str = os.getenv("UPLOAD_DIR", "./uploads")
41
     UPLOAD_DIR: str = os.getenv("UPLOAD_DIR", "./uploads")
35
     MAX_UPLOAD_SIZE: int = int(os.getenv("MAX_UPLOAD_SIZE", "10485760"))  # 10MB
42
     MAX_UPLOAD_SIZE: int = int(os.getenv("MAX_UPLOAD_SIZE", "10485760"))  # 10MB
36
-    
43
+
44
+    # 性能配置
45
+    DATABASE_POOL_SIZE: int = 20
46
+    DATABASE_MAX_OVERFLOW: int = 10
47
+
37
     # 邮箱验证(可选,后续添加)
48
     # 邮箱验证(可选,后续添加)
38
     SMTP_ENABLED: bool = os.getenv("SMTP_ENABLED", "False").lower() == "true"
49
     SMTP_ENABLED: bool = os.getenv("SMTP_ENABLED", "False").lower() == "true"
39
     SMTP_HOST: str = os.getenv("SMTP_HOST", "")
50
     SMTP_HOST: str = os.getenv("SMTP_HOST", "")

+ 6
- 6
app/core/auth.py Просмотреть файл

1
 from passlib.context import CryptContext
1
 from passlib.context import CryptContext
2
 from jose import JWTError, jwt
2
 from jose import JWTError, jwt
3
-from datetime import datetime, timedelta
3
+from datetime import datetime, timezone, timedelta
4
 from typing import Optional, Dict, Any, Union
4
 from typing import Optional, Dict, Any, Union
5
 import secrets
5
 import secrets
6
 import string
6
 import string
23
     to_encode = data.copy()
23
     to_encode = data.copy()
24
     
24
     
25
     if expires_delta:
25
     if expires_delta:
26
-        expire = datetime.utcnow() + expires_delta
26
+        expire = datetime.now(timezone.utc) + expires_delta
27
     else:
27
     else:
28
-        expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
28
+        expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
29
     
29
     
30
     to_encode.update({"exp": expire, "type": "access"})
30
     to_encode.update({"exp": expire, "type": "access"})
31
     encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
31
     encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
34
 def create_refresh_token(data: dict) -> str:
34
 def create_refresh_token(data: dict) -> str:
35
     """创建刷新令牌"""
35
     """创建刷新令牌"""
36
     to_encode = data.copy()
36
     to_encode = data.copy()
37
-    expire = datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
37
+    expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
38
     
38
     
39
     to_encode.update({"exp": expire, "type": "refresh"})
39
     to_encode.update({"exp": expire, "type": "refresh"})
40
     encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
40
     encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
43
 def create_verification_token(email: str) -> str:
43
 def create_verification_token(email: str) -> str:
44
     """创建验证令牌"""
44
     """创建验证令牌"""
45
     to_encode = {"email": email, "type": "verify"}
45
     to_encode = {"email": email, "type": "verify"}
46
-    expire = datetime.utcnow() + timedelta(hours=settings.VERIFICATION_TOKEN_EXPIRE_HOURS)
46
+    expire = datetime.now(timezone.utc) + timedelta(hours=settings.VERIFICATION_TOKEN_EXPIRE_HOURS)
47
     
47
     
48
     to_encode.update({"exp": expire})
48
     to_encode.update({"exp": expire})
49
     encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
49
     encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
52
 def create_reset_token(email: str) -> str:
52
 def create_reset_token(email: str) -> str:
53
     """创建密码重置令牌"""
53
     """创建密码重置令牌"""
54
     to_encode = {"email": email, "type": "reset"}
54
     to_encode = {"email": email, "type": "reset"}
55
-    expire = datetime.utcnow() + timedelta(minutes=settings.RESET_TOKEN_EXPIRE_MINUTES)
55
+    expire = datetime.now(timezone.utc) + timedelta(minutes=settings.RESET_TOKEN_EXPIRE_MINUTES)
56
     
56
     
57
     to_encode.update({"exp": expire})
57
     to_encode.update({"exp": expire})
58
     encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
58
     encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)

+ 5
- 5
app/core/security.py Просмотреть файл

6
 import time
6
 import time
7
 import json
7
 import json
8
 from typing import Optional, Dict, Any, Tuple
8
 from typing import Optional, Dict, Any, Tuple
9
-from datetime import datetime, timedelta
9
+from datetime import datetime, timezone, timedelta
10
 import re
10
 import re
11
 
11
 
12
 # 密码哈希工具
12
 # 密码哈希工具
165
         
165
         
166
         # 设置过期时间
166
         # 设置过期时间
167
         if expires_delta:
167
         if expires_delta:
168
-            expire = datetime.utcnow() + expires_delta
168
+            expire = datetime.now(timezone.utc) + expires_delta
169
         elif expires_minutes:
169
         elif expires_minutes:
170
-            expire = datetime.utcnow() + timedelta(minutes=expires_minutes)
170
+            expire = datetime.now(timezone.utc) + timedelta(minutes=expires_minutes)
171
         else:
171
         else:
172
-            expire = datetime.utcnow() + timedelta(minutes=30)  # 默认30分钟
172
+            expire = datetime.now(timezone.utc) + timedelta(minutes=30)  # 默认30分钟
173
         
173
         
174
         # 添加标准声明
174
         # 添加标准声明
175
         payload.update({
175
         payload.update({
176
             "exp": int(expire.timestamp()),
176
             "exp": int(expire.timestamp()),
177
-            "iat": int(datetime.utcnow().timestamp()),
177
+            "iat": int(datetime.now(timezone.utc).timestamp()),
178
             "iss": "caiyouhui-api",
178
             "iss": "caiyouhui-api",
179
             "type": token_type
179
             "type": token_type
180
         })
180
         })

+ 121
- 46
app/database.py Просмотреть файл

1
 # app/database.py
1
 # app/database.py
2
-from sqlalchemy import create_engine
3
-from sqlalchemy.orm import sessionmaker, Session
4
-from sqlalchemy.ext.declarative import declarative_base
5
-from contextlib import contextmanager
6
-from typing import Generator
7
-import os
2
+from sqlalchemy.orm import DeclarativeBase
3
+from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine, AsyncEngine
4
+from contextlib import asynccontextmanager
5
+from typing import AsyncGenerator
8
 
6
 
9
 from .config import settings
7
 from .config import settings
10
 
8
 
11
-# 创建数据库引擎
12
-engine = create_engine(
13
-    settings.DATABASE_URL,
14
-    connect_args={"check_same_thread": False} if "sqlite" in settings.DATABASE_URL else {},
15
-    echo=settings.DEBUG
9
+
10
+# ====================== 1. 基类定义 ======================
11
+class Base(DeclarativeBase):
12
+    """SQLAlchemy 2.0 声明式基类"""
13
+    pass
14
+
15
+
16
+# ====================== 2. 异步URL转换 ======================
17
+def ensure_async_url(database_url: str) -> str:
18
+    """确保使用异步数据库URL"""
19
+    # 如果已经是异步URL,直接返回
20
+    if any(x in database_url for x in ["+asyncpg", "+aiomysql", "+aiosqlite"]):
21
+        return database_url
22
+    
23
+    # 转换同步URL为异步URL
24
+    if database_url.startswith("postgresql://"):
25
+        return database_url.replace("postgresql://", "postgresql+asyncpg://")
26
+    elif database_url.startswith("mysql://"):
27
+        return database_url.replace("mysql://", "mysql+aiomysql://")
28
+    elif database_url.startswith("sqlite://"):
29
+        return database_url.replace("sqlite://", "sqlite+aiosqlite://")
30
+    else:
31
+        raise ValueError(f"不支持的数据库类型: {database_url}")
32
+
33
+
34
+# ====================== 3. 异步引擎配置 ======================
35
+async_database_url = ensure_async_url(settings.DATABASE_URL)
36
+
37
+# 连接参数
38
+connect_args = {}
39
+if "sqlite+aiosqlite" in async_database_url:
40
+    connect_args = {"check_same_thread": False}
41
+
42
+async_engine: AsyncEngine = create_async_engine(
43
+    async_database_url,
44
+    connect_args=connect_args,
45
+    echo=settings.DEBUG,
46
+    pool_size=20,
47
+    max_overflow=10,
48
+    pool_pre_ping=True,
49
+    future=True,  # 启用2.0特性
16
 )
50
 )
17
 
51
 
18
-# 创建会话工厂
19
-SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
20
 
52
 
21
-# 声明基类
22
-Base = declarative_base()
53
+# ====================== 4. 异步会话工厂 ======================
54
+AsyncSessionLocal = async_sessionmaker(
55
+    bind=async_engine,
56
+    class_=AsyncSession,
57
+    expire_on_commit=False,
58
+    autoflush=False,
59
+)
60
+
23
 
61
 
24
-def get_db() -> Generator[Session, None, None]:
25
-    """数据库会话依赖注入"""
26
-    db = SessionLocal()
62
+# ====================== 5. 异步依赖注入 ======================
63
+async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
64
+    """
65
+    FastAPI兼容的异步数据库依赖
66
+    注意:移除了@asynccontextmanager装饰器
67
+    """
68
+    session = AsyncSessionLocal()
27
     try:
69
     try:
28
-        yield db
70
+        yield session
71
+        await session.commit()
72
+    except Exception:
73
+        await session.rollback()
74
+        raise
29
     finally:
75
     finally:
30
-        db.close()
76
+        await session.close()
31
 
77
 
32
-def init_db():
33
-    """初始化数据库表"""
78
+
79
+# ====================== 6. 数据库初始化 ======================
80
+async def init_async_db():
81
+    """异步初始化数据库"""
34
     from .models.user import User
82
     from .models.user import User
35
-    Base.metadata.create_all(bind=engine)
83
+    from sqlalchemy import select
84
+    from .core.security import password_hasher
85
+    
86
+    # 创建表
87
+    async with async_engine.begin() as conn:
88
+        await conn.run_sync(Base.metadata.create_all)
36
     print("✅ 数据库表创建完成")
89
     print("✅ 数据库表创建完成")
37
     
90
     
38
     # 创建默认管理员用户
91
     # 创建默认管理员用户
39
-    db = SessionLocal()
92
+    async with AsyncSessionLocal() as session:
93
+        try:
94
+            # 异步查询
95
+            stmt = select(User).where(User.username == "admin")
96
+            result = await session.execute(stmt)
97
+            admin = result.scalar_one_or_none()
98
+            
99
+            if not admin:
100
+                admin_user = User(
101
+                    username="admin",
102
+                    email="admin@caiyouhui.com",
103
+                    hashed_password=password_hasher.hash_password("Admin123!"),
104
+                    full_name="系统管理员",
105
+                    is_active=True,
106
+                    is_verified=True,
107
+                    is_superuser=True
108
+                )
109
+                session.add(admin_user)
110
+                await session.commit()
111
+                print("✅ 默认管理员用户已创建")
112
+        except Exception as e:
113
+            print(f"⚠️  创建管理员用户时出错: {e}")
114
+            await session.rollback()
115
+
116
+
117
+# ====================== 7. 连接健康检查 ======================
118
+async def check_async_connection():
119
+    """检查数据库连接"""
120
+    from sqlalchemy import text  # 需要导入text
40
     try:
121
     try:
41
-        # 检查是否已存在管理员
42
-        admin = db.query(User).filter(User.username == "admin").first()
43
-        if not admin:
44
-            from .core.security import password_hasher
45
-            admin_user = User(
46
-                username="admin",
47
-                email="admin@caiyouhui.com",
48
-                hashed_password=password_hasher.hash_password("Admin123!"),
49
-                full_name="系统管理员",
50
-                is_active=True,
51
-                is_verified=True,
52
-                is_superuser=True
53
-            )
54
-            db.add(admin_user)
55
-            db.commit()
56
-            print("✅ 默认管理员用户已创建")
122
+        async with async_engine.connect() as conn:
123
+            await conn.execute(text("SELECT 1"))
124
+        print("✅ 数据库连接正常")
125
+        return True
57
     except Exception as e:
126
     except Exception as e:
58
-        print(f"⚠️  创建管理员用户时出错: {e}")
59
-        db.rollback()
60
-    finally:
61
-        db.close()
127
+        print(f"❌ 数据库连接失败: {e}")
128
+        return False
129
+
62
 
130
 
63
-# 导出
64
-__all__ = ["Base", "engine", "SessionLocal", "get_db", "init_db"]
131
+# ====================== 8. 导出 ======================
132
+__all__ = [
133
+    "Base",
134
+    "async_engine", 
135
+    "AsyncSessionLocal",
136
+    "get_async_db",
137
+    "init_async_db",
138
+    "check_async_connection",
139
+]

+ 3
- 3
app/dependencies/auth.py Просмотреть файл

1
 from fastapi import Depends, HTTPException, status
1
 from fastapi import Depends, HTTPException, status
2
 from fastapi.security import OAuth2PasswordBearer
2
 from fastapi.security import OAuth2PasswordBearer
3
-from sqlalchemy.orm import Session
3
+from sqlalchemy.ext.asyncio import AsyncSession
4
 from jose import JWTError, jwt
4
 from jose import JWTError, jwt
5
 from typing import Optional
5
 from typing import Optional
6
 
6
 
7
-from ..database import get_db
7
+from ..database import get_async_db
8
 from ..models.user import User
8
 from ..models.user import User
9
 from ..schemas.token import TokenData
9
 from ..schemas.token import TokenData
10
 from ..config import settings
10
 from ..config import settings
16
 
16
 
17
 async def get_current_user(
17
 async def get_current_user(
18
     token: Optional[str] = Depends(oauth2_scheme),
18
     token: Optional[str] = Depends(oauth2_scheme),
19
-    db: Session = Depends(get_db)
19
+    db: AsyncSession = Depends(get_async_db)
20
 ) -> Optional[User]:
20
 ) -> Optional[User]:
21
     """获取当前用户"""
21
     """获取当前用户"""
22
     if not token:
22
     if not token:

+ 1
- 1
app/dependencies/database.py Просмотреть файл

6
 # ✅ 正确导入方式
6
 # ✅ 正确导入方式
7
 from app.database import SessionLocal
7
 from app.database import SessionLocal
8
 
8
 
9
-def get_db() -> Generator[Session, None, None]:
9
+def get_async_db() -> Generator[Session, None, None]:
10
     """数据库会话依赖注入"""
10
     """数据库会话依赖注入"""
11
     db = SessionLocal()
11
     db = SessionLocal()
12
     try:
12
     try:

+ 1
- 1
app/logging_config.py Просмотреть файл

2
 import logging
2
 import logging
3
 import logging.handlers  # ✅ 这里导入 handlers
3
 import logging.handlers  # ✅ 这里导入 handlers
4
 import os
4
 import os
5
-from datetime import datetime
5
+from datetime import datetime, timezone, timedelta
6
 
6
 
7
 def setup_logging_with_rotation():
7
 def setup_logging_with_rotation():
8
     """配置带轮转的日志系统"""
8
     """配置带轮转的日志系统"""

+ 406
- 46
app/main.py Просмотреть файл

1
-# app/main.py - 简化版本
2
-from fastapi import FastAPI
1
+# app/main.py
2
+from contextlib import asynccontextmanager
3
+from fastapi import FastAPI, Request, status
4
+from fastapi.responses import JSONResponse
3
 from fastapi.middleware.cors import CORSMiddleware
5
 from fastapi.middleware.cors import CORSMiddleware
4
-from app.logging_config import setup_logging_with_rotation
6
+from fastapi.exceptions import RequestValidationError
7
+from fastapi.middleware.gzip import GZipMiddleware
5
 import logging
8
 import logging
6
-import os
9
+import time
10
+import traceback
11
+import sys
7
 
12
 
8
 # 配置
13
 # 配置
9
-from .config import settings
10
-
11
-# 配置日志
12
-logging.basicConfig(
13
-    level=logging.INFO,
14
-    # format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
15
-    # datefmt='%Y-%m-%d %H:%M:%S',
16
-    # handlers=[
17
-    #     logging.StreamHandler(),  # 控制台
18
-    #     logging.FileHandler('app/logs/app.log', encoding='utf-8')  # 文件
19
-    # ]
20
-)
21
-# 为特定模块设置更详细的日志
22
-logging.getLogger("app.api.v1").setLevel(logging.DEBUG)
23
-logging.getLogger("app.core.security").setLevel(logging.DEBUG)
14
+from app.config import settings
24
 
15
 
16
+# 设置日志
25
 logger = logging.getLogger(__name__)
17
 logger = logging.getLogger(__name__)
26
 logger.info("✅ 日志配置完成")
18
 logger.info("✅ 日志配置完成")
27
 
19
 
28
-# logger = setup_logging_with_rotation
29
-
20
+# ==================== 异步生命周期管理 ====================
21
+@asynccontextmanager
22
+async def lifespan(app: FastAPI):
23
+    """
24
+    应用生命周期管理器 - 完全异步
25
+    启动时初始化资源,关闭时清理资源
26
+    """
27
+    # 启动阶段
28
+    logger.info("🚀 应用启动中...")
29
+    
30
+    try:
31
+        # 1. 初始化数据库
32
+        from app.database import init_async_db, check_async_connection
33
+        if await check_async_connection():
34
+            await init_async_db()
35
+            logger.info("✅ 数据库初始化完成")
36
+        else:
37
+            logger.error("❌ 数据库连接失败")
38
+            # 可以根据需要决定是否终止启动
39
+            # raise RuntimeError("数据库连接失败")
40
+        
41
+        # 2. 初始化Redis等缓存(可选)
42
+        # try:
43
+        #     from app.core.redis import init_redis
44
+        #     await init_redis()
45
+        #     logger.info("✅ Redis连接完成")
46
+        # except ImportError:
47
+        #     logger.info("ℹ️  Redis未配置,跳过初始化")
48
+        
49
+        # 3. 初始化其他服务
50
+        # try:
51
+        #     from app.core.security import init_security
52
+        #     init_security()
53
+        #     logger.info("✅ 安全模块初始化完成")
54
+        # except Exception as e:
55
+        #     logger.warning(f"⚠️  安全模块初始化异常: {e}")
56
+        
57
+        logger.info(f"🎉 {settings.PROJECT_NAME} v{settings.VERSION} 启动完成")
58
+        yield
59
+        
60
+    except Exception as e:
61
+        logger.error(f"🔥 应用启动失败: {e}")
62
+        logger.error(traceback.format_exc())
63
+        sys.exit(1)
64
+    
65
+    finally:
66
+        # 关闭阶段
67
+        logger.info("🛑 应用关闭中...")
68
+        
69
+        # 清理资源
70
+        try:
71
+            from app.database import async_engine
72
+            await async_engine.dispose()
73
+            logger.info("✅ 数据库连接池已清理")
74
+        except Exception as e:
75
+            logger.warning(f"⚠️  数据库清理异常: {e}")
76
+        
77
+        logger.info("👋 应用已安全关闭")
30
 
78
 
31
 
79
 
32
-# 创建 FastAPI 应用
80
+# ==================== 创建FastAPI应用 ====================
33
 app = FastAPI(
81
 app = FastAPI(
34
     title=settings.PROJECT_NAME,
82
     title=settings.PROJECT_NAME,
35
     version=settings.VERSION,
83
     version=settings.VERSION,
84
+    description=f"""
85
+    {settings.PROJECT_NAME} API 服务
86
+    
87
+    ## 功能特性
88
+    - ✅ 全异步架构(SQLAlchemy 2.0 + async/await)
89
+    - 🔒 JWT认证与授权
90
+    - 📊 性能监控中间件
91
+    - 🚀 自动API文档
92
+    
93
+    ## 环境
94
+    - **调试模式**: {'开启' if settings.DEBUG else '关闭'}
95
+    - **数据库**: {settings.DATABASE_URL.split('://')[0] if '://' in settings.DATABASE_URL else '未知'}
96
+    """,
36
     docs_url="/docs" if settings.DEBUG else None,
97
     docs_url="/docs" if settings.DEBUG else None,
37
-    redoc_url="/redoc" if settings.DEBUG else None
98
+    redoc_url="/redoc" if settings.DEBUG else None,
99
+    openapi_url="/openapi.json" if settings.DEBUG else None,
100
+    lifespan=lifespan,  # 异步生命周期管理
101
+    swagger_ui_parameters={
102
+        "persistAuthorization": True,
103
+        "displayRequestDuration": True,
104
+        "filter": True,
105
+    }
38
 )
106
 )
39
 
107
 
40
-# 配置CORS
108
+# ==================== 全局异常处理器 ====================
109
+@app.exception_handler(Exception)
110
+async def global_exception_handler(request: Request, exc: Exception):
111
+    """全局异常处理器"""
112
+    logger.error(f"🔥 未捕获的异常: {exc}", exc_info=True)
113
+    
114
+    # 在调试模式下返回详细错误
115
+    if settings.DEBUG:
116
+        return JSONResponse(
117
+            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
118
+            content={
119
+                "detail": str(exc),
120
+                "type": exc.__class__.__name__,
121
+                "traceback": traceback.format_exc().split('\n'),
122
+                "path": request.url.path,
123
+                "method": request.method,
124
+            }
125
+        )
126
+    
127
+    # 生产环境返回简化错误
128
+    return JSONResponse(
129
+        status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
130
+        content={
131
+            "detail": "服务器内部错误",
132
+            "request_id": request.state.get("request_id", "unknown") if hasattr(request.state, "request_id") else "unknown"
133
+        }
134
+    )
135
+
136
+
137
+@app.exception_handler(RequestValidationError)
138
+async def validation_exception_handler(request: Request, exc: RequestValidationError):
139
+    """请求验证异常处理器"""
140
+    logger.warning(f"⚠️  请求验证失败: {exc}")
141
+    
142
+    return JSONResponse(
143
+        status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
144
+        content={
145
+            "detail": exc.errors(),
146
+            "body": exc.body,
147
+        }
148
+    )
149
+
150
+
151
+# ==================== 中间件配置 ====================
152
+# 1. CORS中间件
41
 if settings.BACKEND_CORS_ORIGINS:
153
 if settings.BACKEND_CORS_ORIGINS:
42
     app.add_middleware(
154
     app.add_middleware(
43
         CORSMiddleware,
155
         CORSMiddleware,
45
         allow_credentials=True,
157
         allow_credentials=True,
46
         allow_methods=["*"],
158
         allow_methods=["*"],
47
         allow_headers=["*"],
159
         allow_headers=["*"],
160
+        expose_headers=["*"],
161
+        max_age=600,  # 预检请求缓存时间(秒)
48
     )
162
     )
163
+    logger.info(f"✅ CORS已启用,允许的源: {settings.BACKEND_CORS_ORIGINS}")
49
 
164
 
50
-# 导入并注册路由
51
-try:
52
-    # 导入所有路由模块
53
-    # from .api.v1.admin import router as admin_router
54
-    from .api.v1.auth import router as auth_router
55
-    from .api.v1.users import router as users_router
56
-    # from .api.v1.verify import router as verify_router
165
+# 2. 请求ID中间件(用于日志追踪)
166
+@app.middleware("http")
167
+async def add_request_id(request: Request, call_next):
168
+    """为每个请求添加唯一ID"""
169
+    import uuid
170
+    request_id = str(uuid.uuid4())[:8]
171
+    request.state.request_id = request_id
57
     
172
     
58
-    # 注册路由
59
-    app.include_router(auth_router, prefix=settings.API_V1_PREFIX)
60
-    app.include_router(users_router, prefix=settings.API_V1_PREFIX)
61
-    # app.include_router(verify_router, prefix=settings.API_V1_PREFIX)
173
+    response = await call_next(request)
174
+    response.headers["X-Request-ID"] = request_id
175
+    return response
176
+
177
+
178
+# 3. 性能监控中间件
179
+@app.middleware("http")
180
+async def performance_middleware(request: Request, call_next):
181
+    """记录请求处理时间"""
182
+    start_time = time.time()
183
+    
184
+    try:
185
+        response = await call_next(request)
186
+        process_time = time.time() - start_time
187
+        
188
+        # 记录慢请求
189
+        if process_time > 1.0:  # 超过1秒
190
+            logger.warning(
191
+                f"🐌 慢请求: {request.method} {request.url.path} "
192
+                f"耗时: {process_time:.3f}s "
193
+                f"状态: {response.status_code} "
194
+                f"请求ID: {getattr(request.state, 'request_id', 'unknown')}"
195
+            )
196
+        elif settings.DEBUG:
197
+            logger.debug(
198
+                f"⚡ 请求: {request.method} {request.url.path} "
199
+                f"耗时: {process_time:.3f}s"
200
+            )
201
+        
202
+        # 添加处理时间到响应头
203
+        response.headers["X-Process-Time"] = str(process_time)
204
+        return response
205
+        
206
+    except Exception as e:
207
+        process_time = time.time() - start_time
208
+        logger.error(
209
+            f"🔥 请求异常: {request.method} {request.url.path} "
210
+            f"耗时: {process_time:.3f}s "
211
+            f"错误: {e}"
212
+        )
213
+        raise
214
+
215
+
216
+# 4. GZIP压缩中间件(提升性能)
217
+app.add_middleware(
218
+    GZipMiddleware,
219
+    minimum_size=1000,  # 只压缩大于1KB的响应
220
+)
221
+
222
+
223
+# ==================== 路由注册 ====================
224
+def register_routers():
225
+    """注册所有API路由"""
226
+    routers = []
62
     
227
     
63
-    logger.info("✅ API 路由注册成功")
228
+    # 尝试导入并注册路由模块
229
+    try:
230
+        from app.api.v1.auth import router as auth_router
231
+        routers.append(("认证模块", auth_router))
232
+    except ImportError as e:
233
+        logger.warning(f"⚠️  认证路由导入失败: {e}")
64
     
234
     
65
-except ImportError as e:
66
-    logger.warning(f"⚠️  部分路由模块未找到: {e}")
235
+    try:
236
+        from app.api.v1.users import router as users_router
237
+        routers.append(("用户管理", users_router))
238
+    except ImportError as e:
239
+        logger.warning(f"⚠️  用户路由导入失败: {e}")
240
+    
241
+    # 更多路由模块...
242
+    # try:
243
+    #     from app.api.v1.admin import router as admin_router
244
+    #     routers.append(("管理后台", admin_router))
245
+    # except ImportError:
246
+    #     pass
247
+    
248
+    # 注册所有路由
249
+    for name, router in routers:
250
+        app.include_router(router, prefix=settings.API_V1_PREFIX)
251
+        logger.info(f"✅ 路由注册: {name}")
252
+    
253
+    logger.info(f"🎯 共注册 {len(routers)} 个路由模块")
254
+
67
 
255
 
68
-# 系统级路由
69
-@app.get("/health")
256
+# 调用路由注册
257
+register_routers()
258
+
259
+
260
+# ==================== 系统级路由 ====================
261
+@app.get("/health", tags=["系统监控"])
70
 async def health_check():
262
 async def health_check():
71
-    return {"status": "healthy", "service": settings.PROJECT_NAME}
263
+    """
264
+    健康检查端点
265
+    
266
+    返回服务健康状态,可用于K8s探针、负载均衡器健康检查等
267
+    """
268
+    from app.database import check_async_connection
269
+    
270
+    health_status = {
271
+        "status": "healthy",
272
+        "service": settings.PROJECT_NAME,
273
+        "version": settings.VERSION,
274
+        "timestamp": time.time(),
275
+        "async": True,
276
+        "debug": settings.DEBUG,
277
+    }
278
+    
279
+    # 检查数据库连接
280
+    try:
281
+        db_healthy = await check_async_connection()
282
+        health_status["database"] = "healthy" if db_healthy else "unhealthy"
283
+    except Exception as e:
284
+        health_status["database"] = f"error: {str(e)}"
285
+    
286
+    # 检查Redis连接(可选)
287
+    # try:
288
+    #     from app.core.redis import check_redis_connection
289
+    #     redis_healthy = await check_redis_connection()
290
+    #     health_status["redis"] = "healthy" if redis_healthy else "unhealthy"
291
+    # except ImportError:
292
+    #     health_status["redis"] = "not_configured"
293
+    # except Exception as e:
294
+    #     health_status["redis"] = f"error: {str(e)}"
295
+    
296
+    # 如果有不健康的服务,返回503
297
+    if "unhealthy" in health_status.values() or "error" in str(health_status.values()):
298
+        return JSONResponse(
299
+            status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
300
+            content=health_status
301
+        )
302
+    
303
+    return health_status
72
 
304
 
73
-@app.get("/")
305
+
306
+@app.get("/", tags=["系统"])
74
 async def root():
307
 async def root():
308
+    """
309
+    根端点
310
+    
311
+    返回API基本信息和可用端点
312
+    """
313
+    # 获取已注册的路由信息
314
+    routes_info = []
315
+    for route in app.routes:
316
+        if hasattr(route, "methods"):
317
+            routes_info.append({
318
+                "path": route.path,
319
+                "methods": list(route.methods) if route.methods else [],
320
+                "name": route.name or "",
321
+            })
322
+    
75
     return {
323
     return {
76
-        "message": f"Welcome to {settings.PROJECT_NAME} API",
324
+        "message": f"欢迎使用 {settings.PROJECT_NAME} API",
77
         "version": settings.VERSION,
325
         "version": settings.VERSION,
78
-        "docs": "/docs"
326
+        "description": "基于FastAPI的现代化异步API服务",
327
+        "docs": "/docs" if settings.DEBUG else None,
328
+        "redoc": "/redoc" if settings.DEBUG else None,
329
+        "health_check": "/health",
330
+        "async": True,
331
+        "database": "SQLAlchemy 2.0 Async",
332
+        "routes_count": len(routes_info),
333
+        "quick_links": {
334
+            "认证": f"{settings.API_V1_PREFIX}/auth/login",
335
+            "用户信息": f"{settings.API_V1_PREFIX}/users/me",
336
+            "API文档": "/docs" if settings.DEBUG else "未启用",
337
+        }
79
     }
338
     }
80
 
339
 
81
-if __name__ == "__main__":
340
+
341
+@app.get("/metrics", tags=["系统监控"])
342
+async def metrics():
343
+    """
344
+    监控指标端点(可用于Prometheus)
345
+    
346
+    返回应用性能指标
347
+    """
348
+    import psutil
349
+    import gc
350
+    
351
+    # 内存使用
352
+    process = psutil.Process()
353
+    mem_info = process.memory_info()
354
+    
355
+    # GC统计
356
+    gc_counts = gc.get_count()
357
+    
358
+    return {
359
+        "process": {
360
+            "pid": process.pid,
361
+            "cpu_percent": process.cpu_percent(),
362
+            "memory_mb": mem_info.rss / 1024 / 1024,
363
+            "threads": process.num_threads(),
364
+        },
365
+        "gc": {
366
+            "generation0": gc_counts[0],
367
+            "generation1": gc_counts[1],
368
+            "generation2": gc_counts[2],
369
+        },
370
+        "python": {
371
+            "version": sys.version,
372
+            "implementation": sys.implementation.name,
373
+        },
374
+        "timestamp": time.time(),
375
+    }
376
+
377
+
378
+@app.get("/config", tags=["系统"])
379
+async def show_config():
380
+    """
381
+    显示当前配置(调试用)
382
+    
383
+    注意:生产环境建议禁用此端点或限制访问
384
+    """
385
+    if not settings.DEBUG:
386
+        raise HTTPException(
387
+            status_code=status.HTTP_403_FORBIDDEN,
388
+            detail="仅调试模式下可用"
389
+        )
390
+    
391
+    # 安全地显示配置(隐藏敏感信息)
392
+    config_dict = {}
393
+    for key in dir(settings):
394
+        if not key.startswith("_"):
395
+            value = getattr(settings, key)
396
+            
397
+            # 隐藏敏感信息
398
+            if any(sensitive in key.lower() for sensitive in ["secret", "key", "password", "token"]):
399
+                config_dict[key] = "***HIDDEN***"
400
+            elif "url" in key.lower() and "database" in key.lower():
401
+                # 显示数据库类型但不显示完整URL
402
+                db_type = value.split("://")[0] if "://" in value else "unknown"
403
+                config_dict[key] = f"{db_type}://***HIDDEN***"
404
+            else:
405
+                config_dict[key] = value
406
+    
407
+    return {
408
+        "config": config_dict,
409
+        "environment": settings.ENVIRONMENT,
410
+        "debug": settings.DEBUG,
411
+    }
412
+
413
+
414
+# ==================== 启动入口 ====================
415
+def run_server():
416
+    """启动服务器(开发模式)"""
82
     import uvicorn
417
     import uvicorn
83
-    uvicorn.run(app, host="0.0.0.0", port=10003, reload=settings.DEBUG)
418
+    
419
+    logger.info("=" * 50)
420
+    logger.info(f"🚀 启动 {settings.PROJECT_NAME} v{settings.VERSION}")
421
+    logger.info(f"🐛 调试模式: {settings.DEBUG}")
422
+    logger.info(f"🌐 主机: 0.0.0.0")
423
+    logger.info(f"🔌 端口: {settings.APP_RUN_PORT}")
424
+    logger.info(f"🗄️  数据库: {settings.DATABASE_URL.split('://')[0] if '://' in settings.DATABASE_URL else '未知'}")
425
+    logger.info("=" * 50)
426
+    
427
+    uvicorn.run(
428
+        app,
429
+        host="0.0.0.0",
430
+        port=settings.APP_RUN_PORT,
431
+        reload=settings.DEBUG,
432
+        log_level="info" if settings.DEBUG else "warning",
433
+        access_log=True,
434
+        use_colors=True,
435
+        # 优化性能的配置
436
+        limit_concurrency=1000,
437
+        limit_max_requests=10000,
438
+        timeout_keep_alive=5,
439
+    )
440
+
441
+
442
+if __name__ == "__main__":
443
+    run_server()

+ 2
- 2
app/models/token.py Просмотреть файл

1
 from sqlalchemy import Column, Integer, String, DateTime, Boolean, ForeignKey
1
 from sqlalchemy import Column, Integer, String, DateTime, Boolean, ForeignKey
2
 from sqlalchemy.orm import relationship
2
 from sqlalchemy.orm import relationship
3
 from sqlalchemy.sql import func
3
 from sqlalchemy.sql import func
4
-from datetime import datetime
4
+from datetime import datetime, timezone, timedelta
5
 from ..database import Base
5
 from ..database import Base
6
 
6
 
7
 class Token(Base):
7
 class Token(Base):
25
     created_at = Column(DateTime(timezone=True), server_default=func.now())
25
     created_at = Column(DateTime(timezone=True), server_default=func.now())
26
     
26
     
27
     def is_expired(self):
27
     def is_expired(self):
28
-        return datetime.utcnow() > self.expires_at
28
+        return datetime.now(timezone.utc) > self.expires_at
29
     
29
     
30
     def __repr__(self):
30
     def __repr__(self):
31
         return f"<Token {self.token_type} for user {self.user_id}>"
31
         return f"<Token {self.token_type} for user {self.user_id}>"

+ 5
- 1
app/models/user.py Просмотреть файл

1
 # app/models/user.py
1
 # app/models/user.py
2
 from sqlalchemy import Column, Integer, String, Boolean, DateTime, Text
2
 from sqlalchemy import Column, Integer, String, Boolean, DateTime, Text
3
 from sqlalchemy.sql import func
3
 from sqlalchemy.sql import func
4
-from ..database import Base
4
+
5
+# 从database.py导入Base(确保是DeclarativeBase)
6
+from app.database import Base
5
 
7
 
6
 class User(Base):
8
 class User(Base):
7
     __tablename__ = "users"
9
     __tablename__ = "users"
10
     username = Column(String(50), unique=True, index=True, nullable=False)
12
     username = Column(String(50), unique=True, index=True, nullable=False)
11
     email = Column(String(100), unique=True, index=True, nullable=False)
13
     email = Column(String(100), unique=True, index=True, nullable=False)
12
     hashed_password = Column(String(255), nullable=False)
14
     hashed_password = Column(String(255), nullable=False)
15
+    full_name = Column(String(100))
16
+    avatar = Column(String(255))
13
     
17
     
14
     # 用户状态
18
     # 用户状态
15
     is_active = Column(Boolean, default=True)
19
     is_active = Column(Boolean, default=True)

+ 1
- 1
app/schemas/auth.py Просмотреть файл

1
 from pydantic import BaseModel, Field
1
 from pydantic import BaseModel, Field
2
 from typing import Optional
2
 from typing import Optional
3
-from datetime import datetime
3
+from datetime import datetime, timezone, timedelta
4
 from .user import UserResponse
4
 from .user import UserResponse
5
 
5
 
6
 class TokenBase(BaseModel):
6
 class TokenBase(BaseModel):

+ 1
- 1
app/schemas/token.py Просмотреть файл

1
 # app/schemas/token.py
1
 # app/schemas/token.py
2
 from pydantic import BaseModel, Field
2
 from pydantic import BaseModel, Field
3
 from typing import Optional
3
 from typing import Optional
4
-from datetime import datetime
4
+from datetime import datetime, timezone, timedelta
5
 
5
 
6
 class Token(BaseModel):
6
 class Token(BaseModel):
7
     """令牌响应模型"""
7
     """令牌响应模型"""

+ 1
- 1
app/schemas/user.py Просмотреть файл

1
 # app/schemas/user.py
1
 # app/schemas/user.py
2
 from pydantic import BaseModel, EmailStr, Field, validator
2
 from pydantic import BaseModel, EmailStr, Field, validator
3
 from typing import Optional
3
 from typing import Optional
4
-from datetime import datetime
4
+from datetime import datetime, timezone, timedelta
5
 
5
 
6
 class UserBase(BaseModel):
6
 class UserBase(BaseModel):
7
     """用户基础模型"""
7
     """用户基础模型"""

+ 242
- 124
app/services/auth_service.py Просмотреть файл

1
 from typing import Optional, Dict, Any, Tuple
1
 from typing import Optional, Dict, Any, Tuple
2
-from datetime import datetime, timedelta
3
-from sqlalchemy.orm import Session
2
+from datetime import datetime, timezone, timedelta
3
+from sqlalchemy import select, update, delete, func, or_
4
+from sqlalchemy.ext.asyncio import AsyncSession
4
 from fastapi import HTTPException, status, BackgroundTasks
5
 from fastapi import HTTPException, status, BackgroundTasks
5
 import secrets
6
 import secrets
6
 import string
7
 import string
7
 
8
 
8
-from ..models.user import User
9
-from ..models.token import Token
9
+from app.models.user import User
10
+from app.models.token import Token
10
 from ..schemas.auth import LoginRequest, TokenResponse
11
 from ..schemas.auth import LoginRequest, TokenResponse
11
-from ..core.security import (
12
+from app.core.security import (
12
     verify_password,
13
     verify_password,
13
     create_access_token,
14
     create_access_token,
14
     create_refresh_token,
15
     create_refresh_token,
17
     decode_token,
18
     decode_token,
18
     generate_verification_code
19
     generate_verification_code
19
 )
20
 )
20
-from ..core.email import email_service
21
+from app.core.email import email_service
21
 from ..config import settings
22
 from ..config import settings
22
 
23
 
23
 class AuthService:
24
 class AuthService:
24
-    def __init__(self, db: Session):
25
+    def __init__(self, db: AsyncSession):
25
         self.db = db
26
         self.db = db
26
     
27
     
27
     async def register_user(
28
     async def register_user(
31
         ip_address: Optional[str] = None,
32
         ip_address: Optional[str] = None,
32
         user_agent: Optional[str] = None
33
         user_agent: Optional[str] = None
33
     ) -> User:
34
     ) -> User:
34
-        """注册新用户"""
35
-        # 检查用户是否存在
36
-        existing_user = self.db.query(User).filter(
37
-            (User.username == user_data["username"]) | 
38
-            (User.email == user_data["email"])
39
-        ).first()
35
+        """注册新用户(SQLAlchemy 2.0异步版)"""
36
+        # 1. 检查用户是否存在 - 使用异步查询
37
+        stmt = select(User).where(
38
+            or_(
39
+                User.username == user_data["username"],
40
+                User.email == user_data["email"]
41
+            )
42
+        )
43
+        result = await self.db.execute(stmt)
44
+        existing_user = result.scalar_one_or_none()
40
         
45
         
41
         if existing_user:
46
         if existing_user:
42
             if existing_user.username == user_data["username"]:
47
             if existing_user.username == user_data["username"]:
50
                     detail="Email already registered"
55
                     detail="Email already registered"
51
                 )
56
                 )
52
         
57
         
53
-        # 创建用户
58
+        # 2. 创建用户
54
         from ..core.security import get_password_hash
59
         from ..core.security import get_password_hash
55
         hashed_password = get_password_hash(user_data["password"])
60
         hashed_password = get_password_hash(user_data["password"])
56
         
61
         
64
             is_verified=False
69
             is_verified=False
65
         )
70
         )
66
         
71
         
67
-        # 生成验证码
72
+        # 3. 生成验证码
68
         verification_code = generate_verification_code()
73
         verification_code = generate_verification_code()
69
         user.verification_code = verification_code
74
         user.verification_code = verification_code
70
-        user.verification_code_expires = datetime.utcnow() + timedelta(hours=24)
75
+        user.verification_code_expires = datetime.now(timezone.utc) + timedelta(hours=24)
71
         
76
         
77
+        # 4. 异步保存用户
72
         self.db.add(user)
78
         self.db.add(user)
73
-        self.db.commit()
74
-        self.db.refresh(user)
79
+        await self.db.commit()
80
+        await self.db.refresh(user)
75
         
81
         
76
-        # 发送验证邮件(后台任务)
82
+        # 5. 发送验证邮件(后台任务)
77
         verification_token = create_verification_token(user.email)
83
         verification_token = create_verification_token(user.email)
78
         verification_url = f"{settings.FRONTEND_URL}/verify-email?token={verification_token}"
84
         verification_url = f"{settings.FRONTEND_URL}/verify-email?token={verification_token}"
79
         
85
         
93
         ip_address: Optional[str] = None,
99
         ip_address: Optional[str] = None,
94
         user_agent: Optional[str] = None
100
         user_agent: Optional[str] = None
95
     ) -> Tuple[TokenResponse, User]:
101
     ) -> Tuple[TokenResponse, User]:
96
-        """用户登录"""
97
-        user = self.db.query(User).filter(
98
-            (User.username == login_data.username) | 
99
-            (User.email == login_data.username)
100
-        ).first()
102
+        """用户登录(SQLAlchemy 2.0异步版)"""
103
+        # 1. 查找用户 - 异步查询
104
+        stmt = select(User).where(
105
+            or_(
106
+                User.username == login_data.username,
107
+                User.email == login_data.username
108
+            )
109
+        )
110
+        result = await self.db.execute(stmt)
111
+        user = result.scalar_one_or_none()
101
         
112
         
102
         if not user:
113
         if not user:
103
             raise HTTPException(
114
             raise HTTPException(
105
                 detail="Incorrect username or password"
116
                 detail="Incorrect username or password"
106
             )
117
             )
107
         
118
         
108
-        # 检查账户是否被锁定
109
-        if user.is_locked and user.locked_until and user.locked_until > datetime.utcnow():
119
+        # 2. 检查账户是否被锁定
120
+        if user.is_locked and user.locked_until and user.locked_until > datetime.now(timezone.utc):
110
             raise HTTPException(
121
             raise HTTPException(
111
                 status_code=status.HTTP_423_LOCKED,
122
                 status_code=status.HTTP_423_LOCKED,
112
                 detail=f"Account is locked until {user.locked_until}"
123
                 detail=f"Account is locked until {user.locked_until}"
113
             )
124
             )
114
         
125
         
115
-        # 验证密码
126
+        # 3. 验证密码
116
         if not verify_password(login_data.password, user.hashed_password):
127
         if not verify_password(login_data.password, user.hashed_password):
117
-            # 记录失败尝试
118
-            user.failed_login_attempts += 1
128
+            # 记录失败尝试 - 异步更新
129
+            new_attempts = (user.failed_login_attempts or 0) + 1
130
+            
131
+            update_data = {
132
+                "failed_login_attempts": new_attempts,
133
+                "updated_at": datetime.now(timezone.utc)
134
+            }
119
             
135
             
120
             # 如果失败次数超过5次,锁定账户
136
             # 如果失败次数超过5次,锁定账户
121
-            if user.failed_login_attempts >= 5:
122
-                user.is_locked = True
123
-                user.locked_until = datetime.utcnow() + timedelta(minutes=30)
137
+            if new_attempts >= 5:
138
+                update_data.update({
139
+                    "is_locked" : True,
140
+                    "locked_until": datetime.now(timezone.utc) + timedelta(minutes=30)
141
+                })
124
             
142
             
125
-            self.db.commit()
143
+            update_stmt = (
144
+                update(User)
145
+                .where(User.id == user.id)
146
+                .values(**update_data)
147
+            )
148
+            
149
+            await self.db.execute(update_stmt)
150
+            await self.db.commit()
126
             
151
             
127
             raise HTTPException(
152
             raise HTTPException(
128
                 status_code=status.HTTP_401_UNAUTHORIZED,
153
                 status_code=status.HTTP_401_UNAUTHORIZED,
129
                 detail="Incorrect username or password"
154
                 detail="Incorrect username or password"
130
             )
155
             )
131
         
156
         
132
-        # 检查邮箱是否已验证
157
+        # 4. 检查邮箱是否已验证
133
         if not user.is_verified:
158
         if not user.is_verified:
134
             raise HTTPException(
159
             raise HTTPException(
135
                 status_code=status.HTTP_403_FORBIDDEN,
160
                 status_code=status.HTTP_403_FORBIDDEN,
136
                 detail="Email not verified"
161
                 detail="Email not verified"
137
             )
162
             )
138
         
163
         
139
-        # 检查账户是否激活
164
+        # 5. 检查账户是否激活
140
         if not user.is_active:
165
         if not user.is_active:
141
             raise HTTPException(
166
             raise HTTPException(
142
                 status_code=status.HTTP_403_FORBIDDEN,
167
                 status_code=status.HTTP_403_FORBIDDEN,
143
                 detail="Account is not active"
168
                 detail="Account is not active"
144
             )
169
             )
145
         
170
         
146
-        # 重置失败尝试次数
147
-        user.failed_login_attempts = 0
148
-        user.last_login = datetime.utcnow()
149
-        user.is_locked = False
150
-        user.locked_until = None
151
-        self.db.commit()
171
+        # 6. 重置失败尝试次数 - 异步更新
172
+        reset_stmt = (
173
+            update(User)
174
+            .where(User.id == user.id)
175
+            .values(
176
+                failed_login_attempts = 0,
177
+                last_login = datetime.now(timezone.utc),
178
+                is_locked = False,
179
+                locked_until = None,
180
+            )
181
+        )
182
+        
183
+        await self.db.execute(reset_stmt)
152
         
184
         
153
-        # 创建令牌
185
+        # 7. 创建令牌
154
         access_token = create_access_token({"sub": user.username, "user_id": user.id})
186
         access_token = create_access_token({"sub": user.username, "user_id": user.id})
155
         refresh_token = create_refresh_token({"sub": user.username, "user_id": user.id})
187
         refresh_token = create_refresh_token({"sub": user.username, "user_id": user.id})
156
         
188
         
157
-        # 保存刷新令牌到数据库
189
+        # 8. 保存刷新令牌到数据库 - 异步添加
158
         refresh_token_entry = Token(
190
         refresh_token_entry = Token(
159
             token=refresh_token,
191
             token=refresh_token,
160
             token_type="refresh",
192
             token_type="refresh",
161
-            expires_at=datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS),
193
+            expires_at=datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS),
162
             user_id=user.id,
194
             user_id=user.id,
163
             ip_address=ip_address,
195
             ip_address=ip_address,
164
             user_agent=user_agent
196
             user_agent=user_agent
165
         )
197
         )
166
         
198
         
167
         self.db.add(refresh_token_entry)
199
         self.db.add(refresh_token_entry)
168
-        self.db.commit()
169
         
200
         
170
-        # 构建响应
201
+        # 9. 提交所有更改
202
+        await self.db.commit()
203
+        
204
+        # 10. 刷新用户对象以获取最新数据
205
+        await self.db.refresh(user)
206
+        
207
+        # 11. 构建响应
171
         token_response = TokenResponse(
208
         token_response = TokenResponse(
172
             access_token=access_token,
209
             access_token=access_token,
173
             refresh_token=refresh_token,
210
             refresh_token=refresh_token,
182
         token: str,
219
         token: str,
183
         code: Optional[str] = None
220
         code: Optional[str] = None
184
     ) -> bool:
221
     ) -> bool:
185
-        """验证邮箱"""
222
+        """验证邮箱(SQLAlchemy 2.0异步版)"""
223
+        # 1. 解码令牌
186
         payload = decode_token(token)
224
         payload = decode_token(token)
187
         
225
         
188
         if not payload or payload.get("type") != "verify":
226
         if not payload or payload.get("type") != "verify":
198
                 detail="Invalid token payload"
236
                 detail="Invalid token payload"
199
             )
237
             )
200
         
238
         
201
-        user = self.db.query(User).filter(User.email == email).first()
239
+        # 2. 查找用户 - 异步查询
240
+        stmt = select(User).where(User.email == email)
241
+        result = await self.db.execute(stmt)
242
+        user = result.scalar_one_or_none()
202
         
243
         
203
         if not user:
244
         if not user:
204
             raise HTTPException(
245
             raise HTTPException(
212
                 detail="Email already verified"
253
                 detail="Email already verified"
213
             )
254
             )
214
         
255
         
215
-        # 验证码验证
256
+        # 3. 验证码验证
216
         if code:
257
         if code:
258
+            now = datetime.now(timezone.utc)
217
             if (not user.verification_code or 
259
             if (not user.verification_code or 
218
                 user.verification_code != code or
260
                 user.verification_code != code or
219
                 not user.verification_code_expires or
261
                 not user.verification_code_expires or
220
-                user.verification_code_expires < datetime.utcnow()):
262
+                user.verification_code_expires < now):
221
                 raise HTTPException(
263
                 raise HTTPException(
222
                     status_code=status.HTTP_400_BAD_REQUEST,
264
                     status_code=status.HTTP_400_BAD_REQUEST,
223
                     detail="Invalid or expired verification code"
265
                     detail="Invalid or expired verification code"
224
                 )
266
                 )
225
         
267
         
226
-        # 更新用户状态
227
-        user.is_verified = True
228
-        user.is_active = True
229
-        user.verification_code = None
230
-        user.verification_code_expires = None
231
-        self.db.commit()
268
+        # 4. 更新用户状态 - 异步更新
269
+        update_stmt = (
270
+            update(User)
271
+            .where(User.id == user.id)
272
+            .values(
273
+                is_verified = True,
274
+                is_active = True,
275
+                verification_code = None,
276
+                verification_code_expires = None,
277
+                updated_at = datetime.now(timezone.utc)
278
+            )
279
+        )
280
+        
281
+        await self.db.execute(update_stmt)
282
+        await self.db.commit()
232
         
283
         
233
-        # 发送欢迎邮件
284
+        # 5. 发送欢迎邮件(假设是异步的)
234
         await email_service.send_welcome_email(user.email, user.username)
285
         await email_service.send_welcome_email(user.email, user.username)
235
         
286
         
236
         return True
287
         return True
240
         email: str,
291
         email: str,
241
         background_tasks: BackgroundTasks
292
         background_tasks: BackgroundTasks
242
     ) -> bool:
293
     ) -> bool:
243
-        """重新发送验证邮件"""
244
-        user = self.db.query(User).filter(User.email == email).first()
294
+        """重新发送验证邮件(SQLAlchemy 2.0异步版)"""
295
+        # 1. 查找用户 - 异步查询
296
+        stmt = select(User).where(User.email == email)
297
+        result = await self.db.execute(stmt)
298
+        user = result.scalar_one_or_none()
245
         
299
         
246
         if not user:
300
         if not user:
247
             # 出于安全考虑,即使用户不存在也返回成功
301
             # 出于安全考虑,即使用户不存在也返回成功
253
                 detail="Email already verified"
307
                 detail="Email already verified"
254
             )
308
             )
255
         
309
         
256
-        # 生成新的验证码
310
+        # 2. 生成新的验证码
257
         verification_code = generate_verification_code()
311
         verification_code = generate_verification_code()
258
-        user.verification_code = verification_code
259
-        user.verification_code_expires = datetime.utcnow() + timedelta(hours=24)
260
         
312
         
261
-        self.db.commit()
313
+        # 3. 更新用户验证信息 - 异步更新
314
+        update_stmt = (
315
+            update(User)
316
+            .where(User.id == user.id)
317
+            .values(
318
+                verification_code = verification_code,
319
+                verification_code_expires = datetime.now(timezone.utc) + timedelta(hours=24),
320
+                updated_at=datetime.now(timezone.utc)
321
+            )
322
+        )
262
         
323
         
263
-        # 发送验证邮件
324
+        await self.db.execute(update_stmt)
325
+        await self.db.commit()
326
+        
327
+        # 4. 发送验证邮件
264
         verification_token = create_verification_token(user.email)
328
         verification_token = create_verification_token(user.email)
265
         verification_url = f"{settings.FRONTEND_URL}/verify-email?token={verification_token}"
329
         verification_url = f"{settings.FRONTEND_URL}/verify-email?token={verification_token}"
266
         
330
         
279
         email: str,
343
         email: str,
280
         background_tasks: BackgroundTasks
344
         background_tasks: BackgroundTasks
281
     ) -> bool:
345
     ) -> bool:
282
-        """请求密码重置"""
283
-        user = self.db.query(User).filter(User.email == email).first()
346
+        """请求密码重置(SQLAlchemy 2.0异步版)"""
347
+        # 1. 查找用户 - 异步查询
348
+        stmt = select(User).where(User.email == email)
349
+        result = await self.db.execute(stmt)
350
+        user = result.scalar_one_or_none()
284
         
351
         
285
         if not user:
352
         if not user:
286
             # 出于安全考虑,即使用户不存在也返回成功
353
             # 出于安全考虑,即使用户不存在也返回成功
292
                 detail="Account is not active"
359
                 detail="Account is not active"
293
             )
360
             )
294
         
361
         
295
-        # 生成重置令牌
362
+        # 2. 生成重置令牌
296
         reset_token = create_reset_token(user.email)
363
         reset_token = create_reset_token(user.email)
297
         reset_url = f"{settings.FRONTEND_URL}/reset-password?token={reset_token}"
364
         reset_url = f"{settings.FRONTEND_URL}/reset-password?token={reset_token}"
298
         
365
         
299
-        # 保存重置令牌到数据库
366
+        # 3. 保存重置令牌到数据库 - 异步添加
300
         reset_token_entry = Token(
367
         reset_token_entry = Token(
301
             token=reset_token,
368
             token=reset_token,
302
             token_type="reset",
369
             token_type="reset",
303
-            expires_at=datetime.utcnow() + timedelta(minutes=settings.RESET_TOKEN_EXPIRE_MINUTES),
370
+            expires_at=datetime.now(timezone.utc) + timedelta(minutes=settings.RESET_TOKEN_EXPIRE_MINUTES),
304
             user_id=user.id
371
             user_id=user.id
305
         )
372
         )
306
         
373
         
307
         self.db.add(reset_token_entry)
374
         self.db.add(reset_token_entry)
308
-        self.db.commit()
375
+        await self.db.commit()
309
         
376
         
310
-        # 发送重置邮件
377
+        # 4. 发送重置邮件
311
         background_tasks.add_task(
378
         background_tasks.add_task(
312
             email_service.send_password_reset_email,
379
             email_service.send_password_reset_email,
313
             user.email,
380
             user.email,
322
         token: str,
389
         token: str,
323
         new_password: str
390
         new_password: str
324
     ) -> bool:
391
     ) -> bool:
325
-        """重置密码"""
326
-        # 验证令牌
392
+        """重置密码(SQLAlchemy 2.0异步版)"""
393
+        # 1. 验证令牌
327
         payload = decode_token(token)
394
         payload = decode_token(token)
328
         
395
         
329
         if not payload or payload.get("type") != "reset":
396
         if not payload or payload.get("type") != "reset":
339
                 detail="Invalid token payload"
406
                 detail="Invalid token payload"
340
             )
407
             )
341
         
408
         
342
-        # 检查令牌是否在数据库中且未过期
343
-        token_entry = self.db.query(Token).filter(
409
+        # 2. 检查令牌是否在数据库中且未过期 - 异步查询
410
+        stmt = select(Token).where(
344
             Token.token == token,
411
             Token.token == token,
345
             Token.token_type == "reset",
412
             Token.token_type == "reset",
346
             Token.is_revoked == False,
413
             Token.is_revoked == False,
347
-            Token.expires_at > datetime.utcnow()
348
-        ).first()
414
+            Token.expires_at > datetime.now(timezone.utc)
415
+        )
416
+        result = await self.db.execute(stmt)
417
+        token_entry = result.scalar_one_or_none()
349
         
418
         
350
         if not token_entry:
419
         if not token_entry:
351
             raise HTTPException(
420
             raise HTTPException(
353
                 detail="Invalid or expired reset token"
422
                 detail="Invalid or expired reset token"
354
             )
423
             )
355
         
424
         
356
-        user = self.db.query(User).filter(User.email == email).first()
425
+        # 3. 查找用户 - 异步查询
426
+        user_stmt = select(User).where(User.email == email)
427
+        user_result = await self.db.execute(user_stmt)
428
+        user = user_result.scalar_one_or_none()
357
         
429
         
358
         if not user:
430
         if not user:
359
             raise HTTPException(
431
             raise HTTPException(
367
                 detail="Account is not active"
439
                 detail="Account is not active"
368
             )
440
             )
369
         
441
         
370
-        # 更新密码
442
+        # 4. 更新密码 - 异步更新
371
         from ..core.security import get_password_hash
443
         from ..core.security import get_password_hash
372
-        user.hashed_password = get_password_hash(new_password)
373
-        user.last_password_change = datetime.utcnow()
374
         
444
         
375
-        # 撤销所有现有令牌
376
-        self.db.query(Token).filter(
377
-            Token.user_id == user.id,
378
-            Token.token_type.in_(["access", "refresh"])
379
-        ).update({"is_revoked": True})
445
+        update_user_stmt = (
446
+            update(User)
447
+            .where(User.id == user.id)
448
+            .values(
449
+                hashed_password = get_password_hash(new_password),
450
+                last_password_change = datetime.now(timezone.utc),
451
+                updated_at = datetime.now(timezone.utc)
452
+            )
453
+        )
454
+        
455
+        # 5. 撤销所有现有令牌 - 异步更新
456
+        revoke_tokens_stmt = (
457
+            update(Token)
458
+            .where(
459
+                Token.user_id == user.id,
460
+                Token.token_type.in_(["access", "refresh"])
461
+            )
462
+            .values(is_revoked = True)
463
+        )
380
         
464
         
381
-        # 标记重置令牌为已使用
382
-        token_entry.is_revoked = True
465
+        # 6. 标记重置令牌为已使用 - 异步更新
466
+        revoke_reset_token_stmt = (
467
+            update(Token)
468
+            .where(Token.token == token)
469
+            .values(is_revoked=True)
470
+        )
383
         
471
         
384
-        self.db.commit()
472
+        # 执行所有更新
473
+        await self.db.execute(update_user_stmt)
474
+        await self.db.execute(revoke_tokens_stmt)
475
+        await self.db.execute(revoke_reset_token_stmt)
476
+        await self.db.commit()
385
         
477
         
386
         return True
478
         return True
387
     
479
     
391
         ip_address: Optional[str] = None,
483
         ip_address: Optional[str] = None,
392
         user_agent: Optional[str] = None
484
         user_agent: Optional[str] = None
393
     ) -> TokenResponse:
485
     ) -> TokenResponse:
394
-        """刷新访问令牌"""
395
-        # 验证刷新令牌
486
+        """刷新访问令牌(SQLAlchemy 2.0异步版)"""
487
+        # 1. 验证刷新令牌
396
         payload = decode_token(refresh_token)
488
         payload = decode_token(refresh_token)
397
         
489
         
398
         if not payload or payload.get("type") != "refresh":
490
         if not payload or payload.get("type") != "refresh":
401
                 detail="Invalid refresh token"
493
                 detail="Invalid refresh token"
402
             )
494
             )
403
         
495
         
404
-        # 检查令牌是否在数据库中且有效
405
-        token_entry = self.db.query(Token).filter(
496
+        # 2. 检查令牌是否在数据库中且有效 - 异步查询
497
+        token_stmt = select(Token).where(
406
             Token.token == refresh_token,
498
             Token.token == refresh_token,
407
             Token.token_type == "refresh",
499
             Token.token_type == "refresh",
408
             Token.is_revoked == False,
500
             Token.is_revoked == False,
409
-            Token.expires_at > datetime.utcnow()
410
-        ).first()
501
+            Token.expires_at > datetime.now(timezone.utc)
502
+        )
503
+        token_result = await self.db.execute(token_stmt)
504
+        token_entry = token_result.scalar_one_or_none()
411
         
505
         
412
         if not token_entry:
506
         if not token_entry:
413
             raise HTTPException(
507
             raise HTTPException(
424
                 detail="Invalid token payload"
518
                 detail="Invalid token payload"
425
             )
519
             )
426
         
520
         
427
-        user = self.db.query(User).filter(
521
+        # 3. 查找用户 - 异步查询
522
+        user_stmt = select(User).where(
428
             User.id == user_id,
523
             User.id == user_id,
429
             User.username == username,
524
             User.username == username,
430
             User.is_active == True
525
             User.is_active == True
431
-        ).first()
526
+        )
527
+        user_result = await self.db.execute(user_stmt)
528
+        user = user_result.scalar_one_or_none()
432
         
529
         
433
         if not user:
530
         if not user:
434
             raise HTTPException(
531
             raise HTTPException(
436
                 detail="User not found or inactive"
533
                 detail="User not found or inactive"
437
             )
534
             )
438
         
535
         
439
-        # 创建新的访问令牌
536
+        # 4. 创建新的访问令牌
440
         access_token = create_access_token({"sub": user.username, "user_id": user.id})
537
         access_token = create_access_token({"sub": user.username, "user_id": user.id})
441
         
538
         
442
-        # 可选:创建新的刷新令牌(刷新令牌轮换)
539
+        # 5. 创建新的刷新令牌(刷新令牌轮换)
443
         new_refresh_token = create_refresh_token({"sub": user.username, "user_id": user.id})
540
         new_refresh_token = create_refresh_token({"sub": user.username, "user_id": user.id})
444
         
541
         
445
-        # 保存新的刷新令牌
542
+        # 6. 保存新的刷新令牌 - 异步添加
446
         new_refresh_token_entry = Token(
543
         new_refresh_token_entry = Token(
447
             token=new_refresh_token,
544
             token=new_refresh_token,
448
-            token_type="refresh",
449
-            expires_at=datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS),
450
-            user_id=user.id,
451
-            ip_address=ip_address,
452
-            user_agent=user_agent
545
+            token_type = "refresh",
546
+            expires_at = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS),
547
+            user_id = user.id,
548
+            ip_address = ip_address,
549
+            user_agent = user_agent
453
         )
550
         )
454
         
551
         
455
-        # 标记旧令牌为已撤销
456
-        token_entry.is_revoked = True
552
+        # 7. 标记旧令牌为已撤销 - 异步更新
553
+        revoke_stmt = (
554
+            update(Token)
555
+            .where(Token.token == refresh_token)
556
+            .values(is_revoked = True)
557
+        )
457
         
558
         
559
+        # 执行更新和添加操作
560
+        await self.db.execute(revoke_stmt)
458
         self.db.add(new_refresh_token_entry)
561
         self.db.add(new_refresh_token_entry)
459
-        self.db.commit()
562
+        await self.db.commit()
460
         
563
         
461
-        # 构建响应
564
+        # 8. 构建响应
462
         return TokenResponse(
565
         return TokenResponse(
463
             access_token=access_token,
566
             access_token=access_token,
464
             refresh_token=new_refresh_token,
567
             refresh_token=new_refresh_token,
470
         self,
573
         self,
471
         refresh_token: str
574
         refresh_token: str
472
     ) -> bool:
575
     ) -> bool:
473
-        """用户登出"""
474
-        # 撤销刷新令牌
475
-        token_entry = self.db.query(Token).filter(
576
+        """用户登出(SQLAlchemy 2.0异步版)"""
577
+        # 1. 查找刷新令牌 - 异步查询
578
+        stmt = select(Token).where(
476
             Token.token == refresh_token,
579
             Token.token == refresh_token,
477
             Token.token_type == "refresh"
580
             Token.token_type == "refresh"
478
-        ).first()
581
+        )
582
+        result = await self.db.execute(stmt)
583
+        token_entry = result.scalar_one_or_none()
479
         
584
         
585
+        # 2. 如果找到令牌,则撤销它 - 异步更新
480
         if token_entry:
586
         if token_entry:
481
-            token_entry.is_revoked = True
482
-            self.db.commit()
587
+            update_stmt = (
588
+                update(Token)
589
+                .where(Token.token == refresh_token)
590
+                .values(is_revoked = True)
591
+            )
592
+            await self.db.execute(update_stmt)
593
+            await self.db.commit()
483
         
594
         
484
         return True
595
         return True
485
     
596
     
487
         self,
598
         self,
488
         user_id: int
599
         user_id: int
489
     ) -> bool:
600
     ) -> bool:
490
-        """撤销用户的所有令牌"""
491
-        self.db.query(Token).filter(
492
-            Token.user_id == user_id,
493
-            Token.token_type.in_(["access", "refresh"]),
494
-            Token.is_revoked == False
495
-        ).update({"is_revoked": True})
496
-        
497
-        self.db.commit()
498
-        return True
601
+        """撤销用户的所有令牌(SQLAlchemy 2.0异步版)"""
602
+        # 撤销用户的所有访问和刷新令牌 - 异步更新
603
+        stmt = (
604
+            update(Token)
605
+            .where(
606
+                Token.user_id == user_id,
607
+                Token.token_type.in_(["access", "refresh"]),
608
+                Token.is_revoked == False
609
+            )
610
+            .values(is_revoked = True)
611
+        )
612
+        
613
+        result = await self.db.execute(stmt)
614
+        await self.db.commit()
615
+        
616
+        return result.rowcount > 0

+ 115
- 66
app/services/user_service.py Просмотреть файл

1
 # app/services/user_service.py
1
 # app/services/user_service.py
2
-from sqlalchemy.orm import Session
3
-from sqlalchemy import or_
4
-from typing import Optional, List
5
-from datetime import datetime
2
+from sqlalchemy import select, update, delete, func, or_
3
+from sqlalchemy.orm import selectinload
4
+from sqlalchemy.ext.asyncio import AsyncSession
5
+from typing import List, Optional, Dict, Any
6
+from datetime import datetime, timezone, timedelta
6
 
7
 
7
-from ..models.user import User
8
-from ..core.security import password_hasher, password_validator
8
+from app.models.user import User
9
+from app.core.security import password_hasher, password_validator
9
 from ..schemas.user import UserCreate, UserUpdate, UserResponse
10
 from ..schemas.user import UserCreate, UserUpdate, UserResponse
10
 
11
 
11
 class UserService:
12
 class UserService:
12
     """用户服务"""
13
     """用户服务"""
13
     
14
     
14
-    def __init__(self, db: Session):
15
+    def __init__(self, db: AsyncSession):  # 异步Session
15
         self.db = db
16
         self.db = db
16
     
17
     
17
     async def create_user(self, user_data: UserCreate) -> User:
18
     async def create_user(self, user_data: UserCreate) -> User:
18
-        """创建用户"""
19
+        """创建用户(异步版本)"""
19
         # 验证密码强度
20
         # 验证密码强度
20
         is_valid, error_msg = password_validator.validate_password_strength(user_data.password)
21
         is_valid, error_msg = password_validator.validate_password_strength(user_data.password)
21
         if not is_valid:
22
         if not is_valid:
22
             raise ValueError(error_msg)
23
             raise ValueError(error_msg)
23
         
24
         
24
-        # 检查用户是否已存在
25
-        existing_user = self.db.query(User).filter(
25
+        # 检查用户是否已存在 - 使用异步查询
26
+        stmt = select(User).where(
26
             or_(
27
             or_(
27
                 User.username == user_data.username,
28
                 User.username == user_data.username,
28
                 User.email == user_data.email
29
                 User.email == user_data.email
29
             )
30
             )
30
-        ).first()
31
+        )
32
+        result = await self.db.execute(stmt)
33
+        existing_user = result.scalar_one_or_none()
31
         
34
         
32
         if existing_user:
35
         if existing_user:
33
             if existing_user.username == user_data.username:
36
             if existing_user.username == user_data.username:
38
         # 哈希密码
41
         # 哈希密码
39
         hashed_password = password_hasher.hash_password(user_data.password)
42
         hashed_password = password_hasher.hash_password(user_data.password)
40
         
43
         
41
-        # 创建用户
44
+        # 创建用户对象
42
         user = User(
45
         user = User(
43
             username=user_data.username,
46
             username=user_data.username,
44
             email=user_data.email,
47
             email=user_data.email,
48
             is_verified=False
51
             is_verified=False
49
         )
52
         )
50
         
53
         
54
+        # 异步保存
51
         self.db.add(user)
55
         self.db.add(user)
52
-        self.db.commit()
53
-        self.db.refresh(user)
56
+        await self.db.commit()
57
+        await self.db.refresh(user)
54
         
58
         
55
         return user
59
         return user
56
     
60
     
57
     async def get_user_by_id(self, user_id: int) -> Optional[User]:
61
     async def get_user_by_id(self, user_id: int) -> Optional[User]:
58
-        """通过ID获取用户"""
59
-        return self.db.query(User).filter(User.id == user_id).first()
62
+        """通过ID获取用户(异步)"""
63
+        stmt = select(User).where(User.id == user_id)
64
+        result = await self.db.execute(stmt)
65
+        return result.scalar_one_or_none()
60
     
66
     
61
     async def get_user_by_username(self, username: str) -> Optional[User]:
67
     async def get_user_by_username(self, username: str) -> Optional[User]:
62
-        """通过用户名获取用户"""
63
-        return self.db.query(User).filter(User.username == username).first()
68
+        """通过用户名获取用户(异步)"""
69
+        stmt = select(User).where(User.username == username)
70
+        result = await self.db.execute(stmt)
71
+        return result.scalar_one_or_none()
64
     
72
     
65
     async def get_user_by_email(self, email: str) -> Optional[User]:
73
     async def get_user_by_email(self, email: str) -> Optional[User]:
66
-        """通过邮箱获取用户"""
67
-        return self.db.query(User).filter(User.email == email).first()
68
-    
74
+        """通过邮箱获取用户(异步)"""
75
+        stmt = select(User).where(User.email == email)
76
+        result = await self.db.execute(stmt)
77
+        return result.scalar_one_or_none()
78
+
69
     async def authenticate_user(self, username: str, password: str) -> Optional[User]:
79
     async def authenticate_user(self, username: str, password: str) -> Optional[User]:
70
-        """验证用户"""
71
-        # 通过用户名或邮箱查找用户
72
-        user = self.db.query(User).filter(
80
+        """验证用户(SQLAlchemy 2.0异步版)"""
81
+        # 通过用户名或邮箱查找用户 - 使用异步select
82
+        stmt = select(User).where(
73
             or_(
83
             or_(
74
                 User.username == username,
84
                 User.username == username,
75
                 User.email == username
85
                 User.email == username
76
             )
86
             )
77
-        ).first()
87
+        )
88
+        result = await self.db.execute(stmt)
89
+        user = result.scalar_one_or_none()
78
         
90
         
79
         if not user:
91
         if not user:
80
             return None
92
             return None
81
         
93
         
94
+        # 验证密码
82
         if not password_hasher.verify_password(password, user.hashed_password):
95
         if not password_hasher.verify_password(password, user.hashed_password):
83
-            # 记录失败尝试
84
-            user.failed_login_attempts += 1
85
-            if user.failed_login_attempts >= 5:
86
-                user.is_locked = True
87
-            self.db.commit()
96
+            # 记录失败尝试 - 使用异步更新
97
+            new_attempts = (user.failed_login_attempts or 0) + 1
98
+            update_stmt = (
99
+                update(User)
100
+                .where(User.id == user.id)
101
+                .values(
102
+                    failed_login_attempts=new_attempts,
103
+                    is_locked=(new_attempts >= 5)  # 5次失败后锁定
104
+                )
105
+            )
106
+            await self.db.execute(update_stmt)
107
+            await self.db.commit()
88
             return None
108
             return None
89
         
109
         
90
-        # 重置失败尝试
91
-        user.failed_login_attempts = 0
92
-        user.last_login = datetime.now()
93
-        user.is_locked = False
94
-        self.db.commit()
110
+        # 登录成功 - 重置失败尝试并更新最后登录时间
111
+        update_stmt = (
112
+            update(User)
113
+            .where(User.id == user.id)
114
+            .values(
115
+                failed_login_attempts = 0,
116
+                last_login = datetime.now(timezone.utc),
117
+                is_locked = False
118
+            )
119
+        )
95
         
120
         
96
-        return user
97
-    
121
+        await self.db.execute(update_stmt)
122
+        await self.db.commit()
123
+        
124
+        # 4. 刷新用户对象以获取最新数据
125
+        await self.db.refresh(user)
126
+
127
+        return user 
128
+
129
+    # ==================== 用户操作 ====================
98
     async def update_user(self, user_id: int, update_data: dict) -> Optional[User]:
130
     async def update_user(self, user_id: int, update_data: dict) -> Optional[User]:
99
-        """更新用户信息"""
131
+        """更新用户信息(异步)"""
100
         user = await self.get_user_by_id(user_id)
132
         user = await self.get_user_by_id(user_id)
101
         if not user:
133
         if not user:
102
             return None
134
             return None
105
             if hasattr(user, key) and value is not None:
137
             if hasattr(user, key) and value is not None:
106
                 setattr(user, key, value)
138
                 setattr(user, key, value)
107
         
139
         
108
-        user.updated_at = datetime.now()
109
-        self.db.commit()
110
-        self.db.refresh(user)
140
+        user.updated_at = datetime.now(timezone.utc)
141
+        await self.db.commit()
142
+        await self.db.refresh(user)
111
         
143
         
112
         return user
144
         return user
113
     
145
     
114
     async def change_password(self, user_id: int, current_password: str, new_password: str) -> bool:
146
     async def change_password(self, user_id: int, current_password: str, new_password: str) -> bool:
115
-        """修改密码"""
147
+        """修改用户密码(异步)"""
148
+        # 获取用户
116
         user = await self.get_user_by_id(user_id)
149
         user = await self.get_user_by_id(user_id)
117
         if not user:
150
         if not user:
118
             return False
151
             return False
128
         
161
         
129
         # 更新密码
162
         # 更新密码
130
         user.hashed_password = password_hasher.hash_password(new_password)
163
         user.hashed_password = password_hasher.hash_password(new_password)
131
-        user.last_password_change = datetime.now()
132
-        self.db.commit()
164
+        user.last_password_change = datetime.now(timezone.utc)
165
+        await self.db.commit()
133
         
166
         
134
         return True
167
         return True
135
     
168
     
136
-    async def list_users(
137
-        self, 
138
-        skip: int = 0, 
139
-        limit: int = 100,
140
-        active_only: bool = True
141
-    ) -> List[User]:
142
-        """列出用户"""
143
-        query = self.db.query(User)
169
+    # ==================== 用户列表 ====================
170
+    async def list_users(self, skip: int = 0, limit: int = 100, active_only: bool = True) -> List[User]:
171
+        """获取用户列表(异步)"""
172
+        stmt = select(User)
144
         
173
         
145
         if active_only:
174
         if active_only:
146
-            query = query.filter(User.is_active == True)
175
+            stmt = stmt.where(User.is_active == True)
176
+        
177
+        stmt = stmt.offset(skip).limit(limit)
147
         
178
         
148
-        return query.offset(skip).limit(limit).all()
179
+        result = await self.db.execute(stmt)
180
+        users = result.scalars().all()
181
+        return users
149
     
182
     
150
     async def delete_user(self, user_id: int) -> bool:
183
     async def delete_user(self, user_id: int) -> bool:
151
         """删除用户(软删除)"""
184
         """删除用户(软删除)"""
152
-        user = await self.get_user_by_id(user_id)
153
-        if not user:
154
-            return False
185
+        """删除/禁用用户(异步)"""
186
+        # 通常我们不会真正删除,而是标记为禁用
187
+        stmt = (
188
+            update(User)
189
+            .where(User.id == user_id)
190
+            .values(is_active = False, updated_at = datetime.now(timezone.utc))
191
+        )
155
         
192
         
156
-        user.is_active = False
157
-        user.updated_at = datetime.now()
158
-        self.db.commit()
193
+        result = await self.db.execute(stmt)
194
+        await self.db.commit()
159
         
195
         
160
-        return True
196
+        # 检查是否影响了一行
197
+        return result.rowcount > 0
161
     
198
     
162
     async def count_users(self, active_only: bool = True) -> int:
199
     async def count_users(self, active_only: bool = True) -> int:
163
-        """统计用户数量"""
164
-        query = self.db.query(User)
165
-        
200
+        """统计用户数量(异步)"""
166
         if active_only:
201
         if active_only:
167
-            query = query.filter(User.is_active == True)
202
+            stmt = select(func.count()).where(User.is_active == True)
203
+        else:
204
+            stmt = select(func.count())
205
+        
206
+        result = await self.db.execute(stmt)
207
+        return result.scalar() or 0
208
+    
209
+    async def update_last_login(self, user_id: int) -> None:
210
+        """更新最后登录时间(异步)"""
211
+        stmt = (
212
+            update(User)
213
+            .where(User.id == user_id)
214
+            .values(last_login=datetime.now(timezone.utc))
215
+        )
168
         
216
         
169
-        return query.count()
217
+        await self.db.execute(stmt)
218
+        await self.db.commit()