Pārlūkot izejas kodu

修改接口调用数据库操作实现异步

root 1 mēnesi atpakaļ
vecāks
revīzija
41def1c900

+ 15
- 0
.gitignore Parādīt failu

@@ -0,0 +1,15 @@
1
+# Python缓存文件
2
+# __pycache__/
3
+# *.py[cod]
4
+# *$py.class
5
+# *.so
6
+
7
+# 特定目录下的缓存文件(可选)
8
+app/__pycache__
9
+app/*/__pycache__/
10
+app/*/*/__pycache__/
11
+
12
+# 或者更通用的匹配所有__pycache__
13
+# **/__pycache__/
14
+
15
+*.db

+ 5
- 6
app/api/v1/admin.py Parādīt failu

@@ -1,19 +1,18 @@
1 1
 # app/api/admin.py
2 2
 from fastapi import APIRouter, Depends, HTTPException
3 3
 from fastapi.responses import JSONResponse
4
-from sqlalchemy.orm import Session
4
+from sqlalchemy.ext.asyncio import AsyncSession
5 5
 from sqlalchemy import text
6
-import json
7 6
 
8 7
 from ..dependencies.auth import get_current_user
9 8
 from ..models.user import User
10
-from ..database import get_db
9
+from ..database import get_async_db
11 10
 
12 11
 router = APIRouter(prefix="/admin", tags=["admin"])
13 12
 
14 13
 @router.get("/database/tables")
15 14
 async def get_database_tables(
16
-    db: Session = Depends(get_db),
15
+    db: AsyncSession = Depends(get_async_db),
17 16
     current_user: User = Depends(get_current_user)
18 17
 ):
19 18
     """获取所有表(需要管理员权限)"""
