Skip to content

Commit 8730c51

Browse files
committed
fix: handle SSE errors occurred after stream started
Currently it'd close the connection.
1 parent 8c65e84 commit 8730c51

15 files changed

Lines changed: 422 additions & 98 deletions

File tree

src/a2a/client/transports/http_helpers.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,23 @@ async def send_http_stream_request(
6969
httpx_client: httpx.AsyncClient,
7070
method: str,
7171
url: str,
72-
status_error_handler: Callable[[httpx.HTTPStatusError], NoReturn]
73-
| None = None,
72+
status_error_handler: Callable[[httpx.HTTPStatusError], NoReturn],
73+
sse_error_handler: Callable[[str], NoReturn],
7474
**kwargs: Any,
7575
) -> AsyncGenerator[str]:
76-
"""Sends a streaming HTTP request, yielding SSE data strings and handling exceptions."""
76+
"""Sends a streaming HTTP request, yielding SSE data strings and handling exceptions.
77+
78+
Args:
79+
httpx_client: The async HTTP client.
80+
method: The HTTP method (e.g. 'POST', 'GET').
81+
url: The URL to send the request to.
82+
status_error_handler: Handler for HTTP status errors. Should raise an
83+
appropriate domain-specific exception.
84+
sse_error_handler: Handler for SSE error events. Called with the
85+
raw SSE data string when an ``event: error`` SSE event is received.
86+
Should raise an appropriate domain-specific exception.
87+
**kwargs: Additional keyword arguments forwarded to ``aconnect_sse``.
88+
"""
7789
with handle_http_exceptions(status_error_handler):
7890
async with aconnect_sse(
7991
httpx_client, method, url, **kwargs
@@ -97,4 +109,6 @@ async def send_http_stream_request(
97109
async for sse in event_source.aiter_sse():
98110
if not sse.data:
99111
continue
112+
if sse.event == 'error':
113+
sse_error_handler(sse.data)
100114
yield sse.data

src/a2a/client/transports/jsonrpc.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22

33
from collections.abc import AsyncGenerator
4-
from typing import Any
4+
from typing import Any, NoReturn
55
from uuid import uuid4
66

77
import httpx
@@ -349,6 +349,7 @@ async def _send_stream_request(
349349
'POST',
350350
self.url,
351351
None,
352+
self._handle_sse_error,
352353
json=rpc_request_payload,
353354
**http_kwargs,
354355
):
@@ -359,3 +360,10 @@ async def _send_stream_request(
359360
json_rpc_response.result, StreamResponse()
360361
)
361362
yield response
363+
364+
def _handle_sse_error(self, sse_data: str) -> NoReturn:
365+
"""Handles SSE error events by parsing JSON-RPC error payload and raising the appropriate domain error."""
366+
json_rpc_response = JSONRPC20Response.from_json(sse_data)
367+
if json_rpc_response.error:
368+
raise self._create_jsonrpc_error(json_rpc_response.error)
369+
raise A2AClientError(f'SSE stream error: {sse_data}')

src/a2a/client/transports/rest.py

Lines changed: 53 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,47 @@
4141
logger = logging.getLogger(__name__)
4242

4343

44+
def _parse_rest_error(
45+
error_payload: dict[str, Any],
46+
fallback_message: str,
47+
) -> Exception | None:
48+
"""Parses a REST error payload and returns the appropriate A2AError.
49+
50+
Args:
51+
error_payload: The parsed JSON error payload.
52+
fallback_message: Message to use if the payload has no ``message``.
53+
54+
Returns:
55+
The mapped A2AError if a known reason was found, otherwise ``None``.
56+
"""
57+
error_data = error_payload.get('error', {})
58+
message = error_data.get('message', fallback_message)
59+
details = error_data.get('details', [])
60+
if not isinstance(details, list):
61+
return None
62+
63+
# The `details` array can contain multiple different error objects.
64+
# We extract the first `ErrorInfo` object because it contains the
65+
# specific `reason` code needed to map this back to a Python A2AError.
66+
for d in details:
67+
if (
68+
isinstance(d, dict)
69+
and d.get('@type') == 'type.googleapis.com/google.rpc.ErrorInfo'
70+
):
71+
reason = d.get('reason')
72+
metadata = d.get('metadata') or {}
73+
if isinstance(reason, str):
74+
exception_cls = A2A_REASON_TO_ERROR.get(reason)
75+
if exception_cls:
76+
exc = exception_cls(message)
77+
if metadata:
78+
exc.data = metadata
79+
return exc
80+
break
81+
82+
return None
83+
84+
4485
@trace_class(kind=SpanKind.CLIENT)
4586
class RestTransport(ClientTransport):
4687
"""A REST transport for the A2A client."""
@@ -294,39 +335,12 @@ def _handle_http_error(self, e: httpx.HTTPStatusError) -> NoReturn:
294335
"""Handles HTTP status errors and raises the appropriate A2AError."""
295336
try:
296337
error_payload = e.response.json()
297-
error_data = error_payload.get('error', {})
298-
299-
message = error_data.get('message', str(e))
300-
details = error_data.get('details', [])
301-
if not isinstance(details, list):
302-
details = []
303-
304-
# The `details` array can contain multiple different error objects.
305-
# We extract the first `ErrorInfo` object because it contains the
306-
# specific `reason` code needed to map this back to a Python A2AError.
307-
error_info = {}
308-
for d in details:
309-
if (
310-
isinstance(d, dict)
311-
and d.get('@type')
312-
== 'type.googleapis.com/google.rpc.ErrorInfo'
313-
):
314-
error_info = d
315-
break
316-
reason = error_info.get('reason')
317-
metadata = error_info.get('metadata') or {}
318-
319-
if isinstance(reason, str):
320-
exception_cls = A2A_REASON_TO_ERROR.get(reason)
321-
if exception_cls:
322-
exc = exception_cls(message)
323-
if metadata:
324-
exc.data = metadata
325-
raise exc from e
338+
mapped = _parse_rest_error(error_payload, str(e))
339+
if mapped:
340+
raise mapped from e
326341
except (json.JSONDecodeError, ValueError):
327342
pass
328343

329-
# Fallback mappings for status codes if 'type' is missing or unknown
330344
status_code = e.response.status_code
331345
if status_code == httpx.codes.NOT_FOUND:
332346
raise MethodNotFoundError(
@@ -335,6 +349,14 @@ def _handle_http_error(self, e: httpx.HTTPStatusError) -> NoReturn:
335349

336350
raise A2AClientError(f'HTTP Error {status_code}: {e}') from e
337351

352+
def _handle_sse_error(self, sse_data: str) -> NoReturn:
353+
"""Handles SSE error events by parsing the REST error payload and raising the appropriate A2AError."""
354+
error_payload = json.loads(sse_data)
355+
mapped = _parse_rest_error(error_payload, sse_data)
356+
if mapped:
357+
raise mapped
358+
raise A2AClientError(sse_data)
359+
338360
async def _send_stream_request(
339361
self,
340362
method: str,
@@ -352,6 +374,7 @@ async def _send_stream_request(
352374
method,
353375
f'{self.url}{path}',
354376
self._handle_http_error,
377+
self._handle_sse_error,
355378
json=json,
356379
**http_kwargs,
357380
):

src/a2a/compat/v0_3/jsonrpc_adapter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,9 +306,10 @@ async def event_generator(
306306
)
307307
)
308308
yield {
309+
'event': 'error',
309310
'data': err_resp.model_dump_json(
310311
by_alias=True, exclude_none=True
311-
)
312+
),
312313
}
313314

314315
return EventSourceResponse(event_generator(stream_gen))

src/a2a/compat/v0_3/jsonrpc_transport.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,13 @@ def _handle_http_error(self, e: httpx.HTTPStatusError) -> NoReturn:
415415
"""Handles HTTP errors for standard requests."""
416416
raise A2AClientError(f'HTTP Error: {e.response.status_code}') from e
417417

418+
def _handle_sse_error(self, sse_data: str) -> NoReturn:
419+
"""Handles SSE error events by parsing JSON-RPC error payload and raising the appropriate domain error."""
420+
data = json.loads(sse_data)
421+
if 'error' in data:
422+
raise self._create_jsonrpc_error(data['error'])
423+
raise A2AClientError(f'SSE stream error: {sse_data}')
424+
418425
async def _send_stream_request(
419426
self,
420427
json_data: dict[str, Any],
@@ -430,6 +437,7 @@ async def _send_stream_request(
430437
'POST',
431438
self.url,
432439
self._handle_http_error,
440+
self._handle_sse_error,
433441
json=json_data,
434442
**http_kwargs,
435443
):

src/a2a/compat/v0_3/rest_adapter.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88

99
if TYPE_CHECKING:
10+
from sse_starlette.event import ServerSentEvent
1011
from sse_starlette.sse import EventSourceResponse
1112
from starlette.requests import Request
1213
from starlette.responses import JSONResponse, Response
@@ -17,6 +18,7 @@
1718
_package_starlette_installed = True
1819
else:
1920
try:
21+
from sse_starlette.event import ServerSentEvent
2022
from sse_starlette.sse import EventSourceResponse
2123
from starlette.requests import Request
2224
from starlette.responses import JSONResponse, Response
@@ -27,6 +29,7 @@
2729
Request = Any
2830
JSONResponse = Any
2931
Response = Any
32+
ServerSentEvent = Any
3033

3134
_package_starlette_installed = False
3235

@@ -37,6 +40,7 @@
3740
from a2a.server.context import ServerCallContext
3841
from a2a.server.routes import CallContextBuilder, DefaultCallContextBuilder
3942
from a2a.utils.error_handlers import (
43+
build_rest_error_payload,
4044
rest_error_handler,
4145
rest_stream_error_handler,
4246
)
@@ -101,9 +105,16 @@ async def _handle_streaming_request(
101105

102106
async def event_generator(
103107
stream: AsyncIterable[Any],
104-
) -> AsyncIterator[str]:
105-
async for item in stream:
106-
yield json.dumps(item)
108+
) -> AsyncIterator[str | ServerSentEvent]:
109+
try:
110+
async for item in stream:
111+
yield json.dumps(item)
112+
except Exception as e:
113+
logger.exception('Error during v0.3 REST SSE stream')
114+
yield ServerSentEvent(
115+
data=json.dumps(build_rest_error_payload(e)),
116+
event='error',
117+
)
107118

108119
return EventSourceResponse(
109120
event_generator(method(request, call_context))

src/a2a/compat/v0_3/rest_transport.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,11 @@
4444
TaskPushNotificationConfig,
4545
)
4646
from a2a.utils.constants import PROTOCOL_VERSION_0_3, VERSION_HEADER
47-
from a2a.utils.errors import JSON_RPC_ERROR_CODE_MAP, MethodNotFoundError
47+
from a2a.utils.errors import (
48+
A2A_REASON_TO_ERROR,
49+
JSON_RPC_ERROR_CODE_MAP,
50+
MethodNotFoundError,
51+
)
4852
from a2a.utils.telemetry import SpanKind, trace_class
4953

5054

@@ -369,6 +373,30 @@ def _handle_http_error(self, e: httpx.HTTPStatusError) -> NoReturn:
369373

370374
raise A2AClientError(f'HTTP Error {status_code}: {e}') from e
371375

376+
def _handle_sse_error(self, sse_data: str) -> NoReturn:
377+
"""Handles SSE error events by parsing the REST error payload and raising the appropriate A2AError."""
378+
error_payload = json.loads(sse_data)
379+
error_data = error_payload.get('error', {})
380+
381+
message = error_data.get('message', sse_data)
382+
details = error_data.get('details', [])
383+
if not isinstance(details, list):
384+
details = []
385+
386+
for d in details:
387+
if (
388+
isinstance(d, dict)
389+
and d.get('@type') == 'type.googleapis.com/google.rpc.ErrorInfo'
390+
):
391+
reason = d.get('reason')
392+
if isinstance(reason, str):
393+
exception_cls = A2A_REASON_TO_ERROR.get(reason)
394+
if exception_cls:
395+
raise exception_cls(message)
396+
break
397+
398+
raise A2AClientError(message)
399+
372400
async def _send_stream_request(
373401
self,
374402
method: str,
@@ -386,6 +414,7 @@ async def _send_stream_request(
386414
method,
387415
f'{self.url}{path}',
388416
self._handle_http_error,
417+
self._handle_sse_error,
389418
json=json,
390419
**http_kwargs,
391420
):

src/a2a/server/apps/rest/rest_adapter.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313

1414
if TYPE_CHECKING:
15+
from sse_starlette.event import ServerSentEvent
1516
from sse_starlette.sse import EventSourceResponse
1617
from starlette.requests import Request
1718
from starlette.responses import JSONResponse, Response
@@ -20,6 +21,7 @@
2021

2122
else:
2223
try:
24+
from sse_starlette.event import ServerSentEvent
2325
from sse_starlette.sse import EventSourceResponse
2426
from starlette.requests import Request
2527
from starlette.responses import JSONResponse, Response
@@ -30,6 +32,7 @@
3032
Request = Any
3133
JSONResponse = Any
3234
Response = Any
35+
ServerSentEvent = Any
3336

3437
_package_starlette_installed = False
3538

@@ -42,6 +45,7 @@
4245
from a2a.server.routes import CallContextBuilder, DefaultCallContextBuilder
4346
from a2a.types.a2a_pb2 import AgentCard
4447
from a2a.utils.error_handlers import (
48+
build_rest_error_payload,
4549
rest_error_handler,
4650
rest_stream_error_handler,
4751
)
@@ -163,10 +167,17 @@ async def _handle_streaming_request(
163167
except StopAsyncIteration:
164168
return EventSourceResponse(iter([]))
165169

166-
async def event_generator() -> AsyncIterator[str]:
170+
async def event_generator() -> AsyncIterator[str | ServerSentEvent]:
167171
yield json.dumps(first_item)
168-
async for item in stream:
169-
yield json.dumps(item)
172+
try:
173+
async for item in stream:
174+
yield json.dumps(item)
175+
except Exception as e:
176+
logger.exception('Error during REST SSE stream')
177+
yield ServerSentEvent(
178+
data=json.dumps(build_rest_error_payload(e)),
179+
event='error',
180+
)
170181

171182
return EventSourceResponse(event_generator())
172183

src/a2a/server/routes/jsonrpc_dispatcher.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -559,8 +559,30 @@ def _create_response(
559559
async def event_generator(
560560
stream: AsyncGenerator[dict[str, Any]],
561561
) -> AsyncGenerator[dict[str, str]]:
562-
async for item in stream:
563-
yield {'data': json.dumps(item)}
562+
try:
563+
async for item in stream:
564+
event: dict[str, str] = {
565+
'data': json.dumps(item),
566+
}
567+
if 'error' in item:
568+
event['event'] = 'error'
569+
yield event
570+
except Exception as e:
571+
logger.exception(
572+
'Unhandled error during JSON-RPC SSE stream'
573+
)
574+
rpc_error: A2AError | JSONRPCError = (
575+
e
576+
if isinstance(e, A2AError | JSONRPCError)
577+
else InternalError(message=str(e))
578+
)
579+
error_response = build_error_response(
580+
context.state.get('request_id'), rpc_error
581+
)
582+
yield {
583+
'event': 'error',
584+
'data': json.dumps(error_response),
585+
}
564586

565587
return EventSourceResponse(
566588
event_generator(handler_result), headers=headers

0 commit comments

Comments
 (0)