Skip to content

Commit d60e259

Browse files
committed
fix
1 parent ffd0de1 commit d60e259

4 files changed

Lines changed: 514 additions & 140 deletions

File tree

Lines changed: 207 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import json
22
import logging
3+
34
from collections.abc import AsyncIterator, Awaitable, Callable
4-
from typing import TYPE_CHECKING, Any
5+
from typing import TYPE_CHECKING, Any, TypeVar
56

67
from google.protobuf.json_format import MessageToDict, Parse
78

@@ -52,6 +53,9 @@
5253

5354
logger = logging.getLogger(__name__)
5455

56+
TResponse = TypeVar('TResponse')
57+
58+
5559
@trace_class(kind=SpanKind.SERVER)
5660
class RestDispatcher:
5761
"""Dispatches incoming REST requests to the appropriate handler methods.
@@ -108,29 +112,24 @@ def _build_call_context(self, request: Request) -> ServerCallContext:
108112
call_context.tenant = request.path_params['tenant']
109113
return call_context
110114

111-
@rest_error_handler
112-
@validate_version(constants.PROTOCOL_VERSION_1_0)
113-
async def on_message_send(self, request: Request) -> Response:
114-
"""Handles the 'message/send' REST method."""
115+
async def _handle_non_streaming(
116+
self,
117+
request: Request,
118+
handler_func: Callable[[ServerCallContext], Awaitable[TResponse]],
119+
) -> TResponse:
120+
"""Centralized error handling and context management for unary calls."""
115121
context = self._build_call_context(request)
116-
body = await request.body()
117-
params = a2a_pb2.SendMessageRequest()
118-
Parse(body, params)
119-
task_or_message = await self.request_handler.on_message_send(params, context)
120-
if isinstance(task_or_message, a2a_pb2.Task):
121-
response = a2a_pb2.SendMessageResponse(task=task_or_message)
122-
else:
123-
response = a2a_pb2.SendMessageResponse(message=task_or_message)
124-
return JSONResponse(content=MessageToDict(response))
122+
return await handler_func(context)
125123

126-
@rest_stream_error_handler
127-
@validate_version(constants.PROTOCOL_VERSION_1_0)
128-
@validate(
129-
lambda self: self.agent_card.capabilities.streaming,
130-
'Streaming is not supported by the agent',
131-
)
132-
async def on_message_send_stream(self, request: Request) -> EventSourceResponse:
133-
"""Handles the 'message/stream' REST method."""
124+
async def _handle_streaming(
125+
self,
126+
request: Request,
127+
handler_func: Callable[[ServerCallContext], AsyncIterator[Any]],
128+
) -> EventSourceResponse:
129+
"""Centralized error handling and context management for streaming calls."""
130+
# Pre-consume and cache the request body to prevent deadlock in streaming context
131+
# This is required because Starlette's request.body() can only be consumed once,
132+
# and attempting to consume it after EventSourceResponse starts causes deadlock
134133
try:
135134
await request.body()
136135
except (ValueError, RuntimeError, OSError) as e:
@@ -139,139 +138,234 @@ async def on_message_send_stream(self, request: Request) -> EventSourceResponse:
139138
) from e
140139

141140
context = self._build_call_context(request)
142-
body = await request.body()
143-
params = a2a_pb2.SendMessageRequest()
144-
Parse(body, params)
145141

146-
stream = aiter(self.request_handler.on_message_send_stream(params, context))
142+
# Eagerly fetch the first item from the stream so that errors raised
143+
# before any event is yielded (e.g. validation, parsing, or handler
144+
# failures) propagate here and are caught by
145+
# @rest_stream_error_handler, which returns a JSONResponse with
146+
# the correct HTTP status code instead of starting an SSE stream.
147+
# Without this, the error would be raised after SSE headers are
148+
# already sent, and the client would see a broken stream instead
149+
stream = aiter(handler_func(context))
147150
try:
148-
first_event = await anext(stream)
151+
first_item = await anext(stream)
149152
except StopAsyncIteration:
150153
return EventSourceResponse(iter([]))
151154

152155
async def event_generator() -> AsyncIterator[str]:
153-
yield json.dumps(MessageToDict(proto_utils.to_stream_response(first_event)))
154-
async for event in stream:
155-
yield json.dumps(MessageToDict(proto_utils.to_stream_response(event)))
156+
yield json.dumps(first_item)
157+
async for item in stream:
158+
yield json.dumps(item)
156159

157160
return EventSourceResponse(event_generator())
158161

159162
@rest_error_handler
160-
@validate_version(constants.PROTOCOL_VERSION_1_0)
163+
async def on_message_send(self, request: Request) -> Response:
164+
"""Handles the 'message/send' REST method."""
165+
166+
@validate_version(constants.PROTOCOL_VERSION_1_0)
167+
async def _handler(
168+
context: ServerCallContext,
169+
) -> a2a_pb2.SendMessageResponse:
170+
body = await request.body()
171+
params = a2a_pb2.SendMessageRequest()
172+
Parse(body, params)
173+
task_or_message = await self.request_handler.on_message_send(
174+
params, context
175+
)
176+
if isinstance(task_or_message, a2a_pb2.Task):
177+
return a2a_pb2.SendMessageResponse(task=task_or_message)
178+
return a2a_pb2.SendMessageResponse(message=task_or_message)
179+
180+
response = await self._handle_non_streaming(request, _handler)
181+
return JSONResponse(content=MessageToDict(response))
182+
183+
@rest_stream_error_handler
184+
async def on_message_send_stream(
185+
self, request: Request
186+
) -> EventSourceResponse:
187+
"""Handles the 'message/stream' REST method."""
188+
189+
@validate_version(constants.PROTOCOL_VERSION_1_0)
190+
@validate(
191+
lambda _: self.agent_card.capabilities.streaming,
192+
'Streaming is not supported by the agent',
193+
)
194+
async def _handler(
195+
context: ServerCallContext,
196+
) -> AsyncIterator[dict[str, Any]]:
197+
body = await request.body()
198+
params = a2a_pb2.SendMessageRequest()
199+
Parse(body, params)
200+
async for event in self.request_handler.on_message_send_stream(
201+
params, context
202+
):
203+
response = proto_utils.to_stream_response(event)
204+
yield MessageToDict(response)
205+
206+
return await self._handle_streaming(request, _handler)
207+
208+
@rest_error_handler
161209
async def on_cancel_task(self, request: Request) -> Response:
162210
"""Handles the 'tasks/cancel' REST method."""
163-
context = self._build_call_context(request)
164-
task_id = request.path_params['id']
165-
task = await self.request_handler.on_cancel_task(CancelTaskRequest(id=task_id), context)
166-
if task:
167-
return JSONResponse(content=MessageToDict(task))
168-
raise TaskNotFoundError
211+
212+
@validate_version(constants.PROTOCOL_VERSION_1_0)
213+
async def _handler(context: ServerCallContext) -> a2a_pb2.Task:
214+
task_id = request.path_params['id']
215+
task = await self.request_handler.on_cancel_task(
216+
CancelTaskRequest(id=task_id), context
217+
)
218+
if task:
219+
return task
220+
raise TaskNotFoundError
221+
222+
response = await self._handle_non_streaming(request, _handler)
223+
return JSONResponse(content=MessageToDict(response))
169224

170225
@rest_stream_error_handler
171-
@validate_version(constants.PROTOCOL_VERSION_1_0)
172-
@validate(
173-
lambda self: self.agent_card.capabilities.streaming,
174-
'Streaming is not supported by the agent',
175-
)
176-
async def on_subscribe_to_task(self, request: Request) -> EventSourceResponse:
226+
async def on_subscribe_to_task(
227+
self, request: Request
228+
) -> EventSourceResponse:
177229
"""Handles the 'SubscribeToTask' REST method."""
178-
try:
179-
await request.body()
180-
except (ValueError, RuntimeError, OSError) as e:
181-
raise InvalidRequestError(
182-
message=f'Failed to pre-consume request body: {e}'
183-
) from e
184-
185-
context = self._build_call_context(request)
186230
task_id = request.path_params['id']
187-
188-
stream = aiter(self.request_handler.on_subscribe_to_task(SubscribeToTaskRequest(id=task_id), context))
189-
try:
190-
first_event = await anext(stream)
191-
except StopAsyncIteration:
192-
return EventSourceResponse(iter([]))
193231

194-
async def event_generator() -> AsyncIterator[str]:
195-
yield json.dumps(MessageToDict(proto_utils.to_stream_response(first_event)))
196-
async for event in stream:
197-
yield json.dumps(MessageToDict(proto_utils.to_stream_response(event)))
232+
@validate_version(constants.PROTOCOL_VERSION_1_0)
233+
@validate(
234+
lambda _: self.agent_card.capabilities.streaming,
235+
'Streaming is not supported by the agent',
236+
)
237+
async def _handler(
238+
context: ServerCallContext,
239+
) -> AsyncIterator[dict[str, Any]]:
240+
async for event in self.request_handler.on_subscribe_to_task(
241+
SubscribeToTaskRequest(id=task_id), context
242+
):
243+
response = proto_utils.to_stream_response(event)
244+
yield MessageToDict(response)
198245

199-
return EventSourceResponse(event_generator())
246+
return await self._handle_streaming(request, _handler)
200247

201248
@rest_error_handler
202-
@validate_version(constants.PROTOCOL_VERSION_1_0)
203249
async def on_get_task(self, request: Request) -> Response:
204250
"""Handles the 'tasks/{id}' REST method."""
205-
context = self._build_call_context(request)
206-
params = a2a_pb2.GetTaskRequest()
207-
proto_utils.parse_params(request.query_params, params)
208-
params.id = request.path_params['id']
209-
task = await self.request_handler.on_get_task(params, context)
210-
if task:
211-
return JSONResponse(content=MessageToDict(task))
212-
raise TaskNotFoundError
251+
252+
@validate_version(constants.PROTOCOL_VERSION_1_0)
253+
async def _handler(context: ServerCallContext) -> a2a_pb2.Task:
254+
params = a2a_pb2.GetTaskRequest()
255+
proto_utils.parse_params(request.query_params, params)
256+
params.id = request.path_params['id']
257+
task = await self.request_handler.on_get_task(params, context)
258+
if task:
259+
return task
260+
raise TaskNotFoundError
261+
262+
response = await self._handle_non_streaming(request, _handler)
263+
return JSONResponse(content=MessageToDict(response))
213264

214265
@rest_error_handler
215-
@validate_version(constants.PROTOCOL_VERSION_1_0)
216266
async def get_push_notification(self, request: Request) -> Response:
217267
"""Handles the 'tasks/pushNotificationConfig/get' REST method."""
218-
context = self._build_call_context(request)
219-
task_id = request.path_params['id']
220-
push_id = request.path_params['push_id']
221-
params = GetTaskPushNotificationConfigRequest(task_id=task_id, id=push_id)
222-
config = await self.request_handler.on_get_task_push_notification_config(params, context)
223-
return JSONResponse(content=MessageToDict(config))
268+
269+
@validate_version(constants.PROTOCOL_VERSION_1_0)
270+
async def _handler(
271+
context: ServerCallContext,
272+
) -> a2a_pb2.TaskPushNotificationConfig:
273+
task_id = request.path_params['id']
274+
push_id = request.path_params['push_id']
275+
params = GetTaskPushNotificationConfigRequest(
276+
task_id=task_id, id=push_id
277+
)
278+
return (
279+
await self.request_handler.on_get_task_push_notification_config(
280+
params, context
281+
)
282+
)
283+
284+
response = await self._handle_non_streaming(request, _handler)
285+
return JSONResponse(content=MessageToDict(response))
224286

225287
@rest_error_handler
226-
@validate_version(constants.PROTOCOL_VERSION_1_0)
227288
async def delete_push_notification(self, request: Request) -> Response:
228289
"""Handles the 'tasks/pushNotificationConfig/delete' REST method."""
229-
context = self._build_call_context(request)
230-
task_id = request.path_params['id']
231-
push_id = request.path_params['push_id']
232-
params = a2a_pb2.DeleteTaskPushNotificationConfigRequest(task_id=task_id, id=push_id)
233-
await self.request_handler.on_delete_task_push_notification_config(params, context)
290+
291+
@validate_version(constants.PROTOCOL_VERSION_1_0)
292+
async def _handler(context: ServerCallContext) -> None:
293+
task_id = request.path_params['id']
294+
push_id = request.path_params['push_id']
295+
params = a2a_pb2.DeleteTaskPushNotificationConfigRequest(
296+
task_id=task_id, id=push_id
297+
)
298+
await self.request_handler.on_delete_task_push_notification_config(
299+
params, context
300+
)
301+
302+
await self._handle_non_streaming(request, _handler)
234303
return JSONResponse(content={})
235304

236305
@rest_error_handler
237-
@validate_version(constants.PROTOCOL_VERSION_1_0)
238-
@validate(
239-
lambda self: self.agent_card.capabilities.push_notifications,
240-
'Push notifications are not supported by the agent',
241-
)
242306
async def set_push_notification(self, request: Request) -> Response:
243307
"""Handles the 'tasks/pushNotificationConfig/set' REST method."""
244-
context = self._build_call_context(request)
245-
body = await request.body()
246-
params = a2a_pb2.TaskPushNotificationConfig()
247-
Parse(body, params)
248-
params.task_id = request.path_params['id']
249-
config = await self.request_handler.on_create_task_push_notification_config(params, context)
250-
return JSONResponse(content=MessageToDict(config))
308+
309+
@validate_version(constants.PROTOCOL_VERSION_1_0)
310+
@validate(
311+
lambda _: self.agent_card.capabilities.push_notifications,
312+
'Push notifications are not supported by the agent',
313+
)
314+
async def _handler(
315+
context: ServerCallContext,
316+
) -> a2a_pb2.TaskPushNotificationConfig:
317+
body = await request.body()
318+
params = a2a_pb2.TaskPushNotificationConfig()
319+
Parse(body, params)
320+
params.task_id = request.path_params['id']
321+
return await self.request_handler.on_create_task_push_notification_config(
322+
params, context
323+
)
324+
325+
response = await self._handle_non_streaming(request, _handler)
326+
return JSONResponse(content=MessageToDict(response))
251327

252328
@rest_error_handler
253-
@validate_version(constants.PROTOCOL_VERSION_1_0)
254329
async def list_push_notifications(self, request: Request) -> Response:
255330
"""Handles the 'tasks/pushNotificationConfig/list' REST method."""
256-
context = self._build_call_context(request)
257-
params = a2a_pb2.ListTaskPushNotificationConfigsRequest()
258-
proto_utils.parse_params(request.query_params, params)
259-
params.task_id = request.path_params['id']
260-
result = await self.request_handler.on_list_task_push_notification_configs(params, context)
261-
return JSONResponse(content=MessageToDict(result))
331+
332+
@validate_version(constants.PROTOCOL_VERSION_1_0)
333+
async def _handler(
334+
context: ServerCallContext,
335+
) -> a2a_pb2.ListTaskPushNotificationConfigsResponse:
336+
params = a2a_pb2.ListTaskPushNotificationConfigsRequest()
337+
proto_utils.parse_params(request.query_params, params)
338+
params.task_id = request.path_params['id']
339+
return await self.request_handler.on_list_task_push_notification_configs(
340+
params, context
341+
)
342+
343+
response = await self._handle_non_streaming(request, _handler)
344+
return JSONResponse(content=MessageToDict(response))
262345

263346
@rest_error_handler
264-
@validate_version(constants.PROTOCOL_VERSION_1_0)
265347
async def list_tasks(self, request: Request) -> Response:
266348
"""Handles the 'tasks/list' REST method."""
267-
context = self._build_call_context(request)
268-
params = a2a_pb2.ListTasksRequest()
269-
proto_utils.parse_params(request.query_params, params)
270-
result = await self.request_handler.on_list_tasks(params, context)
271-
return JSONResponse(content=MessageToDict(result, always_print_fields_with_no_presence=True))
349+
350+
@validate_version(constants.PROTOCOL_VERSION_1_0)
351+
async def _handler(
352+
context: ServerCallContext,
353+
) -> a2a_pb2.ListTasksResponse:
354+
params = a2a_pb2.ListTasksRequest()
355+
proto_utils.parse_params(request.query_params, params)
356+
return await self.request_handler.on_list_tasks(params, context)
357+
358+
response = await self._handle_non_streaming(request, _handler)
359+
return JSONResponse(
360+
content=MessageToDict(
361+
response, always_print_fields_with_no_presence=True
362+
)
363+
)
272364

273365
@rest_error_handler
274-
async def handle_authenticated_agent_card(self, request: Request) -> Response:
366+
async def handle_authenticated_agent_card(
367+
self, request: Request
368+
) -> Response:
275369
"""Handles the 'extendedAgentCard' REST method."""
276370
if not self.agent_card.capabilities.extended_agent_card:
277371
raise ExtendedAgentCardNotConfiguredError(
@@ -288,5 +382,7 @@ async def handle_authenticated_agent_card(self, request: Request) -> Response:
288382
card_to_serve = await maybe_await(self.card_modifier(card_to_serve))
289383

290384
return JSONResponse(
291-
content=MessageToDict(card_to_serve, preserving_proto_field_name=True)
385+
content=MessageToDict(
386+
card_to_serve, preserving_proto_field_name=True
387+
)
292388
)

0 commit comments

Comments
 (0)