@@ -47,7 +46,7 @@ async def get_database_tables(
47 46
 async def get_table_data(
48 47
     table_name: str,
49 48
     limit: int = 100,
50
-    db: Session = Depends(get_db),
49
+    db: AsyncSession = Depends(get_async_db),
51 50
     current_user: User = Depends(get_current_user)
52 51
 ):
53 52
     """获取表数据(需要管理员权限)"""
@@ -111,7 +110,7 @@ async def get_table_data(
111 110
 @router.post("/database/query")
112 111
 async def execute_custom_query(
113 112
     query: str,
114
-    db: Session = Depends(get_db),
113
+    db: AsyncSession = Depends(get_async_db),
115 114
     current_user: User = Depends(get_current_user)
116 115
 ):
117 116
     """执行自定义查询(需要管理员权限,生产环境请谨慎)"""

+ 8
- 9
app/api/v1/auth.py Parādīt failu

@@ -1,8 +1,7 @@
1 1
 # app/api/v1/auth.py
2 2
 from fastapi import APIRouter, HTTPException, status, Request, Depends
3 3
 from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
4
-from sqlalchemy.orm import Session
5
-from typing import Optional
4
+from sqlalchemy.ext.asyncio import AsyncSession
6 5
 
7 6
 from app.schemas.user import UserCreate, UserLogin
8 7
 from app.schemas.token import TokenResponse, RefreshTokenRequest
@@ -12,7 +11,7 @@ from app.core.security import (
12 11
     verify_access_token
13 12
 )
14 13
 from app.config import settings
15
-from app.database import get_db
14
+from app.database import get_async_db
16 15
 from app.models.user import User
17 16
 from app.services.user_service import UserService
18 17
 
@@ -32,7 +31,7 @@ logger = logging.getLogger(__name__)
32 31
 async def register(
33 32
     user_data: UserCreate,
34 33
     request: Request,
35
-    db: Session = Depends(get_db)
34
+    db: AsyncSession = Depends(get_async_db)
36 35
 ):
37 36
     """用户注册"""
38 37
     user_service = UserService(db)
@@ -98,7 +97,7 @@ async def register(
98 97
 async def login(
99 98
     login_data: UserLogin,
100 99
     request: Request,
101
-    db: Session = Depends(get_db)
100
+    db: AsyncSession = Depends(get_async_db)
102 101
 ):
103 102
     """用户登录"""
104 103
     user_service = UserService(db)
@@ -174,7 +173,7 @@ async def login(
174 173
 @router.get("/me")
175 174
 async def get_current_user(
176 175
     credentials: HTTPAuthorizationCredentials = Depends(security),
177
-    db: Session = Depends(get_db)
176
+    db: AsyncSession = Depends(get_async_db)
178 177
 ):
179 178
     """获取当前用户信息"""
180 179
     token = credentials.credentials
@@ -221,7 +220,7 @@ async def get_current_user(
221 220
 @router.post("/refresh")
222 221
 async def refresh_token(
223 222
     refresh_data: RefreshTokenRequest,
224
-    db: Session = Depends(get_db)
223
+    db: AsyncSession = Depends(get_async_db)
225 224
 ):
226 225
     """刷新访问令牌"""
227 226
     refresh_token = refresh_data.refresh_token
@@ -277,7 +276,7 @@ async def refresh_token(
277 276
 @router.post("/logout")
278 277
 async def logout(
279 278
     refresh_data: RefreshTokenRequest,
280
-    db: Session = Depends(get_db),
279
+    db: AsyncSession = Depends(get_async_db),
281 280
     credentials: HTTPAuthorizationCredentials = Depends(security)
282 281
 ):
283 282
     """用户登出"""
@@ -286,7 +285,7 @@ async def logout(
286 285
     return {"message": "登出成功"}
287 286
 
288 287
 @router.get("/test-db")
289
-async def test_database(db: Session = Depends(get_db)):
288
+async def test_database(db: AsyncSession = Depends(get_async_db)):
290 289
     """测试数据库连接和用户统计"""
291 290
     user_service = UserService(db)
292 291
     

+ 47
- 34
app/api/v1/users.py Parādīt failu

@@ -1,11 +1,11 @@
1 1
 # app/api/v1/users.py
2 2
 from fastapi import APIRouter, Depends, HTTPException, status, Query
3
-from sqlalchemy.orm import Session
3
+from sqlalchemy.ext.asyncio import AsyncSession
4 4
 from typing import List, Optional
5
-from datetime import datetime
5
+from datetime import datetime, timezone, timedelta
6 6
 
7 7
 from app.schemas.user import UserProfile, UserUpdate, PasswordChange
8
-from app.database import get_db
8
+from app.database import get_async_db  # 使用异步依赖
9 9
 from app.services.user_service import UserService
10 10
 from app.core.security import password_validator
11 11
 
@@ -17,11 +17,13 @@ from app.config import settings
17 17
 router = APIRouter(prefix="/users", tags=["用户管理"])
18 18
 security = HTTPBearer()
19 19
 
20
-def get_current_user(
20
+
21
+# ==================== 异步认证依赖 ====================
22
+async def get_current_user(
21 23
     credentials: HTTPAuthorizationCredentials = Depends(security),
22
-    db: Session = Depends(get_db)
24
+    db: AsyncSession = Depends(get_async_db)  # 异步Session
23 25
 ):
24
-    """获取当前用户依赖"""
26
+    """获取当前用户依赖(异步版本)"""
25 27
     token = credentials.credentials
26 28
     
27 29
     # 验证令牌
@@ -33,8 +35,14 @@ def get_current_user(
33 35
         )
34 36
     
35 37
     username = payload.get("sub")
38
+    if not username:
39
+        raise HTTPException(
40
+            status_code=status.HTTP_401_UNAUTHORIZED,
41
+            detail="令牌无效"
42
+        )
43
+    
36 44
     user_service = UserService(db)
37
-    user = user_service.get_user_by_username(username)
45
+    user = await user_service.get_user_by_username(username)  # 使用await
38 46
     
39 47
     if not user:
40 48
         raise HTTPException(
@@ -50,38 +58,33 @@ def get_current_user(
50 58
     
51 59
     return user
52 60
 
61
+
62
+# ==================== 用户个人资料相关 ====================
53 63
 @router.get("/me", response_model=UserProfile)
54 64
 async def get_my_profile(
55 65
     current_user = Depends(get_current_user),
56
-    db: Session = Depends(get_db)
66
+    db: AsyncSession = Depends(get_async_db)  # 异步Session
57 67
 ):
58 68
     """获取我的资料"""
59
-    user_service = UserService(db)
60
-    user = await user_service.get_user_by_id(current_user.id)
61
-    
62
-    if not user:
63
-        raise HTTPException(
64
-            status_code=status.HTTP_404_NOT_FOUND,
65
-            detail="用户不存在"
66
-        )
67
-    
69
+    # 直接从current_user获取,无需再查询数据库
68 70
     return UserProfile(
69
-        id=user.id,
70
-        username=user.username,
71
-        email=user.email,
72
-        full_name=user.full_name,
73
-        is_active=user.is_active,
74
-        is_verified=user.is_verified,
75
-        created_at=user.created_at,
76
-        last_login=user.last_login,
77
-        avatar=user.avatar
71
+        id=current_user.id,
72
+        username=current_user.username,
73
+        email=current_user.email,
74
+        full_name=current_user.full_name,
75
+        is_active=current_user.is_active,
76
+        is_verified=current_user.is_verified,
77
+        created_at=current_user.created_at,
78
+        last_login=current_user.last_login,
79
+        avatar=current_user.avatar
78 80
     )
79 81
 
82
+
80 83
 @router.put("/me", response_model=UserProfile)
81 84
 async def update_my_profile(
82 85
     user_data: UserUpdate,
83 86
     current_user = Depends(get_current_user),
84
-    db: Session = Depends(get_db)
87
+    db: AsyncSession = Depends(get_async_db)
85 88
 ):
86 89
     """更新我的资料"""
87 90
     user_service = UserService(db)
@@ -97,6 +100,7 @@ async def update_my_profile(
97 100
                 detail="邮箱已被使用"
98 101
             )
99 102
     
103
+    # 更新用户信息
100 104
     updated_user = await user_service.update_user(current_user.id, update_dict)
101 105
     
102 106
     if not updated_user:
@@ -117,11 +121,12 @@ async def update_my_profile(
117 121
         avatar=updated_user.avatar
118 122
     )
119 123
 
124
+
120 125
 @router.post("/me/change-password")
121 126
 async def change_password(
122 127
     password_data: PasswordChange,
123 128
     current_user = Depends(get_current_user),
124
-    db: Session = Depends(get_db)
129
+    db: AsyncSession = Depends(get_async_db)
125 130
 ):
126 131
     """修改密码"""
127 132
     user_service = UserService(db)
@@ -147,13 +152,15 @@ async def change_password(
147 152
             detail=str(e)
148 153
         )
149 154
 
155
+
156
+# ==================== 管理员操作 ====================
150 157
 @router.get("/", response_model=List[UserProfile])
151 158
 async def list_users(
152 159
     skip: int = Query(0, ge=0),
153 160
     limit: int = Query(100, ge=1, le=100),
154 161
     active_only: bool = Query(True),
155 162
     current_user = Depends(get_current_user),
156
-    db: Session = Depends(get_db)
163
+    db: AsyncSession = Depends(get_async_db)
157 164
 ):
158 165
     """获取用户列表(需要管理员权限)"""
159 166
     if not current_user.is_superuser:
@@ -180,13 +187,14 @@ async def list_users(
180 187
         for user in users
181 188
     ]
182 189
 
190
+
183 191
 @router.get("/{user_id}", response_model=UserProfile)
184 192
 async def get_user(
185 193
     user_id: int,
186 194
     current_user = Depends(get_current_user),
187
-    db: Session = Depends(get_db)
195
+    db: AsyncSession = Depends(get_async_db)
188 196
 ):
189
-    """获取单个用户信息(需要管理员权限查看其他用户)"""
197
+    """获取单个用户信息"""
190 198
     user_service = UserService(db)
191 199
     
192 200
     # 普通用户只能查看自己的信息
@@ -196,6 +204,7 @@ async def get_user(
196 204
             detail="只能查看自己的用户信息"
197 205
         )
198 206
     
207
+    # 管理员可以查看任何人,普通用户只能查看自己(前面已校验)
199 208
     user = await user_service.get_user_by_id(user_id)
200 209
     
201 210
     if not user:
@@ -216,11 +225,12 @@ async def get_user(
216 225
         avatar=user.avatar
217 226
     )
218 227
 
228
+
219 229
 @router.delete("/{user_id}")
220 230
 async def delete_user(
221 231
     user_id: int,
222 232
     current_user = Depends(get_current_user),
223
-    db: Session = Depends(get_db)
233
+    db: AsyncSession = Depends(get_async_db)
224 234
 ):
225 235
     """删除用户(需要管理员权限)"""
226 236
     if not current_user.is_superuser:
@@ -247,11 +257,12 @@ async def delete_user(
247 257
     
248 258
     return {"message": "用户已禁用"}
249 259
 
260
+
250 261
 @router.get("/test")
251 262
 async def test_users():
252 263
     """测试用户管理 API"""
253 264
     return {
254
-        "message": "用户管理 API 正常运行(使用 SQLite 数据库)",
265
+        "message": "用户管理 API 正常运行(使用 SQLite 数据库异步版本)",
255 266
         "endpoints": {
256 267
             "GET /me": "获取我的资料",
257 268
             "PUT /me": "更新我的资料",
@@ -259,5 +270,7 @@ async def test_users():
259 270
             "GET /": "获取用户列表(管理员)",
260 271
             "GET /{user_id}": "获取单个用户",
261 272
             "DELETE /{user_id}": "删除用户(管理员)"
262
-        }
273
+        },
274
+        "async": True,
275
+        "database": "异步SQLAlchemy 2.0"
263 276
     }

+ 15
- 4
app/config.py Parādīt failu

@@ -11,7 +11,10 @@ class Settings:
11 11
     # 项目配置
12 12
     PROJECT_NAME: str = os.getenv("PROJECT_NAME", "CaiYouHui 采油会")
13 13
     VERSION: str = os.getenv("VERSION", "1.0.0")
14
+
15
+    # API配置
14 16
     API_V1_PREFIX: str = os.getenv("API_V1_PREFIX", "/api/v1")
17
+    APP_RUN_PORT: int = int(os.getenv("APP_RUN_PORT", "10003"))
15 18
     
16 19
     # 安全配置
17 20
     SECRET_KEY: str = os.getenv("SECRET_KEY", secrets.token_urlsafe(32))
@@ -19,8 +22,8 @@ class Settings:
19 22
     ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "1440"))  # 24小时
20 23
     
21 24
     # 数据库配置 - SQLite
22
-    DATABASE_URL: str = os.getenv("DATABASE_URL", "sqlite:///./caiyouhui.db")
23
-    
25
+    DATABASE_URL: str = os.getenv("DATABASE_URL", "sqlite+aiosqlite:///./caiyouhui.db")
26
+
24 27
     # CORS 配置
25 28
     BACKEND_CORS_ORIGINS: List[str] = os.getenv(
26 29
         "BACKEND_CORS_ORIGINS", 
@@ -29,11 +32,19 @@ class Settings:
29 32
     
30 33
     # 调试模式
31 34
     DEBUG: bool = os.getenv("DEBUG", "True").lower() == "true"
32
-    
35
+
36
+    # 日志配置
37
+    LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
38
+    LOG_FILE: str = os.getenv("LOG_FILE", "logs/app.log")
39
+
33 40
     # 文件上传
34 41
     UPLOAD_DIR: str = os.getenv("UPLOAD_DIR", "./uploads")
35 42
     MAX_UPLOAD_SIZE: int = int(os.getenv("MAX_UPLOAD_SIZE", "10485760"))  # 10MB
36
-    
43
+
44
+    # 性能配置
45
+    DATABASE_POOL_SIZE: int = 20
46
+    DATABASE_MAX_OVERFLOW: int = 10
47
+
37 48
     # 邮箱验证(可选,后续添加)
38 49
     SMTP_ENABLED: bool = os.getenv("SMTP_ENABLED", "False").lower() == "true"
39 50
     SMTP_HOST: str = os.getenv("SMTP_HOST", "")

+ 6
- 6
app/core/auth.py Parādīt failu

@@ -1,6 +1,6 @@
1 1
 from passlib.context import CryptContext
2 2
 from jose import JWTError, jwt
3
-from datetime import datetime, timedelta
3
+from datetime import datetime, timezone, timedelta
4 4
 from typing import Optional, Dict, Any, Union
5 5
 import secrets
6 6
 import string
@@ -23,9 +23,9 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
23 23
     to_encode = data.copy()
24 24
     
25 25
     if expires_delta:
26
-        expire = datetime.utcnow() + expires_delta
26
+        expire = datetime.now(timezone.utc) + expires_delta
27 27
     else:
28
-        expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
28
+        expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
29 29
     
30 30
     to_encode.update({"exp": expire, "type": "access"})
31 31
     encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
@@ -34,7 +34,7 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -
34 34
 def create_refresh_token(data: dict) -> str:
35 35
     """创建刷新令牌"""
36 36
     to_encode = data.copy()
37
-    expire = datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
37
+    expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
38 38
     
39 39
     to_encode.update({"exp": expire, "type": "refresh"})
40 40
     encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
@@ -43,7 +43,7 @@ def create_refresh_token(data: dict) -> str:
43 43
 def create_verification_token(email: str) -> str:
44 44
     """创建验证令牌"""
45 45
     to_encode = {"email": email, "type": "verify"}
46
-    expire = datetime.utcnow() + timedelta(hours=settings.VERIFICATION_TOKEN_EXPIRE_HOURS)
46
+    expire = datetime.now(timezone.utc) + timedelta(hours=settings.VERIFICATION_TOKEN_EXPIRE_HOURS)
47 47
     
48 48
     to_encode.update({"exp": expire})
49 49
     encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
@@ -52,7 +52,7 @@ def create_verification_token(email: str) -> str:
52 52
 def create_reset_token(email: str) -> str:
53 53
     """创建密码重置令牌"""
54 54
     to_encode = {"email": email, "type": "reset"}
55
-    expire = datetime.utcnow() + timedelta(minutes=settings.RESET_TOKEN_EXPIRE_MINUTES)
55
+    expire = datetime.now(timezone.utc) + timedelta(minutes=settings.RESET_TOKEN_EXPIRE_MINUTES)
56 56
     
57 57
     to_encode.update({"exp": expire})
58 58
     encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)

+ 5
- 5
app/core/security.py Parādīt failu

@@ -6,7 +6,7 @@ import hmac
6 6
 import time
7 7
 import json
8 8
 from typing import Optional, Dict, Any, Tuple
9
-from datetime import datetime, timedelta
9
+from datetime import datetime, timezone, timedelta
10 10
 import re
11 11
 
12 12
 # 密码哈希工具
@@ -165,16 +165,16 @@ class JWTManager:
165 165
         
166 166
         # 设置过期时间
167 167
         if expires_delta:
168
-            expire = datetime.utcnow() + expires_delta
168
+            expire = datetime.now(timezone.utc) + expires_delta
169 169
         elif expires_minutes:
170
-            expire = datetime.utcnow() + timedelta(minutes=expires_minutes)
170
+            expire = datetime.now(timezone.utc) + timedelta(minutes=expires_minutes)
171 171
         else:
172
-            expire = datetime.utcnow() + timedelta(minutes=30)  # 默认30分钟
172
+            expire = datetime.now(timezone.utc) + timedelta(minutes=30)  # 默认30分钟
173 173
         
174 174
         # 添加标准声明
175 175
         payload.update({
176 176
             "exp": int(expire.timestamp()),
177
-            "iat": int(datetime.utcnow().timestamp()),
177
+            "iat": int(datetime.now(timezone.utc).timestamp()),
178 178
             "iss": "caiyouhui-api",
179 179
             "type": token_type
180 180
         })

+ 121
- 46
app/database.py Parādīt failu

@@ -1,64 +1,139 @@
1 1
 # app/database.py
2
-from sqlalchemy import create_engine
3
-from sqlalchemy.orm import sessionmaker, Session
4
-from sqlalchemy.ext.declarative import declarative_base
5
-from contextlib import contextmanager
6
-from typing import Generator
7
-import os
2
+from sqlalchemy.orm import DeclarativeBase
3
+from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine, AsyncEngine
4
+from contextlib import asynccontextmanager
5
+from typing import AsyncGenerator
8 6
 
9 7
 from .config import settings
10 8
 
11
-# 创建数据库引擎
12
-engine = create_engine(
13
-    settings.DATABASE_URL,
14
-    connect_args={"check_same_thread": False} if "sqlite" in settings.DATABASE_URL else {},
15
-    echo=settings.DEBUG
9
+
10
+# ====================== 1. 基类定义 ======================
11
+class Base(DeclarativeBase):
12
+    """SQLAlchemy 2.0 声明式基类"""
13
+    pass
14
+
15
+
16
+# ====================== 2. 异步URL转换 ======================
17
+def ensure_async_url(database_url: str) -> str:
18
+    """确保使用异步数据库URL"""
19
+    # 如果已经是异步URL,直接返回
20
+    if any(x in database_url for x in ["+asyncpg", "+aiomysql", "+aiosqlite"]):
21
+        return database_url
22
+    
23
+    # 转换同步URL为异步URL
24
+    if database_url.startswith("postgresql://"):
25
+        return database_url.replace("postgresql://", "postgresql+asyncpg://")
26
+    elif database_url.startswith("mysql://"):
27
+        return database_url.replace("mysql://", "mysql+aiomysql://")
28
+    elif database_url.startswith("sqlite://"):
29
+        return database_url.replace("sqlite://", "sqlite+aiosqlite://")
30
+    else:
31
+        raise ValueError(f"不支持的数据库类型: {database_url}")
32
+
33
+
34
+# ====================== 3. 异步引擎配置 ======================
35
+async_database_url = ensure_async_url(settings.DATABASE_URL)
36
+
37
+# 连接参数
38
+connect_args = {}
39
+if "sqlite+aiosqlite" in async_database_url:
40
+    connect_args = {"check_same_thread": False}
41
+
42
+async_engine: AsyncEngine = create_async_engine(
43
+    async_database_url,
44
+    connect_args=connect_args,
45
+    echo=settings.DEBUG,
46
+    pool_size=20,
47
+    max_overflow=10,
48
+    pool_pre_ping=True,
49
+    future=True,  # 启用2.0特性
16 50
 )
17 51
 
18
-# 创建会话工厂
19
-SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
20 52
 
21
-# 声明基类
22
-Base = declarative_base()
53
+# ====================== 4. 异步会话工厂 ======================
54
+AsyncSessionLocal = async_sessionmaker(
55
+    bind=async_engine,
56
+    class_=AsyncSession,
57
+    expire_on_commit=False,
58
+    autoflush=False,
59
+)
60
+
23 61
 
24
-def get_db() -> Generator[Session, None, None]:
25
-    """数据库会话依赖注入"""
26
-    db = SessionLocal()
62
+# ====================== 5. 异步依赖注入 ======================
63
+async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
64
+    """
65
+    FastAPI兼容的异步数据库依赖
66
+    注意:移除了@asynccontextmanager装饰器
67
+    """
68
+    session = AsyncSessionLocal()
27 69
     try:
28
-        yield db
70
+        yield session
71
+        await session.commit()
72
+    except Exception:
73
+        await session.rollback()
74
+        raise
29 75
     finally:
30
-        db.close()
76
+        await session.close()
31 77
 
32
-def init_db():
33
-    """初始化数据库表"""
78
+
79
+# ====================== 6. 数据库初始化 ======================
80
+async def init_async_db():
81
+    """异步初始化数据库"""
34 82
     from .models.user import User
35
-    Base.metadata.create_all(bind=engine)
83
+    from sqlalchemy import select
84
+    from .core.security import password_hasher
85
+    
86
+    # 创建表
87
+    async with async_engine.begin() as conn:
88
+        await conn.run_sync(Base.metadata.create_all)
36 89
     print("✅ 数据库表创建完成")
37 90
     
38 91
     # 创建默认管理员用户
39
-    db = SessionLocal()
92
+    async with AsyncSessionLocal() as session:
93
+        try:
94
+            # 异步查询
95
+            stmt = select(User).where(User.username == "admin")
96
+            result = await session.execute(stmt)
97
+            admin = result.scalar_one_or_none()
98
+            
99
+            if not admin:
100
+                admin_user = User(
101
+                    username="admin",
102
+                    email="admin@caiyouhui.com",
103
+                    hashed_password=password_hasher.hash_password("Admin123!"),
104
+                    full_name="系统管理员",
105
+                    is_active=True,
106
+                    is_verified=True,
107
+                    is_superuser=True
108
+                )
109
+                session.add(admin_user)
110
+                await session.commit()
111
+                print("✅ 默认管理员用户已创建")
112
+        except Exception as e:
113
+            print(f"⚠️  创建管理员用户时出错: {e}")
114
+            await session.rollback()
115
+
116
+
117
+# ====================== 7. 连接健康检查 ======================
118
+async def check_async_connection():
119
+    """检查数据库连接"""
120
+    from sqlalchemy import text  # 需要导入text
40 121
     try:
41
-        # 检查是否已存在管理员
42
-        admin = db.query(User).filter(User.username == "admin").first()
43
-        if not admin:
44
-            from .core.security import password_hasher
45
-            admin_user = User(
46
-                username="admin",
47
-                email="admin@caiyouhui.com",
48
-                hashed_password=password_hasher.hash_password("Admin123!"),
49
-                full_name="系统管理员",
50
-                is_active=True,
51
-                is_verified=True,
52
-                is_superuser=True
53
-            )
54
-            db.add(admin_user)
55
-            db.commit()
56
-            print("✅ 默认管理员用户已创建")
122
+        async with async_engine.connect() as conn:
123
+            await conn.execute(text("SELECT 1"))
124
+        print("✅ 数据库连接正常")
125
+        return True
57 126
     except Exception as e:
58
-        print(f"⚠️  创建管理员用户时出错: {e}")
59
-        db.rollback()
60
-    finally:
61
-        db.close()
127
+        print(f"❌ 数据库连接失败: {e}")
128
+        return False
129
+
62 130
 
63
-# 导出
64
-__all__ = ["Base", "engine", "SessionLocal", "get_db", "init_db"]
131
+# ====================== 8. 导出 ======================
132
+__all__ = [
133
+    "Base",
134
+    "async_engine", 
135
+    "AsyncSessionLocal",
136
+    "get_async_db",
137
+    "init_async_db",
138
+    "check_async_connection",
139
+]

+ 3
- 3
app/dependencies/auth.py Parādīt failu

@@ -1,10 +1,10 @@
1 1
 from fastapi import Depends, HTTPException, status
2 2
 from fastapi.security import OAuth2PasswordBearer
3
-from sqlalchemy.orm import Session
3
+from sqlalchemy.ext.asyncio import AsyncSession
4 4
 from jose import JWTError, jwt
5 5
 from typing import Optional
6 6
 
7
-from ..database import get_db
7
+from ..database import get_async_db
8 8
 from ..models.user import User
9 9
 from ..schemas.token import TokenData
10 10
 from ..config import settings
@@ -16,7 +16,7 @@ oauth2_scheme = OAuth2PasswordBearer(
16 16
 
17 17
 async def get_current_user(
18 18
     token: Optional[str] = Depends(oauth2_scheme),
19
-    db: Session = Depends(get_db)
19
+    db: AsyncSession = Depends(get_async_db)
20 20
 ) -> Optional[User]:
21 21
     """获取当前用户"""
22 22
     if not token:

+ 1
- 1
app/dependencies/database.py Parādīt failu

@@ -6,7 +6,7 @@ from typing import Generator
6 6
 # ✅ 正确导入方式
7 7
 from app.database import SessionLocal
8 8
 
9
-def get_db() -> Generator[Session, None, None]:
9
+def get_async_db() -> Generator[Session, None, None]:
10 10
     """数据库会话依赖注入"""
11 11
     db = SessionLocal()
12 12
     try:

+ 1
- 1
app/logging_config.py Parādīt failu

@@ -2,7 +2,7 @@
2 2
 import logging
3 3
 import logging.handlers  # ✅ 这里导入 handlers
4 4
 import os
5
-from datetime import datetime
5
+from datetime import datetime, timezone, timedelta
6 6
 
7 7
 def setup_logging_with_rotation():
8 8
     """配置带轮转的日志系统"""

+ 406
- 46
app/main.py Parādīt failu

@@ -1,43 +1,155 @@
1
-# app/main.py - 简化版本
2
-from fastapi import FastAPI
1
+# app/main.py
2
+from contextlib import asynccontextmanager
3
+from fastapi import FastAPI, Request, status
4
+from fastapi.responses import JSONResponse
3 5
 from fastapi.middleware.cors import CORSMiddleware
4
-from app.logging_config import setup_logging_with_rotation
6
+from fastapi.exceptions import RequestValidationError
7
+from fastapi.middleware.gzip import GZipMiddleware
5 8
 import logging
6
-import os
9
+import time
10
+import traceback
11
+import sys
7 12
 
8 13
 # 配置
9
-from .config import settings
10
-
11
-# 配置日志
12
-logging.basicConfig(
13
-    level=logging.INFO,
14
-    # format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
15
-    # datefmt='%Y-%m-%d %H:%M:%S',
16
-    # handlers=[
17
-    #     logging.StreamHandler(),  # 控制台
18
-    #     logging.FileHandler('app/logs/app.log', encoding='utf-8')  # 文件
19
-    # ]
20
-)
21
-# 为特定模块设置更详细的日志
22
-logging.getLogger("app.api.v1").setLevel(logging.DEBUG)
23
-logging.getLogger("app.core.security").setLevel(logging.DEBUG)
14
+from app.config import settings
24 15
 
16
+# 设置日志
25 17
 logger = logging.getLogger(__name__)
26 18
 logger.info("✅ 日志配置完成")
27 19
 
28
-# logger = setup_logging_with_rotation
29
-
20
+# ==================== 异步生命周期管理 ====================
21
+@asynccontextmanager
22
+async def lifespan(app: FastAPI):
23
+    """
24
+    应用生命周期管理器 - 完全异步
25
+    启动时初始化资源,关闭时清理资源
26
+    """
27
+    # 启动阶段
28
+    logger.info("🚀 应用启动中...")
29
+    
30
+    try:
31
+        # 1. 初始化数据库
32
+        from app.database import init_async_db, check_async_connection
33
+        if await check_async_connection():
34
+            await init_async_db()
35
+            logger.info("✅ 数据库初始化完成")
36
+        else:
37
+            logger.error("❌ 数据库连接失败")
38
+            # 可以根据需要决定是否终止启动
39
+            # raise RuntimeError("数据库连接失败")
40
+        
41
+        # 2. 初始化Redis等缓存(可选)
42
+        # try:
43
+        #     from app.core.redis import init_redis
44
+        #     await init_redis()
45
+        #     logger.info("✅ Redis连接完成")
46
+        # except ImportError:
47
+        #     logger.info("ℹ️  Redis未配置,跳过初始化")
48
+        
49
+        # 3. 初始化其他服务
50
+        # try:
51
+        #     from app.core.security import init_security
52
+        #     init_security()
53
+        #     logger.info("✅ 安全模块初始化完成")
54
+        # except Exception as e:
55
+        #     logger.warning(f"⚠️  安全模块初始化异常: {e}")
56
+        
57
+        logger.info(f"🎉 {settings.PROJECT_NAME} v{settings.VERSION} 启动完成")
58
+        yield
59
+        
60
+    except Exception as e:
61
+        logger.error(f"🔥 应用启动失败: {e}")
62
+        logger.error(traceback.format_exc())
63
+        sys.exit(1)
64
+    
65
+    finally:
66
+        # 关闭阶段
67
+        logger.info("🛑 应用关闭中...")
68
+        
69
+        # 清理资源
70
+        try:
71
+            from app.database import async_engine
72
+            await async_engine.dispose()
73
+            logger.info("✅ 数据库连接池已清理")
74
+        except Exception as e:
75
+            logger.warning(f"⚠️  数据库清理异常: {e}")
76
+        
77
+        logger.info("👋 应用已安全关闭")
30 78
 
31 79
 
32
-# 创建 FastAPI 应用
80
+# ==================== 创建FastAPI应用 ====================
33 81
 app = FastAPI(
34 82
     title=settings.PROJECT_NAME,
35 83
     version=settings.VERSION,
84
+    description=f"""
85
+    {settings.PROJECT_NAME} API 服务
86
+    
87
+    ## 功能特性
88
+    - ✅ 全异步架构(SQLAlchemy 2.0 + async/await)
89
+    - 🔒 JWT认证与授权
90
+    - 📊 性能监控中间件
91
+    - 🚀 自动API文档
92
+    
93
+    ## 环境
94
+    - **调试模式**: {'开启' if settings.DEBUG else '关闭'}
95
+    - **数据库**: {settings.DATABASE_URL.split('://')[0] if '://' in settings.DATABASE_URL else '未知'}
96
+    """,
36 97
     docs_url="/docs" if settings.DEBUG else None,
37
-    redoc_url="/redoc" if settings.DEBUG else None
98
+    redoc_url="/redoc" if settings.DEBUG else None,
99
+    openapi_url="/openapi.json" if settings.DEBUG else None,
100
+    lifespan=lifespan,  # 异步生命周期管理
101
+    swagger_ui_parameters={
102
+        "persistAuthorization": True,
103
+        "displayRequestDuration": True,
104
+        "filter": True,
105
+    }
38 106
 )
39 107
 
40
-# 配置CORS
108
+# ==================== 全局异常处理器 ====================
109
+@app.exception_handler(Exception)
110
+async def global_exception_handler(request: Request, exc: Exception):
111
+    """全局异常处理器"""
112
+    logger.error(f"🔥 未捕获的异常: {exc}", exc_info=True)
113
+    
114
+    # 在调试模式下返回详细错误
115
+    if settings.DEBUG:
116
+        return JSONResponse(
117
+            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
118
+            content={
119
+                "detail": str(exc),
120
+                "type": exc.__class__.__name__,
121
+                "traceback": traceback.format_exc().split('\n'),
122
+                "path": request.url.path,
123
+                "method": request.method,
124
+            }
125
+        )
126
+    
127
+    # 生产环境返回简化错误
128
+    return JSONResponse(
129
+        status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
130
+        content={
131
+            "detail": "服务器内部错误",
132
+            "request_id": request.state.get("request_id", "unknown") if hasattr(request.state, "request_id") else "unknown"
133
+        }
134
+    )
135
+
136
+
137
+@app.exception_handler(RequestValidationError)
138
+async def validation_exception_handler(request: Request, exc: RequestValidationError):
139
+    """请求验证异常处理器"""
140
+    logger.warning(f"⚠️  请求验证失败: {exc}")
141
+    
142
+    return JSONResponse(
143
+        status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
144
+        content={
145
+            "detail": exc.errors(),
146
+            "body": exc.body,
147
+        }
148
+    )
149
+
150
+
151
+# ==================== 中间件配置 ====================
152
+# 1. CORS中间件
41 153
 if settings.BACKEND_CORS_ORIGINS:
42 154
     app.add_middleware(
43 155
         CORSMiddleware,
@@ -45,39 +157,287 @@ if settings.BACKEND_CORS_ORIGINS:
45 157
         allow_credentials=True,
46 158
         allow_methods=["*"],
47 159
         allow_headers=["*"],
160
+        expose_headers=["*"],
161
+        max_age=600,  # 预检请求缓存时间(秒)
48 162
     )
163
+    logger.info(f"✅ CORS已启用,允许的源: {settings.BACKEND_CORS_ORIGINS}")
49 164
 
50
-# 导入并注册路由
51
-try:
52
-    # 导入所有路由模块
53
-    # from .api.v1.admin import router as admin_router
54
-    from .api.v1.auth import router as auth_router
55
-    from .api.v1.users import router as users_router
56
-    # from .api.v1.verify import router as verify_router
165
+# 2. 请求ID中间件(用于日志追踪)
166
+@app.middleware("http")
167
+async def add_request_id(request: Request, call_next):
168
+    """为每个请求添加唯一ID"""
169
+    import uuid
170
+    request_id = str(uuid.uuid4())[:8]
171
+    request.state.request_id = request_id
57 172
     
58
-    # 注册路由
59
-    app.include_router(auth_router, prefix=settings.API_V1_PREFIX)
60
-    app.include_router(users_router, prefix=settings.API_V1_PREFIX)
61
-    # app.include_router(verify_router, prefix=settings.API_V1_PREFIX)
173
+    response = await call_next(request)
174
+    response.headers["X-Request-ID"] = request_id
175
+    return response
176
+
177
+
178
+# 3. 性能监控中间件
179
+@app.middleware("http")
180
+async def performance_middleware(request: Request, call_next):
181
+    """记录请求处理时间"""
182
+    start_time = time.time()
183
+    
184
+    try:
185
+        response = await call_next(request)
186
+        process_time = time.time() - start_time
187
+        
188
+        # 记录慢请求
189
+        if process_time > 1.0:  # 超过1秒
190
+            logger.warning(
191
+                f"🐌 慢请求: {request.method} {request.url.path} "
192
+                f"耗时: {process_time:.3f}s "
193
+                f"状态: {response.status_code} "
194
+                f"请求ID: {getattr(request.state, 'request_id', 'unknown')}"
195
+            )
196
+        elif settings.DEBUG:
197
+            logger.debug(
198
+                f"⚡ 请求: {request.method} {request.url.path} "
199
+                f"耗时: {process_time:.3f}s"
200
+            )
201
+        
202
+        # 添加处理时间到响应头
203
+        response.headers["X-Process-Time"] = str(process_time)
204
+        return response
205
+        
206
+    except Exception as e:
207
+        process_time = time.time() - start_time
208
+        logger.error(
209
+            f"🔥 请求异常: {request.method} {request.url.path} "
210
+            f"耗时: {process_time:.3f}s "
211
+            f"错误: {e}"
212
+        )
213
+        raise
214
+
215
+
216
+# 4. GZIP压缩中间件(提升性能)
217
+app.add_middleware(
218
+    GZipMiddleware,
219
+    minimum_size=1000,  # 只压缩大于1KB的响应
220
+)
221
+
222
+
223
+# ==================== 路由注册 ====================
224
+def register_routers():
225
+    """注册所有API路由"""
226
+    routers = []
62 227
     
63
-    logger.info("✅ API 路由注册成功")
228
+    # 尝试导入并注册路由模块
229
+    try:
230
+        from app.api.v1.auth import router as auth_router
231
+        routers.append(("认证模块", auth_router))
232
+    except ImportError as e:
233
+        logger.warning(f"⚠️  认证路由导入失败: {e}")
64 234
     
65
-except ImportError as e:
66
-    logger.warning(f"⚠️  部分路由模块未找到: {e}")
235
+    try:
236
+        from app.api.v1.users import router as users_router
237
+        routers.append(("用户管理", users_router))
238
+    except ImportError as e:
239
+        logger.warning(f"⚠️  用户路由导入失败: {e}")
240
+    
241
+    # 更多路由模块...
242
+    # try:
243
+    #     from app.api.v1.admin import router as admin_router
244
+    #     routers.append(("管理后台", admin_router))
245
+    # except ImportError:
246
+    #     pass
247
+    
248
+    # 注册所有路由
249
+    for name, router in routers:
250
+        app.include_router(router, prefix=settings.API_V1_PREFIX)
251
+        logger.info(f"✅ 路由注册: {name}")
252
+    
253
+    logger.info(f"🎯 共注册 {len(routers)} 个路由模块")
254
+
67 255
 
68
-# 系统级路由
69
-@app.get("/health")
256
+# 调用路由注册
257
+register_routers()
258
+
259
+
260
+# ==================== 系统级路由 ====================
261
+@app.get("/health", tags=["系统监控"])
70 262
 async def health_check():
71
-    return {"status": "healthy", "service": settings.PROJECT_NAME}
263
+    """
264
+    健康检查端点
265
+    
266
+    返回服务健康状态,可用于K8s探针、负载均衡器健康检查等
267
+    """
268
+    from app.database import check_async_connection
269
+    
270
+    health_status = {
271
+        "status": "healthy",
272
+        "service": settings.PROJECT_NAME,
273
+        "version": settings.VERSION,
274
+        "timestamp": time.time(),
275
+        "async": True,
276
+        "debug": settings.DEBUG,
277
+    }
278
+    
279
+    # 检查数据库连接
280
+    try:
281
+        db_healthy = await check_async_connection()
282
+        health_status["database"] = "healthy" if db_healthy else "unhealthy"
283
+    except Exception as e:
284
+        health_status["database"] = f"error: {str(e)}"
285
+    
286
+    # 检查Redis连接(可选)
287
+    # try:
288
+    #     from app.core.redis import check_redis_connection
289
+    #     redis_healthy = await check_redis_connection()
290
+    #     health_status["redis"] = "healthy" if redis_healthy else "unhealthy"
291
+    # except ImportError:
292
+    #     health_status["redis"] = "not_configured"
293
+    # except Exception as e:
294
+    #     health_status["redis"] = f"error: {str(e)}"
295
+    
296
+    # 如果有不健康的服务,返回503
297
+    if "unhealthy" in health_status.values() or "error" in str(health_status.values()):
298
+        return JSONResponse(
299
+            status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
300
+            content=health_status
301
+        )
302
+    
303
+    return health_status
72 304
 
73
-@app.get("/")
305
+
306
+@app.get("/", tags=["系统"])
74 307
 async def root():
308
+    """
309
+    根端点
310
+    
311
+    返回API基本信息和可用端点
312
+    """
313
+    # 获取已注册的路由信息
314
+    routes_info = []
315
+    for route in app.routes:
316
+        if hasattr(route, "methods"):
317
+            routes_info.append({
318
+                "path": route.path,
319
+                "methods": list(route.methods) if route.methods else [],
320
+                "name": route.name or "",
321
+            })
322
+    
75 323
     return {
76
-        "message": f"Welcome to {settings.PROJECT_NAME} API",
324
+        "message": f"欢迎使用 {settings.PROJECT_NAME} API",
77 325
         "version": settings.VERSION,
78
-        "docs": "/docs"
326
+        "description": "基于FastAPI的现代化异步API服务",
327
+        "docs": "/docs" if settings.DEBUG else None,
328
+        "redoc": "/redoc" if settings.DEBUG else None,
329
+        "health_check": "/health",
330
+        "async": True,
331
+        "database": "SQLAlchemy 2.0 Async",
332
+        "routes_count": len(routes_info),
333
+        "quick_links": {
334
+            "认证": f"{settings.API_V1_PREFIX}/auth/login",
335
+            "用户信息": f"{settings.API_V1_PREFIX}/users/me",
336
+            "API文档": "/docs" if settings.DEBUG else "未启用",
337
+        }
79 338
     }
80 339
 
81
-if __name__ == "__main__":
340
+
341
+@app.get("/metrics", tags=["系统监控"])
342
+async def metrics():
343
+    """
344
+    监控指标端点(可用于Prometheus)
345
+    
346
+    返回应用性能指标
347
+    """
348
+    import psutil
349
+    import gc
350
+    
351
+    # 内存使用
352
+    process = psutil.Process()
353
+    mem_info = process.memory_info()
354
+    
355
+    # GC统计
356
+    gc_counts = gc.get_count()
357
+    
358
+    return {
359
+        "process": {
360
+            "pid": process.pid,
361
+            "cpu_percent": process.cpu_percent(),
362
+            "memory_mb": mem_info.rss / 1024 / 1024,
363
+            "threads": process.num_threads(),
364
+        },
365
+        "gc": {
366
+            "generation0": gc_counts[0],
367
+            "generation1": gc_counts[1],
368
+            "generation2": gc_counts[2],
369
+        },
370
+        "python": {
371
+            "version": sys.version,
372
+            "implementation": sys.implementation.name,
373
+        },
374
+        "timestamp": time.time(),
375
+    }
376
+
377
+
378
+@app.get("/config", tags=["系统"])
379
+async def show_config():
380
+    """
381
+    显示当前配置(调试用)
382
+    
383
+    注意:生产环境建议禁用此端点或限制访问
384
+    """
385
+    if not settings.DEBUG:
386
+        raise HTTPException(
387
+            status_code=status.HTTP_403_FORBIDDEN,
388
+            detail="仅调试模式下可用"
389
+        )
390
+    
391
+    # 安全地显示配置(隐藏敏感信息)
392
+    config_dict = {}
393
+    for key in dir(settings):
394
+        if not key.startswith("_"):
395
+            value = getattr(settings, key)
396
+            
397
+            # 隐藏敏感信息
398
+            if any(sensitive in key.lower() for sensitive in ["secret", "key", "password", "token"]):
399
+                config_dict[key] = "***HIDDEN***"
400
+            elif "url" in key.lower() and "database" in key.lower():
401
+                # 显示数据库类型但不显示完整URL
402
+                db_type = value.split("://")[0] if "://" in value else "unknown"
403
+                config_dict[key] = f"{db_type}://***HIDDEN***"
404
+            else:
405
+                config_dict[key] = value
406
+    
407
+    return {
408
+        "config": config_dict,
409
+        "environment": settings.ENVIRONMENT,
410
+        "debug": settings.DEBUG,
411
+    }
412
+
413
+
414
+# ==================== 启动入口 ====================
415
+def run_server():
416
+    """启动服务器(开发模式)"""
82 417
     import uvicorn
83
-    uvicorn.run(app, host="0.0.0.0", port=10003, reload=settings.DEBUG)
418
+    
419
+    logger.info("=" * 50)
420
+    logger.info(f"🚀 启动 {settings.PROJECT_NAME} v{settings.VERSION}")
421
+    logger.info(f"🐛 调试模式: {settings.DEBUG}")
422
+    logger.info(f"🌐 主机: 0.0.0.0")
423
+    logger.info(f"🔌 端口: {settings.APP_RUN_PORT}")
424
+    logger.info(f"🗄️  数据库: {settings.DATABASE_URL.split('://')[0] if '://' in settings.DATABASE_URL else '未知'}")
425
+    logger.info("=" * 50)
426
+    
427
+    uvicorn.run(
428
+        app,
429
+        host="0.0.0.0",
430
+        port=settings.APP_RUN_PORT,
431
+        reload=settings.DEBUG,
432
+        log_level="info" if settings.DEBUG else "warning",
433
+        access_log=True,
434
+        use_colors=True,
435
+        # 优化性能的配置
436
+        limit_concurrency=1000,
437
+        limit_max_requests=10000,
438
+        timeout_keep_alive=5,
439
+    )
440
+
441
+
442
+if __name__ == "__main__":
443
+    run_server()

+ 2
- 2
app/models/token.py Parādīt failu

@@ -1,7 +1,7 @@
1 1
 from sqlalchemy import Column, Integer, String, DateTime, Boolean, ForeignKey
2 2
 from sqlalchemy.orm import relationship
3 3
 from sqlalchemy.sql import func
4
-from datetime import datetime
4
+from datetime import datetime, timezone, timedelta
5 5
 from ..database import Base
6 6
 
7 7
 class Token(Base):
@@ -25,7 +25,7 @@ class Token(Base):
25 25
     created_at = Column(DateTime(timezone=True), server_default=func.now())
26 26
     
27 27
     def is_expired(self):
28
-        return datetime.utcnow() > self.expires_at
28
+        return datetime.now(timezone.utc) > self.expires_at
29 29
     
30 30
     def __repr__(self):
31 31
         return f"<Token {self.token_type} for user {self.user_id}>"

+ 5
- 1
app/models/user.py Parādīt failu

@@ -1,7 +1,9 @@
1 1
 # app/models/user.py
2 2
 from sqlalchemy import Column, Integer, String, Boolean, DateTime, Text
3 3
 from sqlalchemy.sql import func
4
-from ..database import Base
4
+
5
+# 从database.py导入Base(确保是DeclarativeBase)
6
+from app.database import Base
5 7
 
6 8
 class User(Base):
7 9
     __tablename__ = "users"
@@ -10,6 +12,8 @@ class User(Base):
10 12
     username = Column(String(50), unique=True, index=True, nullable=False)
11 13
     email = Column(String(100), unique=True, index=True, nullable=False)
12 14
     hashed_password = Column(String(255), nullable=False)
15
+    full_name = Column(String(100))
16
+    avatar = Column(String(255))
13 17
     
14 18
     # 用户状态
15 19
     is_active = Column(Boolean, default=True)

+ 1
- 1
app/schemas/auth.py Parādīt failu

@@ -1,6 +1,6 @@
1 1
 from pydantic import BaseModel, Field
2 2
 from typing import Optional
3
-from datetime import datetime
3
+from datetime import datetime, timezone, timedelta
4 4
 from .user import UserResponse
5 5
 
6 6
 class TokenBase(BaseModel):

+ 1
- 1
app/schemas/token.py Parādīt failu

@@ -1,7 +1,7 @@
1 1
 # app/schemas/token.py
2 2
 from pydantic import BaseModel, Field
3 3
 from typing import Optional
4
-from datetime import datetime
4
+from datetime import datetime, timezone, timedelta
5 5
 
6 6
 class Token(BaseModel):
7 7
     """令牌响应模型"""

+ 1
- 1
app/schemas/user.py Parādīt failu

@@ -1,7 +1,7 @@
1 1
 # app/schemas/user.py
2 2
 from pydantic import BaseModel, EmailStr, Field, validator
3 3
 from typing import Optional
4
-from datetime import datetime
4
+from datetime import datetime, timezone, timedelta
5 5
 
6 6
 class UserBase(BaseModel):
7 7
     """用户基础模型"""

+ 242
- 124
app/services/auth_service.py Parādīt failu

@@ -1,14 +1,15 @@
1 1
 from typing import Optional, Dict, Any, Tuple
2
-from datetime import datetime, timedelta
3
-from sqlalchemy.orm import Session
2
+from datetime import datetime, timezone, timedelta
3
+from sqlalchemy import select, update, delete, func, or_
4
+from sqlalchemy.ext.asyncio import AsyncSession
4 5
 from fastapi import HTTPException, status, BackgroundTasks
5 6
 import secrets
6 7
 import string
7 8
 
8
-from ..models.user import User
9
-from ..models.token import Token
9
+from app.models.user import User
10
+from app.models.token import Token
10 11
 from ..schemas.auth import LoginRequest, TokenResponse
11
-from ..core.security import (
12
+from app.core.security import (
12 13
     verify_password,
13 14
     create_access_token,
14 15
     create_refresh_token,
@@ -17,11 +18,11 @@ from ..core.security import (
17 18
     decode_token,
18 19
     generate_verification_code
19 20
 )
20
-from ..core.email import email_service
21
+from app.core.email import email_service
21 22
 from ..config import settings
22 23
 
23 24
 class AuthService:
24
-    def __init__(self, db: Session):
25
+    def __init__(self, db: AsyncSession):
25 26
         self.db = db
26 27
     
27 28
     async def register_user(
@@ -31,12 +32,16 @@ class AuthService:
31 32
         ip_address: Optional[str] = None,
32 33
         user_agent: Optional[str] = None
33 34
     ) -> User:
34
-        """注册新用户"""
35
-        # 检查用户是否存在
36
-        existing_user = self.db.query(User).filter(
37
-            (User.username == user_data["username"]) | 
38
-            (User.email == user_data["email"])
39
-        ).first()
35
+        """注册新用户(SQLAlchemy 2.0异步版)"""
36
+        # 1. 检查用户是否存在 - 使用异步查询
37
+        stmt = select(User).where(
38
+            or_(
39
+                User.username == user_data["username"],
40
+                User.email == user_data["email"]
41
+            )
42
+        )
43
+        result = await self.db.execute(stmt)
44
+        existing_user = result.scalar_one_or_none()
40 45
         
41 46
         if existing_user:
42 47
             if existing_user.username == user_data["username"]:
@@ -50,7 +55,7 @@ class AuthService:
50 55
                     detail="Email already registered"
51 56
                 )
52 57
         
53
-        # 创建用户
58
+        # 2. 创建用户
54 59
         from ..core.security import get_password_hash
55 60
         hashed_password = get_password_hash(user_data["password"])
56 61
         
@@ -64,16 +69,17 @@ class AuthService:
64 69
             is_verified=False
65 70
         )
66 71
         
67
-        # 生成验证码
72
+        # 3. 生成验证码
68 73
         verification_code = generate_verification_code()
69 74
         user.verification_code = verification_code
70
-        user.verification_code_expires = datetime.utcnow() + timedelta(hours=24)
75
+        user.verification_code_expires = datetime.now(timezone.utc) + timedelta(hours=24)
71 76
         
77
+        # 4. 异步保存用户
72 78
         self.db.add(user)
73
-        self.db.commit()
74
-        self.db.refresh(user)
79
+        await self.db.commit()
80
+        await self.db.refresh(user)
75 81
         
76
-        # 发送验证邮件(后台任务)
82
+        # 5. 发送验证邮件(后台任务)
77 83
         verification_token = create_verification_token(user.email)
78 84
         verification_url = f"{settings.FRONTEND_URL}/verify-email?token={verification_token}"
79 85
         
@@ -93,11 +99,16 @@ class AuthService:
93 99
         ip_address: Optional[str] = None,
94 100
         user_agent: Optional[str] = None
95 101
     ) -> Tuple[TokenResponse, User]:
96
-        """用户登录"""
97
-        user = self.db.query(User).filter(
98
-            (User.username == login_data.username) | 
99
-            (User.email == login_data.username)
100
-        ).first()
102
+        """用户登录(SQLAlchemy 2.0异步版)"""
103
+        # 1. 查找用户 - 异步查询
104
+        stmt = select(User).where(
105
+            or_(
106
+                User.username == login_data.username,
107
+                User.email == login_data.username
108
+            )
109
+        )
110
+        result = await self.db.execute(stmt)
111
+        user = result.scalar_one_or_none()
101 112
         
102 113
         if not user:
103 114
             raise HTTPException(
@@ -105,69 +116,95 @@ class AuthService:
105 116
                 detail="Incorrect username or password"
106 117
             )
107 118
         
108
-        # 检查账户是否被锁定
109
-        if user.is_locked and user.locked_until and user.locked_until > datetime.utcnow():
119
+        # 2. 检查账户是否被锁定
120
+        if user.is_locked and user.locked_until and user.locked_until > datetime.now(timezone.utc):
110 121
             raise HTTPException(
111 122
                 status_code=status.HTTP_423_LOCKED,
112 123
                 detail=f"Account is locked until {user.locked_until}"
113 124
             )
114 125
         
115
-        # 验证密码
126
+        # 3. 验证密码
116 127
         if not verify_password(login_data.password, user.hashed_password):
117
-            # 记录失败尝试
118
-            user.failed_login_attempts += 1
128
+            # 记录失败尝试 - 异步更新
129
+            new_attempts = (user.failed_login_attempts or 0) + 1
130
+            
131
+            update_data = {
132
+                "failed_login_attempts": new_attempts,
133
+                "updated_at": datetime.now(timezone.utc)
134
+            }
119 135
             
120 136
             # 如果失败次数超过5次,锁定账户
121
-            if user.failed_login_attempts >= 5:
122
-                user.is_locked = True
123
-                user.locked_until = datetime.utcnow() + timedelta(minutes=30)
137
+            if new_attempts >= 5:
138
+                update_data.update({
139
+                    "is_locked" : True,
140
+                    "locked_until": datetime.now(timezone.utc) + timedelta(minutes=30)
141
+                })
124 142
             
125
-            self.db.commit()
143
+            update_stmt = (
144
+                update(User)
145
+                .where(User.id == user.id)
146
+                .values(**update_data)
147
+            )
148
+            
149
+            await self.db.execute(update_stmt)
150
+            await self.db.commit()
126 151
             
127 152
             raise HTTPException(
128 153
                 status_code=status.HTTP_401_UNAUTHORIZED,
129 154
                 detail="Incorrect username or password"
130 155
             )
131 156
         
132
-        # 检查邮箱是否已验证
157
+        # 4. 检查邮箱是否已验证
133 158
         if not user.is_verified:
134 159
             raise HTTPException(
135 160
                 status_code=status.HTTP_403_FORBIDDEN,
136 161
                 detail="Email not verified"
137 162
             )
138 163
         
139
-        # 检查账户是否激活
164
+        # 5. 检查账户是否激活
140 165
         if not user.is_active:
141 166
             raise HTTPException(
142 167
                 status_code=status.HTTP_403_FORBIDDEN,
143 168
                 detail="Account is not active"
144 169
             )
145 170
         
146
-        # 重置失败尝试次数
147
-        user.failed_login_attempts = 0
148
-        user.last_login = datetime.utcnow()
149
-        user.is_locked = False
150
-        user.locked_until = None
151
-        self.db.commit()
171
+        # 6. 重置失败尝试次数 - 异步更新
172
+        reset_stmt = (
173
+            update(User)
174
+            .where(User.id == user.id)
175
+            .values(
176
+                failed_login_attempts = 0,
177
+                last_login = datetime.now(timezone.utc),
178
+                is_locked = False,
179
+                locked_until = None,
180
+            )
181
+        )
182
+        
183
+        await self.db.execute(reset_stmt)
152 184
         
153
-        # 创建令牌
185
+        # 7. 创建令牌
154 186
         access_token = create_access_token({"sub": user.username, "user_id": user.id})
155 187
         refresh_token = create_refresh_token({"sub": user.username, "user_id": user.id})
156 188
         
157
-        # 保存刷新令牌到数据库
189
+        # 8. 保存刷新令牌到数据库 - 异步添加
158 190
         refresh_token_entry = Token(
159 191
             token=refresh_token,
160 192
             token_type="refresh",
161
-            expires_at=datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS),
193
+            expires_at=datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS),
162 194
             user_id=user.id,
163 195
             ip_address=ip_address,
164 196
             user_agent=user_agent
165 197
         )
166 198
         
167 199
         self.db.add(refresh_token_entry)
168
-        self.db.commit()
169 200
         
170
-        # 构建响应
201
+        # 9. 提交所有更改
202
+        await self.db.commit()
203
+        
204
+        # 10. 刷新用户对象以获取最新数据
205
+        await self.db.refresh(user)
206
+        
207
+        # 11. 构建响应
171 208
         token_response = TokenResponse(
172 209
             access_token=access_token,
173 210
             refresh_token=refresh_token,
@@ -182,7 +219,8 @@ class AuthService:
182 219
         token: str,
183 220
         code: Optional[str] = None
184 221
     ) -> bool:
185
-        """验证邮箱"""
222
+        """验证邮箱(SQLAlchemy 2.0异步版)"""
223
+        # 1. 解码令牌
186 224
         payload = decode_token(token)
187 225
         
188 226
         if not payload or payload.get("type") != "verify":
@@ -198,7 +236,10 @@ class AuthService:
198 236
                 detail="Invalid token payload"
199 237
             )
200 238
         
201
-        user = self.db.query(User).filter(User.email == email).first()
239
+        # 2. 查找用户 - 异步查询
240
+        stmt = select(User).where(User.email == email)
241
+        result = await self.db.execute(stmt)
242
+        user = result.scalar_one_or_none()
202 243
         
203 244
         if not user:
204 245
             raise HTTPException(
@@ -212,25 +253,35 @@ class AuthService:
212 253
                 detail="Email already verified"
213 254
             )
214 255
         
215
-        # 验证码验证
256
+        # 3. 验证码验证
216 257
         if code:
258
+            now = datetime.now(timezone.utc)
217 259
             if (not user.verification_code or 
218 260
                 user.verification_code != code or
219 261
                 not user.verification_code_expires or
220
-                user.verification_code_expires < datetime.utcnow()):
262
+                user.verification_code_expires < now):
221 263
                 raise HTTPException(
222 264
                     status_code=status.HTTP_400_BAD_REQUEST,
223 265
                     detail="Invalid or expired verification code"
224 266
                 )
225 267
         
226
-        # 更新用户状态
227
-        user.is_verified = True
228
-        user.is_active = True
229
-        user.verification_code = None
230
-        user.verification_code_expires = None
231
-        self.db.commit()
268
+        # 4. 更新用户状态 - 异步更新
269
+        update_stmt = (
270
+            update(User)
271
+            .where(User.id == user.id)
272
+            .values(
273
+                is_verified = True,
274
+                is_active = True,
275
+                verification_code = None,
276
+                verification_code_expires = None,
277
+                updated_at = datetime.now(timezone.utc)
278
+            )
279
+        )
280
+        
281
+        await self.db.execute(update_stmt)
282
+        await self.db.commit()
232 283
         
233
-        # 发送欢迎邮件
284
+        # 5. 发送欢迎邮件(假设是异步的)
234 285
         await email_service.send_welcome_email(user.email, user.username)
235 286
         
236 287
         return True
@@ -240,8 +291,11 @@ class AuthService:
240 291
         email: str,
241 292
         background_tasks: BackgroundTasks
242 293
     ) -> bool:
243
-        """重新发送验证邮件"""
244
-        user = self.db.query(User).filter(User.email == email).first()
294
+        """重新发送验证邮件(SQLAlchemy 2.0异步版)"""
295
+        # 1. 查找用户 - 异步查询
296
+        stmt = select(User).where(User.email == email)
297
+        result = await self.db.execute(stmt)
298
+        user = result.scalar_one_or_none()
245 299
         
246 300
         if not user:
247 301
             # 出于安全考虑,即使用户不存在也返回成功
@@ -253,14 +307,24 @@ class AuthService:
253 307
                 detail="Email already verified"
254 308
             )
255 309
         
256
-        # 生成新的验证码
310
+        # 2. 生成新的验证码
257 311
         verification_code = generate_verification_code()
258
-        user.verification_code = verification_code
259
-        user.verification_code_expires = datetime.utcnow() + timedelta(hours=24)
260 312
         
261
-        self.db.commit()
313
+        # 3. 更新用户验证信息 - 异步更新
314
+        update_stmt = (
315
+            update(User)
316
+            .where(User.id == user.id)
317
+            .values(
318
+                verification_code = verification_code,
319
+                verification_code_expires = datetime.now(timezone.utc) + timedelta(hours=24),
320
+                updated_at=datetime.now(timezone.utc)
321
+            )
322
+        )
262 323
         
263
-        # 发送验证邮件
324
+        await self.db.execute(update_stmt)
325
+        await self.db.commit()
326
+        
327
+        # 4. 发送验证邮件
264 328
         verification_token = create_verification_token(user.email)
265 329
         verification_url = f"{settings.FRONTEND_URL}/verify-email?token={verification_token}"
266 330
         
@@ -279,8 +343,11 @@ class AuthService:
279 343
         email: str,
280 344
         background_tasks: BackgroundTasks
281 345
     ) -> bool:
282
-        """请求密码重置"""
283
-        user = self.db.query(User).filter(User.email == email).first()
346
+        """请求密码重置(SQLAlchemy 2.0异步版)"""
347
+        # 1. 查找用户 - 异步查询
348
+        stmt = select(User).where(User.email == email)
349
+        result = await self.db.execute(stmt)
350
+        user = result.scalar_one_or_none()
284 351
         
285 352
         if not user:
286 353
             # 出于安全考虑,即使用户不存在也返回成功
@@ -292,22 +359,22 @@ class AuthService:
292 359
                 detail="Account is not active"
293 360
             )
294 361
         
295
-        # 生成重置令牌
362
+        # 2. 生成重置令牌
296 363
         reset_token = create_reset_token(user.email)
297 364
         reset_url = f"{settings.FRONTEND_URL}/reset-password?token={reset_token}"
298 365
         
299
-        # 保存重置令牌到数据库
366
+        # 3. 保存重置令牌到数据库 - 异步添加
300 367
         reset_token_entry = Token(
301 368
             token=reset_token,
302 369
             token_type="reset",
303
-            expires_at=datetime.utcnow() + timedelta(minutes=settings.RESET_TOKEN_EXPIRE_MINUTES),
370
+            expires_at=datetime.now(timezone.utc) + timedelta(minutes=settings.RESET_TOKEN_EXPIRE_MINUTES),
304 371
             user_id=user.id
305 372
         )
306 373
         
307 374
         self.db.add(reset_token_entry)
308
-        self.db.commit()
375
+        await self.db.commit()
309 376
         
310
-        # 发送重置邮件
377
+        # 4. 发送重置邮件
311 378
         background_tasks.add_task(
312 379
             email_service.send_password_reset_email,
313 380
             user.email,
@@ -322,8 +389,8 @@ class AuthService:
322 389
         token: str,
323 390
         new_password: str
324 391
     ) -> bool:
325
-        """重置密码"""
326
-        # 验证令牌
392
+        """重置密码(SQLAlchemy 2.0异步版)"""
393
+        # 1. 验证令牌
327 394
         payload = decode_token(token)
328 395
         
329 396
         if not payload or payload.get("type") != "reset":
@@ -339,13 +406,15 @@ class AuthService:
339 406
                 detail="Invalid token payload"
340 407
             )
