Skip to content

Commit 2de48b0

Browse files
authored
Merge pull request #271 from NikkeTryHard/feature/monkeytype-strict-typing
refactor: comprehensive strict type annotations with enhanced test coverage
2 parents 1c69ae7 + 8463661 commit 2de48b0

109 files changed

Lines changed: 8562 additions & 1319 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.gitignore

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ coverage.xml
188188
*.py,cover
189189
.hypothesis/
190190
.pytest_cache/
191+
.testmondata*
191192

192193
# Environments
193194
.env
@@ -260,3 +261,12 @@ browser_utils/generated_*.js
260261
# Docker 环境的实际配置文件(保留示例文件)
261262
docker/.env
262263
docker/my_*.json
264+
265+
monkeytype.sqlite3
266+
267+
# Temporary debug/output files
268+
pyright_output.txt
269+
pyright_utils_full.txt
270+
temp_*.txt
271+
temp_*.md
272+
utils_errors.txt

api_utils/app.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,8 @@ async def lifespan(app: FastAPI):
228228
logger.info("Server startup complete.")
229229
state.is_initializing = False
230230
yield
231+
except asyncio.CancelledError:
232+
raise
231233
except Exception as e:
232234
logger.critical(f"Application startup failed: {e}", exc_info=True)
233235
await _shutdown_resources()

api_utils/client_connection.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
11
import asyncio
22
from asyncio import Event, Task
33
from logging import Logger
4-
from typing import Any, Callable, Coroutine, Tuple
4+
from typing import Any, Callable, Coroutine, Dict, Protocol, Tuple
55

66
from fastapi import HTTPException, Request
77

88
from logging_utils import set_request_id
99
from models import ClientDisconnectedError
1010

1111

