11import asyncio
22from asyncio import Event , Task
33from logging import Logger
4- from typing import Any , Callable , Coroutine , Tuple
4+ from typing import Any , Callable , Coroutine , Dict , Protocol , Tuple
55
66from fastapi import HTTPException , Request
77
88from logging_utils import set_request_id
99from models import ClientDisconnectedError
1010
1111
12+ class SupportsReceive (Protocol ):
13+ """Protocol for request objects that support _receive method."""
14+
15+ def _receive (self ) -> Coroutine [Any , Any , Dict [str , Any ]]:
16+ """Internal method to receive messages from ASGI."""
17+ ...
18+
19+
1220async def check_client_connection (req_id : str , http_request : Request ) -> bool :
1321 """
1422 Checks if the client is still connected.
@@ -18,9 +26,12 @@ async def check_client_connection(req_id: str, http_request: Request) -> bool:
1826 if hasattr (http_request , "_receive" ):
1927 try :
2028 # Use a very short timeout to check for disconnect message
21- # Cast to Coroutine to satisfy mypy, as _receive is awaitable
22- receive_coro : Coroutine [Any , Any , Any ] = http_request ._receive () # type: ignore
23- receive_task : Task [Any ] = asyncio .create_task (receive_coro )
29+ # _receive is a private Starlette/FastAPI method that returns a coroutine
30+ receive_obj = http_request # type: ignore[misc]
31+ receive_coro : Coroutine [Any , Any , Dict [str , Any ]] = (
32+ receive_obj ._receive ()
33+ ) # type: ignore[misc]
34+ receive_task : Task [Dict [str , Any ]] = asyncio .create_task (receive_coro )
2435 done , pending = await asyncio .wait ([receive_task ], timeout = 0.01 )
2536
2637 if done :
@@ -34,6 +45,8 @@ async def check_client_connection(req_id: str, http_request: Request) -> bool:
3445 await receive_task
3546 except asyncio .CancelledError :
3647 pass
48+ except asyncio .CancelledError :
49+ raise
3750 except Exception :
3851 # If checking fails, assume disconnected to be safe, or log and continue?
3952 # Usually if _receive fails it might mean connection issues.
@@ -44,6 +57,8 @@ async def check_client_connection(req_id: str, http_request: Request) -> bool:
4457 return False
4558
4659 return True
60+ except asyncio .CancelledError :
61+ raise
4762 except Exception :
4863 return False
4964
0 commit comments