341 408
         
342
-        # 检查令牌是否在数据库中且未过期
343
-        token_entry = self.db.query(Token).filter(
409
+        # 2. 检查令牌是否在数据库中且未过期 - 异步查询
410
+        stmt = select(Token).where(
344 411
             Token.token == token,
345 412
             Token.token_type == "reset",
346 413
             Token.is_revoked == False,
347
-            Token.expires_at > datetime.utcnow()
348
-        ).first()
414
+            Token.expires_at > datetime.now(timezone.utc)
415
+        )
416
+        result = await self.db.execute(stmt)
417
+        token_entry = result.scalar_one_or_none()
349 418
         
350 419
         if not token_entry:
351 420
             raise HTTPException(
@@ -353,7 +422,10 @@ class AuthService:
353 422
                 detail="Invalid or expired reset token"
354 423
             )
355 424
         
356
-        user = self.db.query(User).filter(User.email == email).first()
425
+        # 3. 查找用户 - 异步查询
426
+        user_stmt = select(User).where(User.email == email)
427
+        user_result = await self.db.execute(user_stmt)
428
+        user = user_result.scalar_one_or_none()
357 429
         
358 430
         if not user:
359 431
             raise HTTPException(
@@ -367,21 +439,41 @@ class AuthService:
367 439
                 detail="Account is not active"
368 440
             )
