Skip to content

Commit a4bc9e6

Browse files
committed
fix: fix REST error handling
Do one iteration to catch exceptions occurred beforehand to return an error instead of sending headers for SSE. Error handling during the execution is not defined in the spec: a2aproject/A2A#1262.
1 parent c18fb60 commit a4bc9e6

2 files changed

Lines changed: 92 additions & 19 deletions

File tree

src/a2a/server/apps/rest/rest_adapter.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -152,15 +152,26 @@ async def _handle_streaming_request(
152152

153153
call_context = self._build_call_context(request)
154154

155-
async def event_generator(
156-
stream: AsyncIterable[Any],
157-
) -> AsyncIterator[str]:
155+
# Eagerly fetch the first item from the stream so that errors raised
156+
# before any event is yielded (e.g. validation, parsing, or handler
157+
# failures) propagate here and are caught by
158+
# @rest_stream_error_handler, which returns a JSONResponse with
159+
# the correct HTTP status code instead of starting an SSE stream.
160+
# Without this, the error would be raised after SSE headers are
161+
# already sent, and the client would see a broken stream instead
162+
# of a proper error response.
163+
stream = aiter(method(request, call_context))
164+
try:
165+
first_item = await anext(stream)
166+
except StopAsyncIteration:
167+
return EventSourceResponse(iter([]))
168+
169+
async def event_generator() -> AsyncIterator[str]:
170+
yield json.dumps(first_item)
158171
async for item in stream:
159172
yield json.dumps(item)
160173

161-
return EventSourceResponse(
162-
event_generator(method(request, call_context))
163-
)
174+
return EventSourceResponse(event_generator())
164175

165176
async def handle_get_agent_card(
166177
self, request: Request, call_context: ServerCallContext | None = None

tests/integration/test_client_server_integration.py

Lines changed: 75 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import asyncio
2-
32
from collections.abc import AsyncGenerator
43
from typing import Any, NamedTuple
54
from unittest.mock import ANY, AsyncMock, patch
@@ -8,21 +7,24 @@
87
import httpx
98
import pytest
109
import pytest_asyncio
11-
1210
from cryptography.hazmat.primitives.asymmetric import ec
1311
from google.protobuf.json_format import MessageToDict
1412
from google.protobuf.timestamp_pb2 import Timestamp
1513

1614
from a2a.client import Client, ClientConfig
1715
from a2a.client.base_client import BaseClient
1816
from a2a.client.card_resolver import A2ACardResolver
19-
from a2a.client.client_factory import ClientFactory
2017
from a2a.client.client import ClientCallContext
18+
from a2a.client.client_factory import ClientFactory
2119
from a2a.client.service_parameters import (
2220
ServiceParametersFactory,
2321
with_a2a_extensions,
2422
)
2523
from a2a.client.transports import JsonRpcTransport, RestTransport
24+
25+
# Compat v0.3 imports for dedicated tests
26+
from a2a.compat.v0_3 import a2a_v0_3_pb2, a2a_v0_3_pb2_grpc
27+
from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler
2628
from a2a.server.apps import A2AFastAPIApplication, A2ARESTFastAPIApplication
2729
from a2a.server.request_handlers import GrpcHandler, RequestHandler
2830
from a2a.types import a2a_pb2_grpc
@@ -50,12 +52,10 @@
5052
TaskStatus,
5153
TaskStatusUpdateEvent,
5254
)
53-
from a2a.utils.constants import (
54-
TransportProtocol,
55-
)
55+
from a2a.utils.constants import TransportProtocol
5656
from a2a.utils.errors import (
57-
ExtendedAgentCardNotConfiguredError,
5857
ContentTypeNotSupportedError,
58+
ExtendedAgentCardNotConfiguredError,
5959
ExtensionSupportRequiredError,
6060
InternalError,
6161
InvalidAgentResponseError,
@@ -73,11 +73,6 @@
7373
create_signature_verifier,
7474
)
7575

76-
# Compat v0.3 imports for dedicated tests
77-
from a2a.compat.v0_3 import a2a_v0_3_pb2, a2a_v0_3_pb2_grpc
78-
from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler
79-
80-
8176
# --- Test Constants ---
8277

8378
TASK_FROM_STREAM = Task(
@@ -360,9 +355,9 @@ def grpc_03_setup(
360355
) -> TransportSetup:
361356
"""Sets up the CompatGrpcTransport and in-process 0.3 server."""
362357
server_address, handler = grpc_03_server_and_handler
363-
from a2a.compat.v0_3.grpc_transport import CompatGrpcTransport
364358
from a2a.client.base_client import BaseClient
365359
from a2a.client.client import ClientConfig
360+
from a2a.compat.v0_3.grpc_transport import CompatGrpcTransport
366361

367362
channel = grpc.aio.insecure_channel(server_address)
368363
transport = CompatGrpcTransport(channel=channel, agent_card=agent_card)
@@ -909,6 +904,73 @@ async def test_client_handles_a2a_errors(transport_setups, error_cls) -> None:
909904
await client.close()
910905

911906

907+
@pytest.mark.asyncio
908+
@pytest.mark.parametrize(
909+
'error_cls',
910+
[
911+
TaskNotFoundError,
912+
TaskNotCancelableError,
913+
PushNotificationNotSupportedError,
914+
UnsupportedOperationError,
915+
ContentTypeNotSupportedError,
916+
InvalidAgentResponseError,
917+
ExtendedAgentCardNotConfiguredError,
918+
ExtensionSupportRequiredError,
919+
VersionNotSupportedError,
920+
],
921+
)
922+
@pytest.mark.parametrize(
923+
'handler_attr, client_method, request_params',
924+
[
925+
pytest.param(
926+
'on_message_send_stream',
927+
'send_message',
928+
SendMessageRequest(
929+
message=Message(
930+
role=Role.ROLE_USER,
931+
message_id='msg-integration-test',
932+
parts=[Part(text='Hello, integration test!')],
933+
)
934+
),
935+
id='stream',
936+
),
937+
pytest.param(
938+
'on_subscribe_to_task',
939+
'subscribe',
940+
SubscribeToTaskRequest(id='some-id'),
941+
id='subscribe',
942+
),
943+
],
944+
)
945+
async def test_client_handles_a2a_errors_streaming(
946+
transport_setups, error_cls, handler_attr, client_method, request_params
947+
) -> None:
948+
"""Integration test to verify error propagation from streaming handlers to client.
949+
950+
The handler raises an A2AError before yielding any events. All transports
951+
must propagate this as the exact error_cls, not wrapped in an ExceptionGroup
952+
or converted to a generic client error.
953+
"""
954+
client = transport_setups.client
955+
handler = transport_setups.handler
956+
957+
async def mock_generator(*args, **kwargs):
958+
raise error_cls('Test error message')
959+
yield
960+
961+
getattr(handler, handler_attr).side_effect = mock_generator
962+
963+
with pytest.raises(error_cls) as exc_info:
964+
async for _ in getattr(client, client_method)(request=request_params):
965+
pass
966+
967+
assert 'Test error message' in str(exc_info.value)
968+
969+
getattr(handler, handler_attr).side_effect = None
970+
971+
await client.close()
972+
973+
912974
@pytest.mark.asyncio
913975
@pytest.mark.parametrize(
914976
'request_kwargs, expected_error_code',

0 commit comments

Comments
 (0)