12+
class SupportsReceive(Protocol):
13+
"""Protocol for request objects that support _receive method."""
14+
15+
def _receive(self) -> Coroutine[Any, Any, Dict[str, Any]]:
16+
"""Internal method to receive messages from ASGI."""
17+
...
18+
19+
1220
async def check_client_connection(req_id: str, http_request: Request) -> bool:
1321
"""
1422
Checks if the client is still connected.
@@ -18,9 +26,12 @@ async def check_client_connection(req_id: str, http_request: Request) -> bool:
1826
if hasattr(http_request, "_receive"):
1927
try:
2028
# Use a very short timeout to check for disconnect message
21-
# Cast to Coroutine to satisfy mypy, as _receive is awaitable
22-
receive_coro: Coroutine[Any, Any, Any] = http_request._receive() # type: ignore
23-
receive_task: Task[Any] = asyncio.create_task(receive_coro)
29+
# _receive is a private Starlette/FastAPI method that returns a coroutine
30+
receive_obj = http_request # type: ignore[misc]
31+
receive_coro: Coroutine[Any, Any, Dict[str, Any]] = (
32+
receive_obj._receive()
33+
) # type: ignore[misc]
34+
receive_task: Task[Dict[str, Any]] = asyncio.create_task(receive_coro)
2435
done, pending = await asyncio.wait([receive_task], timeout=0.01)
2536

2637
if done:
@@ -34,6 +45,8 @@ async def check_client_connection(req_id: str, http_request: Request) -> bool:
3445
await receive_task
3546
except asyncio.CancelledError:
3647
pass
48+
except asyncio.CancelledError:
49+
raise
3750
except Exception:
3851
# If checking fails, assume disconnected to be safe, or log and continue?
3952
# Usually if _receive fails it might mean connection issues.
@@ -44,6 +57,8 @@ async def check_client_connection(req_id: str, http_request: Request) -> bool:
4457
return False
4558

4659
return True
60+
except asyncio.CancelledError:
61+
raise
4762
except Exception:
4863
return False
4964

api_utils/context_init.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import cast
2+
13
from logging_utils import set_request_id
24
from models import ChatCompletionRequest
35

@@ -13,20 +15,23 @@ async def initialize_request_context(
1315
state.logger.info("开始处理请求...")
1416
state.logger.info(f" 请求参数 - Model: {request.model}, Stream: {request.stream}")
1517

16-
context: RequestContext = {
17-
"logger": state.logger,
18-
"page": state.page_instance,
19-
"is_page_ready": state.is_page_ready,
20-
"parsed_model_list": state.parsed_model_list,
21-
"current_ai_studio_model_id": state.current_ai_studio_model_id,
22-
"model_switching_lock": state.model_switching_lock,
23-
"page_params_cache": state.page_params_cache,
24-
"params_cache_lock": state.params_cache_lock,
25-
"is_streaming": request.stream,
26-
"model_actually_switched": False,
27-
"requested_model": request.model,
28-
"model_id_to_use": None,
29-
"needs_model_switching": False,
30-
}
18+
context: RequestContext = cast(
19+
RequestContext,
20+
{
21+
"logger": state.logger,
22+
"page": state.page_instance,
23+
"is_page_ready": state.is_page_ready,
24+
"parsed_model_list": state.parsed_model_list,
25+
"current_ai_studio_model_id": state.current_ai_studio_model_id,
26+
"model_switching_lock": state.model_switching_lock,
27+
"page_params_cache": state.page_params_cache,
28+
"params_cache_lock": state.params_cache_lock,
29+
"is_streaming": request.stream,
30+
"model_actually_switched": False,
31+
"requested_model": request.model,
32+
"model_id_to_use": None,
33+
"needs_model_switching": False,
34+
},
35+
)
3136

3237
return context

api_utils/context_types.py

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,53 @@
1-
from typing import Any, List, Optional, TypedDict
1+
import logging
2+
from asyncio import Future, Lock
3+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypedDict
24

35
from playwright.async_api import Page as AsyncPage
46

7+
if TYPE_CHECKING:
8+
from fastapi import Request
9+
from fastapi.responses import JSONResponse
510

6-
class RequestContext(TypedDict, total=False):
7-
logger: Any
8-
page: Optional[AsyncPage]
11+
from models.chat import ChatCompletionRequest
12+
13+
14+
class QueueItem(TypedDict):
15+
"""Type definition for items in the request queue.
16+
17+
This defines the structure of each item put into the request_queue,
18+
ensuring type safety for queue operations.
19+
"""
20+
21+
req_id: str
22+
request_data: "ChatCompletionRequest"
23+
http_request: "Request"
24+
result_future: "Future[JSONResponse]"
25+
enqueue_time: float
26+
cancelled: bool
27+
28+
29+
class RequestContext(TypedDict):
30+
"""Request context with all keys always present after initialization.
31+
32+
All keys are required (always exist in the dict) after context_init.py initialization.
33+
Optional[] types indicate that the VALUE can be None, not that the key might not exist.
34+
"""
35+
36+
# Core components (always set by context_init.py)
37+
logger: logging.Logger
38+
page: Optional[AsyncPage] # Value can be None if browser not ready
939
is_page_ready: bool
10-
parsed_model_list: List[dict]
11-
current_ai_studio_model_id: Optional[str]
12-
model_switching_lock: Any
13-
page_params_cache: dict
14-
params_cache_lock: Any
40+
parsed_model_list: List[Dict[str, Any]]
41+
current_ai_studio_model_id: Optional[str] # Value can be None initially
42+
43+
# Locks (always set by server_state)
44+
model_switching_lock: Lock
45+
page_params_cache: Dict[str, Any]
46+
params_cache_lock: Lock
47+
48+
# Request-specific state (always initialized)
1549
is_streaming: bool
1650
model_actually_switched: bool
17-
requested_model: Optional[str]
18-
model_id_to_use: Optional[str]
51+
requested_model: Optional[str] # Value can be None if not specified
52+
model_id_to_use: Optional[str] # Value set during model analysis
1953
needs_model_switching: bool

api_utils/dependencies.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from asyncio import Event, Lock, Queue
77
from typing import Any, Dict, List, Set
88

9+
from api_utils.context_types import QueueItem
10+
911

1012
def get_logger() -> logging.Logger:
1113
from server import logger
@@ -19,7 +21,7 @@ def get_log_ws_manager():
1921
return log_ws_manager
2022

2123

22-
def get_request_queue() -> Queue:
24+
def get_request_queue() -> "Queue[QueueItem]":
2325
from server import request_queue
2426

2527
return request_queue

api_utils/mcp_adapter.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import json
23
import os
34
from typing import Any, Dict
@@ -30,6 +31,8 @@ async def execute_mcp_tool(name: str, params: Dict[str, Any]) -> str:
3031
resp.raise_for_status()
3132
try:
3233
data = resp.json()
34+
except asyncio.CancelledError:
35+
raise
3336
except Exception:
3437
data = {"raw": resp.text}
3538
return json.dumps(data, ensure_ascii=False)
@@ -47,6 +50,8 @@ async def execute_mcp_tool_with_endpoint(
4750
resp.raise_for_status()
4851
try:
4952
data = resp.json()
53+
except asyncio.CancelledError:
54+
raise
5055
except Exception:
5156
data = {"raw": resp.text}
5257
return json.dumps(data, ensure_ascii=False)

api_utils/model_switching.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import logging
2+
13
from playwright.async_api import Page as AsyncPage
24

35
from api_utils.server_state import state
@@ -19,7 +21,9 @@ async def analyze_model_requirements(
1921
logger.info(f"请求使用模型: {requested_model_id}")
2022

2123
if parsed_model_list:
22-
valid_model_ids = [m.get("id") for m in parsed_model_list]
24+
valid_model_ids = [
25+
str(m.get("id")) for m in parsed_model_list if m.get("id")
26+
]
2327
if requested_model_id not in valid_model_ids:
2428
from .error_utils import bad_request
2529

@@ -50,6 +54,10 @@ async def handle_model_switching(
5054
model_switching_lock = context["model_switching_lock"]
5155
model_id_to_use = context["model_id_to_use"]
5256

57+
# Assert non-None values required for model switching
58+
assert page is not None, "Page must be ready for model switching"
59+
assert model_id_to_use is not None, "Target model ID must be set"
60+
5361
async with model_switching_lock:
5462
if state.current_ai_studio_model_id != model_id_to_use:
5563
logger.info(
@@ -64,19 +72,25 @@ async def handle_model_switching(
6472
context["current_ai_studio_model_id"] = model_id_to_use
6573
logger.info(f"模型切换成功: {state.current_ai_studio_model_id}")
6674
else:
75+
# Current model ID should exist when switching fails
76+
current_model = state.current_ai_studio_model_id or "unknown"
6777
await _handle_model_switch_failure(
6878
req_id,
6979
page,
7080
model_id_to_use,
71-
state.current_ai_studio_model_id,
81+
current_model,
7282
logger,
7383
)
7484

7585
return context
7686

7787

7888
async def _handle_model_switch_failure(
79-
req_id: str, page: AsyncPage, model_id_to_use: str, model_before_switch: str, logger
89+
req_id: str,
90+
page: AsyncPage,
91+
model_id_to_use: str,
92+
model_before_switch: str,
93+
logger: logging.Logger,
8094
) -> None:
8195
set_request_id(req_id)
8296
logger.warning(f"模型切换至 {model_id_to_use} 失败。")

api_utils/page_response.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import asyncio
2+
import logging
3+
from typing import Callable
24

35
from playwright.async_api import Error as PlaywrightAsyncError
46
from playwright.async_api import Page as AsyncPage
@@ -9,7 +11,10 @@
911

1012

1113
async def locate_response_elements(
12-
page: AsyncPage, req_id: str, logger, check_client_disconnected
14+
page: AsyncPage,
15+
req_id: str,
16+
logger: logging.Logger,
17+
check_client_disconnected: Callable[[str], bool],
1318
) -> None:
1419
"""定位响应容器与文本元素,包含超时与错误处理。"""
1520
set_request_id(req_id)
@@ -26,6 +31,8 @@ async def locate_response_elements(
2631
from .error_utils import upstream_error
2732

2833
raise upstream_error(req_id, f"定位AI Studio响应元素失败: {locate_err}")
34+
except asyncio.CancelledError:
35+
raise
2936
except Exception as locate_exc:
3037
from .error_utils import server_error
3138

0 commit comments

Comments
 (0)