369 441
         
370
-        # 更新密码
442
+        # 4. 更新密码 - 异步更新
371 443
         from ..core.security import get_password_hash
372
-        user.hashed_password = get_password_hash(new_password)
373
-        user.last_password_change = datetime.utcnow()
374 444
         
375
-        # 撤销所有现有令牌
376
-        self.db.query(Token).filter(
377
-            Token.user_id == user.id,
378
-            Token.token_type.in_(["access", "refresh"])
379
-        ).update({"is_revoked": True})
445
+        update_user_stmt = (
446
+            update(User)
447
+            .where(User.id == user.id)
448
+            .values(
449
+                hashed_password = get_password_hash(new_password),
450
+                last_password_change = datetime.now(timezone.utc),
451
+                updated_at = datetime.now(timezone.utc)
452
+            )
453
+        )
454
+        
455
+        # 5. 撤销所有现有令牌 - 异步更新
456
+        revoke_tokens_stmt = (
457
+            update(Token)
458
+            .where(
459
+                Token.user_id == user.id,
460
+                Token.token_type.in_(["access", "refresh"])
461
+            )
462
+            .values(is_revoked = True)
463
+        )
380 464
         
381
-        # 标记重置令牌为已使用
382
-        token_entry.is_revoked = True
465
+        # 6. 标记重置令牌为已使用 - 异步更新
466
+        revoke_reset_token_stmt = (
467
+            update(Token)
468
+            .where(Token.token == token)
469
+            .values(is_revoked=True)
470
+        )
383 471
         
