33from contextlib import AsyncExitStack
44from datetime import timedelta
55from types import TracebackType
6- from typing import Any , Generic , TypeVar
6+ from typing import Any , Generic , Protocol , TypeVar
77
88import anyio
99import httpx
2424 JSONRPCNotification ,
2525 JSONRPCRequest ,
2626 JSONRPCResponse ,
27+ ProgressNotification ,
2728 RequestParams ,
2829 ServerNotification ,
2930 ServerRequest ,
4243RequestId = str | int
4344
4445
46+ class ProgressCallbackFnT (Protocol ):
47+ """Protocol for progress notification callbacks."""
48+
49+ def __call__ (
50+ self , progress : float , total : float | None , message : str | None
51+ ) -> None :
52+ """Called when progress updates are received.
53+
54+ Args:
55+ progress: Current progress value
56+ total: Total progress value (if known), None if indeterminate
57+ message: Optional progress message
58+ """
59+ ...
60+
61+
4562class RequestResponder (Generic [ReceiveRequestT , SendResultT ]):
4663 """Handles responding to MCP requests and manages request lifecycle.
4764
@@ -169,6 +186,7 @@ class BaseSession(
169186 ]
170187 _request_id : int
171188 _in_flight : dict [RequestId , RequestResponder [ReceiveRequestT , SendResultT ]]
189+ _progress_callbacks : dict [RequestId , ProgressCallbackFnT ]
172190
173191 def __init__ (
174192 self ,
@@ -187,6 +205,7 @@ def __init__(
187205 self ._receive_notification_type = receive_notification_type
188206 self ._session_read_timeout_seconds = read_timeout_seconds
189207 self ._in_flight = {}
208+ self ._progress_callbacks = {}
190209 self ._exit_stack = AsyncExitStack ()
191210
192211 async def __aenter__ (self ) -> Self :
@@ -214,6 +233,7 @@ async def send_request(
214233 result_type : type [ReceiveResultT ],
215234 request_read_timeout_seconds : timedelta | None = None ,
216235 metadata : MessageMetadata = None ,
236+ progress_callback : ProgressCallbackFnT | None = None ,
217237 ) -> ReceiveResultT :
218238 """
219239 Sends a request and wait for a response. Raises an McpError if the
@@ -231,15 +251,25 @@ async def send_request(
231251 ](1 )
232252 self ._response_streams [request_id ] = response_stream
233253
254+ # Set up progress token if progress callback is provided
255+ request_data = request .model_dump (by_alias = True , mode = "json" , exclude_none = True )
256+ if progress_callback is not None :
257+ # Use request_id as progress token
258+ if "params" not in request_data :
259+ request_data ["params" ] = {}
260+ if "_meta" not in request_data ["params" ]:
261+ request_data ["params" ]["_meta" ] = {}
262+ request_data ["params" ]["_meta" ]["progressToken" ] = request_id
263+ # Store the callback for this request
264+ self ._progress_callbacks [request_id ] = progress_callback
265+
234266 try :
235267 jsonrpc_request = JSONRPCRequest (
236268 jsonrpc = "2.0" ,
237269 id = request_id ,
238- ** request . model_dump ( by_alias = True , mode = "json" , exclude_none = True ) ,
270+ ** request_data ,
239271 )
240272
241- # TODO: Support progress callbacks
242-
243273 await self ._write_stream .send (
244274 SessionMessage (
245275 message = JSONRPCMessage (jsonrpc_request ), metadata = metadata
@@ -275,6 +305,7 @@ async def send_request(
275305
276306 finally :
277307 self ._response_streams .pop (request_id , None )
308+ self ._progress_callbacks .pop (request_id , None )
278309 await response_stream .aclose ()
279310 await response_stream_reader .aclose ()
280311
@@ -333,7 +364,6 @@ async def _receive_loop(self) -> None:
333364 by_alias = True , mode = "json" , exclude_none = True
334365 )
335366 )
336-
337367 responder = RequestResponder (
338368 request_id = message .message .root .id ,
339369 request_meta = validated_request .root .params .meta
@@ -362,6 +392,18 @@ async def _receive_loop(self) -> None:
362392 cancelled_id = notification .root .params .requestId
363393 if cancelled_id in self ._in_flight :
364394 await self ._in_flight [cancelled_id ].cancel ()
395+ # Handle progress notifications
396+ elif isinstance (notification .root , ProgressNotification ):
397+ progress_token = notification .root .params .progressToken
398+ # If there is a progress callback for this token,
399+ # call it with the progress information
400+ if progress_token in self ._progress_callbacks :
401+ callback = self ._progress_callbacks [progress_token ]
402+ callback (
403+ notification .root .params .progress ,
404+ notification .root .params .total ,
405+ notification .root .params .message ,
406+ )
365407 else :
366408 await self ._received_notification (notification )
367409 await self ._handle_incoming (notification )
0 commit comments