Skip to content

Commit 0b60cc6

Browse files
committed
Add tests
1 parent 763a013 commit 0b60cc6

5 files changed

Lines changed: 93 additions & 37 deletions

File tree

src/a2a/client/transports/grpc.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
from functools import wraps
55
from typing import Any, NoReturn
66

7-
import a2a.utils.errors
8-
97
from a2a.client.errors import A2AClientError
8+
from a2a.utils.errors import JSON_RPC_ERROR_CODE_MAP
109

1110

1211
try:
@@ -48,18 +47,18 @@
4847

4948
logger = logging.getLogger(__name__)
5049

50+
_A2A_ERROR_NAME_TO_CLS = {
51+
error_type.__name__: error_type for error_type in JSON_RPC_ERROR_CODE_MAP
52+
}
53+
5154

5255
def _map_grpc_error(e: grpc.aio.AioRpcError) -> NoReturn:
5356
details = e.details()
5457
if isinstance(details, str) and ': ' in details:
5558
error_type_name, error_message = details.split(': ', 1)
5659
# TODO(#723): Resolving imports by name is a temporary hack until proper error handling structure is added in #723.
57-
exception_cls = getattr(a2a.utils.errors, error_type_name, None)
58-
if (
59-
exception_cls
60-
and isinstance(exception_cls, type)
61-
and issubclass(exception_cls, a2a.utils.errors.A2AError)
62-
):
60+
exception_cls = _A2A_ERROR_NAME_TO_CLS.get(error_type_name)
61+
if exception_cls:
6362
raise exception_cls(error_message) from e
6463
raise A2AClientError(f'gRPC Error {e.code().name}: {e.details()}') from e
6564

src/a2a/client/transports/rest.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
from google.protobuf.message import Message
1111
from httpx_sse import SSEError, aconnect_sse
1212

13-
import a2a.utils.errors
14-
1513
from a2a.client.errors import A2AClientError
1614
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
1715
from a2a.client.transports.base import ClientTransport
@@ -34,12 +32,16 @@
3432
Task,
3533
TaskPushNotificationConfig,
3634
)
37-
from a2a.utils.errors import MethodNotFoundError
35+
from a2a.utils.errors import JSON_RPC_ERROR_CODE_MAP, MethodNotFoundError
3836
from a2a.utils.telemetry import SpanKind, trace_class
3937

4038

4139
logger = logging.getLogger(__name__)
4240

41+
_A2A_ERROR_NAME_TO_CLS = {
42+
error_type.__name__: error_type for error_type in JSON_RPC_ERROR_CODE_MAP
43+
}
44+
4345

4446
@trace_class(kind=SpanKind.CLIENT)
4547
class RestTransport(ClientTransport):
@@ -103,12 +105,9 @@ def _handle_http_error(self, e: httpx.HTTPStatusError) -> NoReturn:
103105
message = error_data.get('message', str(e))
104106

105107
if isinstance(error_type, str):
106-
exception_cls = getattr(a2a.utils.errors, error_type, None)
107-
if (
108-
exception_cls
109-
and isinstance(exception_cls, type)
110-
and issubclass(exception_cls, a2a.utils.errors.A2AError)
111-
):
108+
# TODO(#723): Resolving imports by name is a temporary hack until proper error handling structure is added in #723.
109+
exception_cls = _A2A_ERROR_NAME_TO_CLS.get(error_type)
110+
if exception_cls:
112111
raise exception_cls(message) from e
113112
except (json.JSONDecodeError, ValueError):
114113
pass

tests/client/transports/test_grpc_client.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,18 @@
55

66
from a2a.client.transports.grpc import GrpcTransport
77
from a2a.extensions.common import HTTP_EXTENSION_HEADER
8-
from a2a.types import a2a_pb2, a2a_pb2_grpc
8+
from a2a.types import a2a_pb2
99
from a2a.types.a2a_pb2 import (
1010
AgentCapabilities,
11-
AgentInterface,
1211
AgentCard,
12+
AgentInterface,
1313
Artifact,
1414
AuthenticationInfo,
1515
CreateTaskPushNotificationConfigRequest,
1616
DeleteTaskPushNotificationConfigRequest,
1717
GetTaskPushNotificationConfigRequest,
18-
ListTaskPushNotificationConfigsRequest,
19-
ListTaskPushNotificationConfigsResponse,
2018
GetTaskRequest,
19+
ListTaskPushNotificationConfigsRequest,
2120
Message,
2221
Part,
2322
PushNotificationConfig,
@@ -30,7 +29,8 @@
3029
TaskStatus,
3130
TaskStatusUpdateEvent,
3231
)
33-
from a2a.utils import get_text_parts, proto_utils
32+
from a2a.utils import get_text_parts
33+
from a2a.utils.errors import JSON_RPC_ERROR_CODE_MAP
3434

3535