384
-        self.db.commit()
472
+        # 执行所有更新
473
+        await self.db.execute(update_user_stmt)
474
+        await self.db.execute(revoke_tokens_stmt)
475
+        await self.db.execute(revoke_reset_token_stmt)
476
+        await self.db.commit()
385 477
         
386 478
         return True
387 479
     
@@ -391,8 +483,8 @@ class AuthService:
391 483
         ip_address: Optional[str] = None,
392 484
         user_agent: Optional[str] = None
393 485
     ) -> TokenResponse:
394
-        """刷新访问令牌"""
395
-        # 验证刷新令牌
486
+        """刷新访问令牌(SQLAlchemy 2.0异步版)"""
487
+        # 1. 验证刷新令牌
396 488
         payload = decode_token(refresh_token)
397 489
         
398 490
         if not payload or payload.get("type") != "refresh":
@@ -401,13 +493,15 @@ class AuthService:
401 493
                 detail="Invalid refresh token"
402 494
             )
403 495
         
404
-        # 检查令牌是否在数据库中且有效
405
-        token_entry = self.db.query(Token).filter(
496
+        # 2. 检查令牌是否在数据库中且有效 - 异步查询
497
+        token_stmt = select(Token).where(
406 498
             Token.token == refresh_token,
407 499
             Token.token_type == "refresh",
408 500
             Token.is_revoked == False,
409
-            Token.expires_at > datetime.utcnow()
410
-        ).first()
501
+            Token.expires_at > datetime.now(timezone.utc)
502
+        )
503
+        token_result = await self.db.execute(token_stmt)
504
+        token_entry = token_result.scalar_one_or_none()
411 505
         
