11import asyncio
2- from typing import Callable , Tuple
3- from asyncio import Event
2+ from typing import Callable , Tuple , Optional
3+ from asyncio import Event , Task
44from fastapi import HTTPException , Request
5+ from models import ClientDisconnectedError
56
6-
7- async def test_client_connection (req_id : str , http_request : Request ) -> bool :
7+ async def check_client_connection (req_id : str , http_request : Request ) -> bool :
8+ """
9+ Checks if the client is still connected.
10+ Returns True if connected, False if disconnected.
11+ """
812 try :
913 if hasattr (http_request , '_receive' ):
1014 try :
11- receive_task = asyncio .create_task (http_request ._receive ())
15+ # Use a very short timeout to check for disconnect message
16+ # Cast to Coroutine to satisfy mypy, as _receive is awaitable
17+ from typing import Coroutine , Any
18+ receive_coro : Coroutine [Any , Any , Any ] = http_request ._receive () # type: ignore
19+ receive_task : Task = asyncio .create_task (receive_coro )
1220 done , pending = await asyncio .wait ([receive_task ], timeout = 0.01 )
21+
1322 if done :
1423 message = receive_task .result ()
1524 if message .get ("type" ) == "http.disconnect" :
1625 return False
1726 else :
27+ # Cancel the task if it didn't complete immediately
1828 receive_task .cancel ()
1929 try :
2030 await receive_task
2131 except asyncio .CancelledError :
2232 pass
2333 except Exception :
34+ # If checking fails, assume disconnected to be safe, or log and continue?
35+ # Usually if _receive fails it might mean connection issues.
2436 return False
37+
38+ # Fallback to is_disconnected() if available (Starlette/FastAPI)
39+ if await http_request .is_disconnected ():
40+ return False
41+
2542 return True
2643 except Exception :
2744 return False
2845
2946
30- async def setup_disconnect_monitoring (req_id : str , http_request : Request , result_future ) -> Tuple [Event , asyncio .Task , Callable ]:
47+ async def setup_disconnect_monitoring (req_id : str , http_request : Request , result_future ) -> Tuple [Event , Task , Callable ]:
48+ """
49+ Sets up a background task to monitor client disconnection.
50+ Returns:
51+ - client_disconnected_event: Event set when disconnect is detected
52+ - disconnect_check_task: The background task
53+ - check_client_disconnected: Helper function to raise error if disconnected
54+ """
3155 from server import logger
3256 client_disconnected_event = Event ()
3357
3458 async def check_disconnect_periodically ():
3559 while not client_disconnected_event .is_set ():
3660 try :
37- is_connected = await test_client_connection (req_id , http_request )
61+ is_connected = await check_client_connection (req_id , http_request )
3862 if not is_connected :
39- logger .info (f"[{ req_id } ] 主动检测到客户端断开连接。" )
40- client_disconnected_event .set ()
41- if not result_future .done ():
42- result_future .set_exception (HTTPException (status_code = 499 , detail = f"[{ req_id } ] 客户端关闭了请求" ))
43- break
44-
45- if await http_request .is_disconnected ():
46- logger .info (f"[{ req_id } ] 备用检测到客户端断开连接。" )
63+ logger .info (f"[{ req_id } ] Active disconnect check detected client disconnection." )
4764 client_disconnected_event .set ()
4865 if not result_future .done ():
49- result_future .set_exception (HTTPException (status_code = 499 , detail = f"[{ req_id } ] 客户端关闭了请求 " ))
66+ result_future .set_exception (HTTPException (status_code = 499 , detail = f"[{ req_id } ] Client closed request " ))
5067 break
68+
5169 await asyncio .sleep (0.3 )
5270 except asyncio .CancelledError :
71+ # Task cancelled, exit gracefully
5372 break
5473 except Exception as e :
55- logger .error (f"[{ req_id } ] (Disco Check Task) 错误 : { e } " )
74+ logger .error (f"[{ req_id } ] (Disco Check Task) Error : { e } " )
5675 client_disconnected_event .set ()
5776 if not result_future .done ():
5877 result_future .set_exception (HTTPException (status_code = 500 , detail = f"[{ req_id } ] Internal disconnect checker error: { e } " ))
@@ -62,10 +81,61 @@ async def check_disconnect_periodically():
6281
6382 def check_client_disconnected (stage : str = "" ):
6483 if client_disconnected_event .is_set ():
65- logger .info (f"[{ req_id } ] 在 '{ stage } ' 检测到客户端断开连接。" )
66- from models import ClientDisconnectedError
84+ logger .info (f"[{ req_id } ] Client disconnected detected at stage: '{ stage } '" )
6785 raise ClientDisconnectedError (f"[{ req_id } ] Client disconnected at stage: { stage } " )
6886 return False
6987
7088 return client_disconnected_event , disconnect_check_task , check_client_disconnected
7189
90+
91+ async def enhanced_disconnect_monitor (req_id : str , http_request : Request , completion_event : Event , logger ) -> bool :
92+ """
93+ Enhanced disconnect monitor for streaming responses.
94+ Returns True if client disconnected early.
95+ """
96+ client_disconnected_early = False
97+ while not completion_event .is_set ():
98+ try :
99+ is_connected = await check_client_connection (req_id , http_request )
100+ if not is_connected :
101+ logger .info (f"[{ req_id } ] (Monitor) ✅ Client disconnected during streaming, triggering completion event." )
102+ client_disconnected_early = True
103+ if not completion_event .is_set ():
104+ completion_event .set ()
105+ break
106+ await asyncio .sleep (0.3 )
107+ except asyncio .CancelledError :
108+ break
109+ except Exception as e :
110+ logger .error (f"[{ req_id } ] (Monitor) Enhanced disconnect checker error: { e } " )
111+ break
112+ return client_disconnected_early
113+
114+
115+ async def non_streaming_disconnect_monitor (req_id : str , http_request : Request , result_future : asyncio .Future , logger ) -> bool :
116+ """
117+ Disconnect monitor for non-streaming responses.
118+ Returns True if client disconnected early.
119+ """
120+ client_disconnected_early = False
121+ while not result_future .done ():
122+ try :
123+ is_connected = await check_client_connection (req_id , http_request )
124+ if not is_connected :
125+ logger .info (f"[{ req_id } ] (Monitor) ✅ Client disconnected during non-streaming processing." )
126+ client_disconnected_early = True
127+ if not result_future .done ():
128+ result_future .set_exception (
129+ HTTPException (
130+ status_code = 499 ,
131+ detail = f"[{ req_id } ] Client disconnected during processing" ,
132+ )
133+ )
134+ break
135+ await asyncio .sleep (0.3 )
136+ except asyncio .CancelledError :
137+ break
138+ except Exception as e :
139+ logger .error (f"[{ req_id } ] (Monitor) Non-streaming disconnect checker error: { e } " )
140+ break
141+ return client_disconnected_early
0 commit comments