3636
@pytest.fixture
@@ -226,6 +226,29 @@ async def test_send_message_task_response(
226226
assert response.task.id == sample_task.id
227227

228228

229+
@pytest.mark.parametrize('error_cls', list(JSON_RPC_ERROR_CODE_MAP.keys()))
230+
@pytest.mark.asyncio
231+
async def test_grpc_mapped_errors(
232+
grpc_transport: GrpcTransport,
233+
mock_grpc_stub: AsyncMock,
234+
sample_message_send_params: SendMessageRequest,
235+
error_cls,
236+
) -> None:
237+
"""Test handling of mapped gRPC error responses."""
238+
error_details = f'{error_cls.__name__}: Mapped Error'
239+
240+
# We must trigger it from a standard transport method call, for example `send_message`.
241+
mock_grpc_stub.SendMessage.side_effect = grpc.aio.AioRpcError(
242+
code=grpc.StatusCode.INTERNAL,
243+
initial_metadata=grpc.aio.Metadata(),
244+
trailing_metadata=grpc.aio.Metadata(),
245+
details=error_details,
246+
)
247+
248+
with pytest.raises(error_cls):
249+
await grpc_transport.send_message(sample_message_send_params)
250+
251+
229252
@pytest.mark.asyncio
230253
async def test_send_message_message_response(
231254
grpc_transport: GrpcTransport,

tests/client/transports/test_jsonrpc_client.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,27 @@
11
"""Tests for the JSON-RPC client transport."""
22

33
import json
4-
from google.protobuf import json_format
5-
from unittest import mock
4+
65
from unittest.mock import AsyncMock, MagicMock, patch
76
from uuid import uuid4
87

98
import httpx
109
import pytest
11-
import respx
10+
11+
from google.protobuf import json_format
1212
from httpx_sse import EventSource, SSEError
1313

1414
from a2a.client.errors import A2AClientError
15-
from a2a.utils.errors import InvalidRequestError
1615
from a2a.client.transports.jsonrpc import JsonRpcTransport
1716
from a2a.types.a2a_pb2 import (
1817
AgentCapabilities,
19-
AgentInterface,
2018
AgentCard,
19+
AgentInterface,
2120
CancelTaskRequest,
22-
CreateTaskPushNotificationConfigRequest,
2321
DeleteTaskPushNotificationConfigRequest,
2422
GetTaskPushNotificationConfigRequest,
25-
ListTaskPushNotificationConfigsRequest,
26-
ListTaskPushNotificationConfigsResponse,
2723
GetTaskRequest,
24+
ListTaskPushNotificationConfigsRequest,
2825
Message,
2926
Part,
3027
SendMessageConfiguration,
@@ -33,8 +30,8 @@
3330
Task,
3431
TaskPushNotificationConfig,
3532
TaskState,
36-
TaskStatus,
3733
)
34+
from a2a.utils.errors import JSON_RPC_ERROR_CODE_MAP
3835

3936

4037
@pytest.fixture
@@ -174,16 +171,19 @@ async def test_send_message_success(self, transport, mock_httpx_client):
174171
payload = call_args[1]['json']
175172
assert payload['method'] == 'SendMessage'
176173

174+
@pytest.mark.parametrize(
175+
'error_cls, error_code', JSON_RPC_ERROR_CODE_MAP.items()
176+
)
177177
@pytest.mark.asyncio
178178
async def test_send_message_jsonrpc_error(
179-
self, transport, mock_httpx_client
179+
self, transport, mock_httpx_client, error_cls, error_code
180180
):
181-
"""Test handling of JSON-RPC error response."""
181+
"""Test handling of JSON-RPC mapped error response."""
182182
mock_response = MagicMock()
183183
mock_response.json.return_value = {
184184
'jsonrpc': '2.0',
185185
'id': '1',
186-
'error': {'code': -32600, 'message': 'Invalid Request'},
186+
'error': {'code': error_code, 'message': 'Mapped Error'},
187187
'result': None,
188188
}
189189
mock_response.raise_for_status = MagicMock()
@@ -192,7 +192,7 @@ async def test_send_message_jsonrpc_error(
192192
request = create_send_message_request()
193193

194194
# The transport raises the specific A2AError mapped from code
195-
with pytest.raises(InvalidRequestError):
195+
with pytest.raises(error_cls):
196196
await transport.send_message(request)
197197

198198
@pytest.mark.asyncio

tests/client/transports/test_rest_client.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@
1717
AgentInterface,
1818
DeleteTaskPushNotificationConfigRequest,
1919
ListTaskPushNotificationConfigsRequest,
20-
ListTaskPushNotificationConfigsResponse,
2120
SendMessageRequest,
22-
TaskPushNotificationConfig,
2321
)
2422
from a2a.utils.constants import TransportProtocol
23+
from a2a.utils.errors import JSON_RPC_ERROR_CODE_MAP
2524

2625

2726
@pytest.fixture
@@ -95,6 +94,42 @@ async def test_send_message_streaming_timeout(
9594
assert 'Client Request timed out' in str(exc_info.value)
9695

9796

97+
@pytest.mark.parametrize('error_cls', list(JSON_RPC_ERROR_CODE_MAP.keys()))
98+
@pytest.mark.asyncio
99+
async def test_rest_mapped_errors(
100+
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock, error_cls
101+
):
102+
"""Test handling of mapped REST HTTP error responses."""
103+
client = RestTransport(
104+
httpx_client=mock_httpx_client,
105+
agent_card=mock_agent_card,
106+
url='http://agent.example.com/api',
107+
)
108+
params = SendMessageRequest(
109+
message=create_text_message_object(content='Hello')
110+
)
111+
112+
mock_build_request = MagicMock(
113+
return_value=AsyncMock(spec=httpx.Request)
114+
)
115+
mock_httpx_client.build_request = mock_build_request
116+
117+
mock_response = AsyncMock(spec=httpx.Response)
118+
mock_response.status_code = 500
119+
mock_response.json.return_value = {'type': error_cls.__name__, 'message': 'Mapped Error'}
120+
121+
error = httpx.HTTPStatusError(
122+
'Server Error',
123+
request=httpx.Request('POST', 'http://test.url'),
124+
response=mock_response,
125+
)
126+
127+
mock_httpx_client.send.side_effect = error
128+
129+
with pytest.raises(error_cls):
130+
await client.send_message(request=params)
131+
132+
98133
class TestRestTransportExtensions:
99134
@pytest.mark.asyncio
100135
async def test_send_message_with_default_extensions(

0 commit comments

Comments
 (0)