412 506
         if not token_entry:
413 507
             raise HTTPException(
@@ -424,11 +518,14 @@ class AuthService:
424 518
                 detail="Invalid token payload"
425 519
             )
426 520
         
427
-        user = self.db.query(User).filter(
521
+        # 3. 查找用户 - 异步查询
522
+        user_stmt = select(User).where(
428 523
             User.id == user_id,
429 524
             User.username == username,
430 525
             User.is_active == True
431
-        ).first()
526
+        )
527
+        user_result = await self.db.execute(user_stmt)
528
+        user = user_result.scalar_one_or_none()
432 529
         
433 530
         if not user:
434 531
             raise HTTPException(
@@ -436,29 +533,35 @@ class AuthService:
436 533
                 detail="User not found or inactive"
437 534
             )
438 535
         
439
-        # 创建新的访问令牌
536
+        # 4. 创建新的访问令牌
440 537
         access_token = create_access_token({"sub": user.username, "user_id": user.id})
441 538
         
442
-        # 可选:创建新的刷新令牌(刷新令牌轮换)
539
+        # 5. 创建新的刷新令牌(刷新令牌轮换)
443 540
         new_refresh_token = create_refresh_token({"sub": user.username, "user_id": user.id})
444 541
         
445
-        # 保存新的刷新令牌
542
+        # 6. 保存新的刷新令牌 - 异步添加
446 543
         new_refresh_token_entry = Token(
447 544
             token=new_refresh_token,
448
-            token_type="refresh",
449
-            expires_at=datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS),
450
-            user_id=user.id,
451
-            ip_address=ip_address,
452
-            user_agent=user_agent
545
+            token_type = "refresh",
546
+            expires_at = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS),
547
+            user_id = user.id,
548
+            ip_address = ip_address,
549
+            user_agent = user_agent
453 550
         )
