11import json
22import logging
3+
34from collections .abc import AsyncIterator , Awaitable , Callable
4- from typing import TYPE_CHECKING , Any
5+ from typing import TYPE_CHECKING , Any , TypeVar
56
67from google .protobuf .json_format import MessageToDict , Parse
78
5253
5354logger = logging .getLogger (__name__ )
5455
56+ TResponse = TypeVar ('TResponse' )
57+
58+
5559@trace_class (kind = SpanKind .SERVER )
5660class RestDispatcher :
5761 """Dispatches incoming REST requests to the appropriate handler methods.
@@ -108,29 +112,24 @@ def _build_call_context(self, request: Request) -> ServerCallContext:
108112 call_context .tenant = request .path_params ['tenant' ]
109113 return call_context
110114
111- @rest_error_handler
112- @validate_version (constants .PROTOCOL_VERSION_1_0 )
113- async def on_message_send (self , request : Request ) -> Response :
114- """Handles the 'message/send' REST method."""
115+ async def _handle_non_streaming (
116+ self ,
117+ request : Request ,
118+ handler_func : Callable [[ServerCallContext ], Awaitable [TResponse ]],
119+ ) -> TResponse :
120+ """Centralized error handling and context management for unary calls."""
115121 context = self ._build_call_context (request )
116- body = await request .body ()
117- params = a2a_pb2 .SendMessageRequest ()
118- Parse (body , params )
119- task_or_message = await self .request_handler .on_message_send (params , context )
120- if isinstance (task_or_message , a2a_pb2 .Task ):
121- response = a2a_pb2 .SendMessageResponse (task = task_or_message )
122- else :
123- response = a2a_pb2 .SendMessageResponse (message = task_or_message )
124- return JSONResponse (content = MessageToDict (response ))
122+ return await handler_func (context )
125123
126- @rest_stream_error_handler
127- @validate_version (constants .PROTOCOL_VERSION_1_0 )
128- @validate (
129- lambda self : self .agent_card .capabilities .streaming ,
130- 'Streaming is not supported by the agent' ,
131- )
132- async def on_message_send_stream (self , request : Request ) -> EventSourceResponse :
133- """Handles the 'message/stream' REST method."""
124+ async def _handle_streaming (
125+ self ,
126+ request : Request ,
127+ handler_func : Callable [[ServerCallContext ], AsyncIterator [Any ]],
128+ ) -> EventSourceResponse :
129+ """Centralized error handling and context management for streaming calls."""
130+ # Pre-consume and cache the request body to prevent deadlock in streaming context
131+ # This is required because Starlette's request.body() can only be consumed once,
132+ # and attempting to consume it after EventSourceResponse starts causes deadlock
134133 try :
135134 await request .body ()
136135 except (ValueError , RuntimeError , OSError ) as e :
@@ -139,139 +138,234 @@ async def on_message_send_stream(self, request: Request) -> EventSourceResponse:
139138 ) from e
140139
141140 context = self ._build_call_context (request )
142- body = await request .body ()
143- params = a2a_pb2 .SendMessageRequest ()
144- Parse (body , params )
145141
146- stream = aiter (self .request_handler .on_message_send_stream (params , context ))
142+ # Eagerly fetch the first item from the stream so that errors raised
143+ # before any event is yielded (e.g. validation, parsing, or handler
144+ # failures) propagate here and are caught by
145+ # @rest_stream_error_handler, which returns a JSONResponse with
146+ # the correct HTTP status code instead of starting an SSE stream.
147+ # Without this, the error would be raised after SSE headers are
148+ # already sent, and the client would see a broken stream instead
149+ stream = aiter (handler_func (context ))
147150 try :
148- first_event = await anext (stream )
151+ first_item = await anext (stream )
149152 except StopAsyncIteration :
150153 return EventSourceResponse (iter ([]))
151154
152155 async def event_generator () -> AsyncIterator [str ]:
153- yield json .dumps (MessageToDict ( proto_utils . to_stream_response ( first_event )) )
154- async for event in stream :
155- yield json .dumps (MessageToDict ( proto_utils . to_stream_response ( event )) )
156+ yield json .dumps (first_item )
157+ async for item in stream :
158+ yield json .dumps (item )
156159
157160 return EventSourceResponse (event_generator ())
158161
159162 @rest_error_handler
160- @validate_version (constants .PROTOCOL_VERSION_1_0 )
163+ async def on_message_send (self , request : Request ) -> Response :
164+ """Handles the 'message/send' REST method."""
165+
166+ @validate_version (constants .PROTOCOL_VERSION_1_0 )
167+ async def _handler (
168+ context : ServerCallContext ,
169+ ) -> a2a_pb2 .SendMessageResponse :
170+ body = await request .body ()
171+ params = a2a_pb2 .SendMessageRequest ()
172+ Parse (body , params )
173+ task_or_message = await self .request_handler .on_message_send (
174+ params , context
175+ )
176+ if isinstance (task_or_message , a2a_pb2 .Task ):
177+ return a2a_pb2 .SendMessageResponse (task = task_or_message )
178+ return a2a_pb2 .SendMessageResponse (message = task_or_message )
179+
180+ response = await self ._handle_non_streaming (request , _handler )
181+ return JSONResponse (content = MessageToDict (response ))
182+
183+ @rest_stream_error_handler
184+ async def on_message_send_stream (
185+ self , request : Request
186+ ) -> EventSourceResponse :
187+ """Handles the 'message/stream' REST method."""
188+
189+ @validate_version (constants .PROTOCOL_VERSION_1_0 )
190+ @validate (
191+ lambda _ : self .agent_card .capabilities .streaming ,
192+ 'Streaming is not supported by the agent' ,
193+ )
194+ async def _handler (
195+ context : ServerCallContext ,
196+ ) -> AsyncIterator [dict [str , Any ]]:
197+ body = await request .body ()
198+ params = a2a_pb2 .SendMessageRequest ()
199+ Parse (body , params )
200+ async for event in self .request_handler .on_message_send_stream (
201+ params , context
202+ ):
203+ response = proto_utils .to_stream_response (event )
204+ yield MessageToDict (response )
205+
206+ return await self ._handle_streaming (request , _handler )
207+
208+ @rest_error_handler
161209 async def on_cancel_task (self , request : Request ) -> Response :
162210 """Handles the 'tasks/cancel' REST method."""
163- context = self ._build_call_context (request )
164- task_id = request .path_params ['id' ]
165- task = await self .request_handler .on_cancel_task (CancelTaskRequest (id = task_id ), context )
166- if task :
167- return JSONResponse (content = MessageToDict (task ))
168- raise TaskNotFoundError
211+
212+ @validate_version (constants .PROTOCOL_VERSION_1_0 )
213+ async def _handler (context : ServerCallContext ) -> a2a_pb2 .Task :
214+ task_id = request .path_params ['id' ]
215+ task = await self .request_handler .on_cancel_task (
216+ CancelTaskRequest (id = task_id ), context
217+ )
218+ if task :
219+ return task
220+ raise TaskNotFoundError
221+
222+ response = await self ._handle_non_streaming (request , _handler )
223+ return JSONResponse (content = MessageToDict (response ))
169224
170225 @rest_stream_error_handler
171- @validate_version (constants .PROTOCOL_VERSION_1_0 )
172- @validate (
173- lambda self : self .agent_card .capabilities .streaming ,
174- 'Streaming is not supported by the agent' ,
175- )
176- async def on_subscribe_to_task (self , request : Request ) -> EventSourceResponse :
226+ async def on_subscribe_to_task (
227+ self , request : Request
228+ ) -> EventSourceResponse :
177229 """Handles the 'SubscribeToTask' REST method."""
178- try :
179- await request .body ()
180- except (ValueError , RuntimeError , OSError ) as e :
181- raise InvalidRequestError (
182- message = f'Failed to pre-consume request body: { e } '
183- ) from e
184-
185- context = self ._build_call_context (request )
186230 task_id = request .path_params ['id' ]
187-
188- stream = aiter (self .request_handler .on_subscribe_to_task (SubscribeToTaskRequest (id = task_id ), context ))
189- try :
190- first_event = await anext (stream )
191- except StopAsyncIteration :
192- return EventSourceResponse (iter ([]))
193231
194- async def event_generator () -> AsyncIterator [str ]:
195- yield json .dumps (MessageToDict (proto_utils .to_stream_response (first_event )))
196- async for event in stream :
197- yield json .dumps (MessageToDict (proto_utils .to_stream_response (event )))
232+ @validate_version (constants .PROTOCOL_VERSION_1_0 )
233+ @validate (
234+ lambda _ : self .agent_card .capabilities .streaming ,
235+ 'Streaming is not supported by the agent' ,
236+ )
237+ async def _handler (
238+ context : ServerCallContext ,
239+ ) -> AsyncIterator [dict [str , Any ]]:
240+ async for event in self .request_handler .on_subscribe_to_task (
241+ SubscribeToTaskRequest (id = task_id ), context
242+ ):
243+ response = proto_utils .to_stream_response (event )
244+ yield MessageToDict (response )
198245
199- return EventSourceResponse ( event_generator () )
246+ return await self . _handle_streaming ( request , _handler )
200247
201248 @rest_error_handler
202- @validate_version (constants .PROTOCOL_VERSION_1_0 )
203249 async def on_get_task (self , request : Request ) -> Response :
204250 """Handles the 'tasks/{id}' REST method."""
205- context = self ._build_call_context (request )
206- params = a2a_pb2 .GetTaskRequest ()
207- proto_utils .parse_params (request .query_params , params )
208- params .id = request .path_params ['id' ]
209- task = await self .request_handler .on_get_task (params , context )
210- if task :
211- return JSONResponse (content = MessageToDict (task ))
212- raise TaskNotFoundError
251+
252+ @validate_version (constants .PROTOCOL_VERSION_1_0 )
253+ async def _handler (context : ServerCallContext ) -> a2a_pb2 .Task :
254+ params = a2a_pb2 .GetTaskRequest ()
255+ proto_utils .parse_params (request .query_params , params )
256+ params .id = request .path_params ['id' ]
257+ task = await self .request_handler .on_get_task (params , context )
258+ if task :
259+ return task
260+ raise TaskNotFoundError
261+
262+ response = await self ._handle_non_streaming (request , _handler )
263+ return JSONResponse (content = MessageToDict (response ))
213264
214265 @rest_error_handler
215- @validate_version (constants .PROTOCOL_VERSION_1_0 )
216266 async def get_push_notification (self , request : Request ) -> Response :
217267 """Handles the 'tasks/pushNotificationConfig/get' REST method."""
218- context = self ._build_call_context (request )
219- task_id = request .path_params ['id' ]
220- push_id = request .path_params ['push_id' ]
221- params = GetTaskPushNotificationConfigRequest (task_id = task_id , id = push_id )
222- config = await self .request_handler .on_get_task_push_notification_config (params , context )
223- return JSONResponse (content = MessageToDict (config ))
268+
269+ @validate_version (constants .PROTOCOL_VERSION_1_0 )
270+ async def _handler (
271+ context : ServerCallContext ,
272+ ) -> a2a_pb2 .TaskPushNotificationConfig :
273+ task_id = request .path_params ['id' ]
274+ push_id = request .path_params ['push_id' ]
275+ params = GetTaskPushNotificationConfigRequest (
276+ task_id = task_id , id = push_id
277+ )
278+ return (
279+ await self .request_handler .on_get_task_push_notification_config (
280+ params , context
281+ )
282+ )
283+
284+ response = await self ._handle_non_streaming (request , _handler )
285+ return JSONResponse (content = MessageToDict (response ))
224286
225287 @rest_error_handler
226- @validate_version (constants .PROTOCOL_VERSION_1_0 )
227288 async def delete_push_notification (self , request : Request ) -> Response :
228289 """Handles the 'tasks/pushNotificationConfig/delete' REST method."""
229- context = self ._build_call_context (request )
230- task_id = request .path_params ['id' ]
231- push_id = request .path_params ['push_id' ]
232- params = a2a_pb2 .DeleteTaskPushNotificationConfigRequest (task_id = task_id , id = push_id )
233- await self .request_handler .on_delete_task_push_notification_config (params , context )
290+
291+ @validate_version (constants .PROTOCOL_VERSION_1_0 )
292+ async def _handler (context : ServerCallContext ) -> None :
293+ task_id = request .path_params ['id' ]
294+ push_id = request .path_params ['push_id' ]
295+ params = a2a_pb2 .DeleteTaskPushNotificationConfigRequest (
296+ task_id = task_id , id = push_id
297+ )
298+ await self .request_handler .on_delete_task_push_notification_config (
299+ params , context
300+ )
301+
302+ await self ._handle_non_streaming (request , _handler )
234303 return JSONResponse (content = {})
235304
236305 @rest_error_handler
237- @validate_version (constants .PROTOCOL_VERSION_1_0 )
238- @validate (
239- lambda self : self .agent_card .capabilities .push_notifications ,
240- 'Push notifications are not supported by the agent' ,
241- )
242306 async def set_push_notification (self , request : Request ) -> Response :
243307 """Handles the 'tasks/pushNotificationConfig/set' REST method."""
244- context = self ._build_call_context (request )
245- body = await request .body ()
246- params = a2a_pb2 .TaskPushNotificationConfig ()
247- Parse (body , params )
248- params .task_id = request .path_params ['id' ]
249- config = await self .request_handler .on_create_task_push_notification_config (params , context )
250- return JSONResponse (content = MessageToDict (config ))
308+
309+ @validate_version (constants .PROTOCOL_VERSION_1_0 )
310+ @validate (
311+ lambda _ : self .agent_card .capabilities .push_notifications ,
312+ 'Push notifications are not supported by the agent' ,
313+ )
314+ async def _handler (
315+ context : ServerCallContext ,
316+ ) -> a2a_pb2 .TaskPushNotificationConfig :
317+ body = await request .body ()
318+ params = a2a_pb2 .TaskPushNotificationConfig ()
319+ Parse (body , params )
320+ params .task_id = request .path_params ['id' ]
321+ return await self .request_handler .on_create_task_push_notification_config (
322+ params , context
323+ )
324+
325+ response = await self ._handle_non_streaming (request , _handler )
326+ return JSONResponse (content = MessageToDict (response ))
251327
252328 @rest_error_handler
253- @validate_version (constants .PROTOCOL_VERSION_1_0 )
254329 async def list_push_notifications (self , request : Request ) -> Response :
255330 """Handles the 'tasks/pushNotificationConfig/list' REST method."""
256- context = self ._build_call_context (request )
257- params = a2a_pb2 .ListTaskPushNotificationConfigsRequest ()
258- proto_utils .parse_params (request .query_params , params )
259- params .task_id = request .path_params ['id' ]
260- result = await self .request_handler .on_list_task_push_notification_configs (params , context )
261- return JSONResponse (content = MessageToDict (result ))
331+
332+ @validate_version (constants .PROTOCOL_VERSION_1_0 )
333+ async def _handler (
334+ context : ServerCallContext ,
335+ ) -> a2a_pb2 .ListTaskPushNotificationConfigsResponse :
336+ params = a2a_pb2 .ListTaskPushNotificationConfigsRequest ()
337+ proto_utils .parse_params (request .query_params , params )
338+ params .task_id = request .path_params ['id' ]
339+ return await self .request_handler .on_list_task_push_notification_configs (
340+ params , context
341+ )
342+
343+ response = await self ._handle_non_streaming (request , _handler )
344+ return JSONResponse (content = MessageToDict (response ))
262345
263346 @rest_error_handler
264- @validate_version (constants .PROTOCOL_VERSION_1_0 )
265347 async def list_tasks (self , request : Request ) -> Response :
266348 """Handles the 'tasks/list' REST method."""
267- context = self ._build_call_context (request )
268- params = a2a_pb2 .ListTasksRequest ()
269- proto_utils .parse_params (request .query_params , params )
270- result = await self .request_handler .on_list_tasks (params , context )
271- return JSONResponse (content = MessageToDict (result , always_print_fields_with_no_presence = True ))
349+
350+ @validate_version (constants .PROTOCOL_VERSION_1_0 )
351+ async def _handler (
352+ context : ServerCallContext ,
353+ ) -> a2a_pb2 .ListTasksResponse :
354+ params = a2a_pb2 .ListTasksRequest ()
355+ proto_utils .parse_params (request .query_params , params )
356+ return await self .request_handler .on_list_tasks (params , context )
357+
358+ response = await self ._handle_non_streaming (request , _handler )
359+ return JSONResponse (
360+ content = MessageToDict (
361+ response , always_print_fields_with_no_presence = True
362+ )
363+ )
272364
273365 @rest_error_handler
274- async def handle_authenticated_agent_card (self , request : Request ) -> Response :
366+ async def handle_authenticated_agent_card (
367+ self , request : Request
368+ ) -> Response :
275369 """Handles the 'extendedAgentCard' REST method."""
276370 if not self .agent_card .capabilities .extended_agent_card :
277371 raise ExtendedAgentCardNotConfiguredError (
@@ -288,5 +382,7 @@ async def handle_authenticated_agent_card(self, request: Request) -> Response:
288382 card_to_serve = await maybe_await (self .card_modifier (card_to_serve ))
289383
290384 return JSONResponse (
291- content = MessageToDict (card_to_serve , preserving_proto_field_name = True )
385+ content = MessageToDict (
386+ card_to_serve , preserving_proto_field_name = True
387+ )
292388 )
0 commit comments