454 551
         
455
-        # 标记旧令牌为已撤销
456
-        token_entry.is_revoked = True
552
+        # 7. 标记旧令牌为已撤销 - 异步更新
553
+        revoke_stmt = (
554
+            update(Token)
555
+            .where(Token.token == refresh_token)
556
+            .values(is_revoked = True)
557
+        )
457 558
         
559
+        # 执行更新和添加操作
560
+        await self.db.execute(revoke_stmt)
458 561
         self.db.add(new_refresh_token_entry)
459
-        self.db.commit()
562
+        await self.db.commit()
460 563
         
461
-        # 构建响应
564
+        # 8. 构建响应
462 565
         return TokenResponse(
463 566
             access_token=access_token,
464 567
             refresh_token=new_refresh_token,
@@ -470,16 +573,24 @@ class AuthService:
470 573
         self,
471 574
         refresh_token: str
472 575
     ) -> bool:
473
-        """用户登出"""
474
-        # 撤销刷新令牌
475
-        token_entry = self.db.query(Token).filter(
576
+        """用户登出(SQLAlchemy 2.0异步版)"""
577
+        # 1. 查找刷新令牌 - 异步查询
578
+        stmt = select(Token).where(
476 579
             Token.token == refresh_token,
477 580
             Token.token_type == "refresh"
478
-        ).first()
581
+        )
582
+        result = await self.db.execute(stmt)
583
+        token_entry = result.scalar_one_or_none()
479 584
         
585
+        # 2. 如果找到令牌,则撤销它 - 异步更新
480 586
         if token_entry:
481
-            token_entry.is_revoked = True
482
-            self.db.commit()
587
+            update_stmt = (
588
+                update(Token)
589
+                .where(Token.token == refresh_token)
590
+                .values(is_revoked = True)
591
+            )
592
+            await self.db.execute(update_stmt)
593
+            await self.db.commit()
483 594
         
484 595
         return True
485 596
     
@@ -487,12 +598,19 @@ class AuthService:
487 598
         self,
488 599
         user_id: int
489 600
     ) -> bool:
490
-        """撤销用户的所有令牌"""
491
-        self.db.query(Token).filter(
492
-            Token.user_id == user_id,
493
-            Token.token_type.in_(["access", "refresh"]),
494
-            Token.is_revoked == False
495
-        ).update({"is_revoked": True})
496
-        
497
-        self.db.commit()
498
-        return True
601
+        """撤销用户的所有令牌(SQLAlchemy 2.0异步版)"""
602
+        # 撤销用户的所有访问和刷新令牌 - 异步更新
603
+        stmt = (
604
+            update(Token)
605
+            .where(
606
+                Token.user_id == user_id,
607
+                Token.token_type.in_(["access", "refresh"]),
608
+                Token.is_revoked == False
609
+            )
610
+            .values(is_revoked = True)
611
+        )
612
+        
613
+        result = await self.db.execute(stmt)
614
+        await self.db.commit()
615
+        
616
+        return result.rowcount > 0

+ 115
- 66
app/services/user_service.py Parādīt failu

@@ -1,33 +1,36 @@
1 1
 # app/services/user_service.py
2
-from sqlalchemy.orm import Session
3
-from sqlalchemy import or_
4
-from typing import Optional, List
5
-from datetime import datetime
2
+from sqlalchemy import select, update, delete, func, or_
3
+from sqlalchemy.orm import selectinload
4
+from sqlalchemy.ext.asyncio import AsyncSession
5
+from typing import List, Optional, Dict, Any
6
+from datetime import datetime, timezone, timedelta
6 7
 
7
-from ..models.user import User
8
-from ..core.security import password_hasher, password_validator
8
+from app.models.user import User
9
+from app.core.security import password_hasher, password_validator
9 10
 from ..schemas.user import UserCreate, UserUpdate, UserResponse
10 11
 
11 12
 class UserService:
12 13
     """用户服务"""
13 14
     
14
-    def __init__(self, db: Session):
15
+    def __init__(self, db: AsyncSession):  # 异步Session
15 16
         self.db = db
16 17
     
17 18
     async def create_user(self, user_data: UserCreate) -> User:
18
-        """创建用户"""
19
+        """创建用户(异步版本)"""
19 20
         # 验证密码强度
20 21
         is_valid, error_msg = password_validator.validate_password_strength(user_data.password)
21 22
         if not is_valid:
22 23
             raise ValueError(error_msg)
23 24
         
24
-        # 检查用户是否已存在
25
-        existing_user = self.db.query(User).filter(
25
+        # 检查用户是否已存在 - 使用异步查询
26
+        stmt = select(User).where(
26 27
             or_(
27 28
                 User.username == user_data.username,
28 29
                 User.email == user_data.email
29 30
             )
30
-        ).first()
31
+        )
32
+        result = await self.db.execute(stmt)
33
+        existing_user = result.scalar_one_or_none()
31 34
         
32 35
         if existing_user:
33 36
             if existing_user.username == user_data.username:
@@ -38,7 +41,7 @@ class UserService:
38 41
         # 哈希密码
39 42
         hashed_password = password_hasher.hash_password(user_data.password)
40 43
         
41
-        # 创建用户
44
+        # 创建用户对象
42 45
         user = User(
43 46
             username=user_data.username,
44 47
             email=user_data.email,
@@ -48,55 +51,84 @@ class UserService:
48 51
             is_verified=False
49 52
         )
50 53
         
54
+        # 异步保存
51 55
         self.db.add(user)
52
-        self.db.commit()
53
-        self.db.refresh(user)
56
+        await self.db.commit()
57
+        await self.db.refresh(user)
54 58
         
55 59
         return user
56 60
     
57 61
     async def get_user_by_id(self, user_id: int) -> Optional[User]:
58
-        """通过ID获取用户"""
59
-        return self.db.query(User).filter(User.id == user_id).first()
62
+        """通过ID获取用户(异步)"""
63
+        stmt = select(User).where(User.id == user_id)
64
+        result = await self.db.execute(stmt)
65
+        return result.scalar_one_or_none()
60 66
     
61 67
     async def get_user_by_username(self, username: str) -> Optional[User]:
62
-        """通过用户名获取用户"""
63
-        return self.db.query(User).filter(User.username == username).first()
68
+        """通过用户名获取用户(异步)"""
69
+        stmt = select(User).where(User.username == username)
70
+        result = await self.db.execute(stmt)
71
+        return result.scalar_one_or_none()
64 72
     
65 73
     async def get_user_by_email(self, email: str) -> Optional[User]:
66
-        """通过邮箱获取用户"""
67
-        return self.db.query(User).filter(User.email == email).first()
68
-    
74
+        """通过邮箱获取用户(异步)"""
75
+        stmt = select(User).where(User.email == email)
76
+        result = await self.db.execute(stmt)
77
+        return result.scalar_one_or_none()
78
+
69 79
     async def authenticate_user(self, username: str, password: str) -> Optional[User]:
70
-        """验证用户"""
71
-        # 通过用户名或邮箱查找用户
72
-        user = self.db.query(User).filter(
80
+        """验证用户(SQLAlchemy 2.0异步版)"""
81
+        # 通过用户名或邮箱查找用户 - 使用异步select
82
+        stmt = select(User).where(
73 83
             or_(
74 84
                 User.username == username,
75 85
                 User.email == username
76 86
             )
77
-        ).first()
87
+        )
88
+        result = await self.db.execute(stmt)
89
+        user = result.scalar_one_or_none()
78 90
         
79 91
         if not user:
80 92
             return None
81 93
         
94
+        # 验证密码
82 95
         if not password_hasher.verify_password(password, user.hashed_password):
83
-            # 记录失败尝试
84
-            user.failed_login_attempts += 1
85
-            if user.failed_login_attempts >= 5:
86
-                user.is_locked = True
87
-            self.db.commit()
96
+            # 记录失败尝试 - 使用异步更新
97
+            new_attempts = (user.failed_login_attempts or 0) + 1
98
+            update_stmt = (
99
+                update(User)
100
+                .where(User.id == user.id)
101
+                .values(
102
+                    failed_login_attempts=new_attempts,
103
+                    is_locked=(new_attempts >= 5)  # 5次失败后锁定
104
+                )
105
+            )
106
+            await self.db.execute(update_stmt)
107
+            await self.db.commit()
88 108
             return None
89 109
         
90
-        # 重置失败尝试
91
-        user.failed_login_attempts = 0
92
-        user.last_login = datetime.now()
93
-        user.is_locked = False
94
-        self.db.commit()
110
+        # 登录成功 - 重置失败尝试并更新最后登录时间
111
+        update_stmt = (
112
+            update(User)
113
+            .where(User.id == user.id)
114
+            .values(
115
+                failed_login_attempts = 0,
116
+                last_login = datetime.now(timezone.utc),
117
+                is_locked = False
118
+            )
119
+        )
95 120
         
96
-        return user
97
-    
121
+        await self.db.execute(update_stmt)
122
+        await self.db.commit()
123
+        
124
+        # 4. 刷新用户对象以获取最新数据
125
+        await self.db.refresh(user)
126
+
127
+        return user 
128
+
129
+    # ==================== 用户操作 ====================
98 130
     async def update_user(self, user_id: int, update_data: dict) -> Optional[User]:
99
-        """更新用户信息"""
131
+        """更新用户信息(异步)"""
100 132
         user = await self.get_user_by_id(user_id)
101 133
         if not user:
102 134
             return None
@@ -105,14 +137,15 @@ class UserService:
105 137
             if hasattr(user, key) and value is not None:
106 138
                 setattr(user, key, value)
107 139
         
108
-        user.updated_at = datetime.now()
109
-        self.db.commit()
110
-        self.db.refresh(user)
140
+        user.updated_at = datetime.now(timezone.utc)
141
+        await self.db.commit()
142
+        await self.db.refresh(user)
111 143
         
112 144
         return user
113 145
     
114 146
     async def change_password(self, user_id: int, current_password: str, new_password: str) -> bool:
115
-        """修改密码"""
147
+        """修改用户密码(异步)"""
148
+        # 获取用户
116 149
         user = await self.get_user_by_id(user_id)
117 150
         if not user:
118 151
             return False
@@ -128,42 +161,58 @@ class UserService:
128 161
         
129 162
         # 更新密码
130 163
         user.hashed_password = password_hasher.hash_password(new_password)
131
-        user.last_password_change = datetime.now()
132
-        self.db.commit()
164
+        user.last_password_change = datetime.now(timezone.utc)
165
+        await self.db.commit()
133 166
         
134 167
         return True
135 168
     
136
-    async def list_users(
137
-        self, 
138
-        skip: int = 0, 
139
-        limit: int = 100,
140
-        active_only: bool = True
141
-    ) -> List[User]:
142
-        """列出用户"""
143
-        query = self.db.query(User)
169
+    # ==================== 用户列表 ====================
170
+    async def list_users(self, skip: int = 0, limit: int = 100, active_only: bool = True) -> List[User]:
171
+        """获取用户列表(异步)"""
172
+        stmt = select(User)
144 173
         
145 174
         if active_only:
146
-            query = query.filter(User.is_active == True)
175
+            stmt = stmt.where(User.is_active == True)
176
+        
177
+        stmt = stmt.offset(skip).limit(limit)
147 178
         
148
-        return query.offset(skip).limit(limit).all()
179
+        result = await self.db.execute(stmt)
180
+        users = result.scalars().all()
181
+        return users
149 182
     
150 183
     async def delete_user(self, user_id: int) -> bool:
151 184
         """删除用户(软删除)"""
152
-        user = await self.get_user_by_id(user_id)
153
-        if not user:
154
-            return False
185
+        """删除/禁用用户(异步)"""
186
+        # 通常我们不会真正删除,而是标记为禁用
187
+        stmt = (
188
+            update(User)
189
+            .where(User.id == user_id)
190
+            .values(is_active = False, updated_at = datetime.now(timezone.utc))
191
+        )
155 192
         
156
-        user.is_active = False
157
-        user.updated_at = datetime.now()
158
-        self.db.commit()
193
+        result = await self.db.execute(stmt)
194
+        await self.db.commit()
159 195
         
160
-        return True
196
+        # 检查是否影响了一行
197
+        return result.rowcount > 0
161 198
     
162 199
     async def count_users(self, active_only: bool = True) -> int:
163
-        """统计用户数量"""
164
-        query = self.db.query(User)
165
-        
200
+        """统计用户数量(异步)"""
166 201
         if active_only:
167
-            query = query.filter(User.is_active == True)
202
+            stmt = select(func.count()).where(User.is_active == True)
203
+        else:
204
+            stmt = select(func.count())
205
+        
206
+        result = await self.db.execute(stmt)
207
+        return result.scalar() or 0
208
+    
209
+    async def update_last_login(self, user_id: int) -> None:
210
+        """更新最后登录时间(异步)"""
211
+        stmt = (
212
+            update(User)
213
+            .where(User.id == user_id)
214
+            .values(last_login=datetime.now(timezone.utc))
215
+        )
168 216
         
169
-        return query.count()
217
+        await self.db.execute(stmt)
218
+        await self.db.commit()