|
1 | 1 | import asyncio |
2 | | - |
3 | 2 | from collections.abc import AsyncGenerator |
4 | 3 | from typing import Any, NamedTuple |
5 | 4 | from unittest.mock import ANY, AsyncMock, patch |
|
8 | 7 | import httpx |
9 | 8 | import pytest |
10 | 9 | import pytest_asyncio |
11 | | - |
12 | 10 | from cryptography.hazmat.primitives.asymmetric import ec |
13 | 11 | from google.protobuf.json_format import MessageToDict |
14 | 12 | from google.protobuf.timestamp_pb2 import Timestamp |
15 | 13 |
|
16 | 14 | from a2a.client import Client, ClientConfig |
17 | 15 | from a2a.client.base_client import BaseClient |
18 | 16 | from a2a.client.card_resolver import A2ACardResolver |
19 | | -from a2a.client.client_factory import ClientFactory |
20 | 17 | from a2a.client.client import ClientCallContext |
| 18 | +from a2a.client.client_factory import ClientFactory |
21 | 19 | from a2a.client.service_parameters import ( |
22 | 20 | ServiceParametersFactory, |
23 | 21 | with_a2a_extensions, |
24 | 22 | ) |
25 | 23 | from a2a.client.transports import JsonRpcTransport, RestTransport |
| 24 | + |
| 25 | +# Compat v0.3 imports for dedicated tests |
| 26 | +from a2a.compat.v0_3 import a2a_v0_3_pb2, a2a_v0_3_pb2_grpc |
| 27 | +from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler |
26 | 28 | from a2a.server.apps import A2AFastAPIApplication, A2ARESTFastAPIApplication |
27 | 29 | from a2a.server.request_handlers import GrpcHandler, RequestHandler |
28 | 30 | from a2a.types import a2a_pb2_grpc |
|
50 | 52 | TaskStatus, |
51 | 53 | TaskStatusUpdateEvent, |
52 | 54 | ) |
53 | | -from a2a.utils.constants import ( |
54 | | - TransportProtocol, |
55 | | -) |
| 55 | +from a2a.utils.constants import TransportProtocol |
56 | 56 | from a2a.utils.errors import ( |
57 | | - ExtendedAgentCardNotConfiguredError, |
58 | 57 | ContentTypeNotSupportedError, |
| 58 | + ExtendedAgentCardNotConfiguredError, |
59 | 59 | ExtensionSupportRequiredError, |
60 | 60 | InternalError, |
61 | 61 | InvalidAgentResponseError, |
|
73 | 73 | create_signature_verifier, |
74 | 74 | ) |
75 | 75 |
|
76 | | -# Compat v0.3 imports for dedicated tests |
77 | | -from a2a.compat.v0_3 import a2a_v0_3_pb2, a2a_v0_3_pb2_grpc |
78 | | -from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler |
79 | | - |
80 | | - |
81 | 76 | # --- Test Constants --- |
82 | 77 |
|
83 | 78 | TASK_FROM_STREAM = Task( |
@@ -360,9 +355,9 @@ def grpc_03_setup( |
360 | 355 | ) -> TransportSetup: |
361 | 356 | """Sets up the CompatGrpcTransport and in-process 0.3 server.""" |
362 | 357 | server_address, handler = grpc_03_server_and_handler |
363 | | - from a2a.compat.v0_3.grpc_transport import CompatGrpcTransport |
364 | 358 | from a2a.client.base_client import BaseClient |
365 | 359 | from a2a.client.client import ClientConfig |
| 360 | + from a2a.compat.v0_3.grpc_transport import CompatGrpcTransport |
366 | 361 |
|
367 | 362 | channel = grpc.aio.insecure_channel(server_address) |
368 | 363 | transport = CompatGrpcTransport(channel=channel, agent_card=agent_card) |
@@ -909,6 +904,73 @@ async def test_client_handles_a2a_errors(transport_setups, error_cls) -> None: |
909 | 904 | await client.close() |
910 | 905 |
|
911 | 906 |
|
| 907 | +@pytest.mark.asyncio |
| 908 | +@pytest.mark.parametrize( |
| 909 | + 'error_cls', |
| 910 | + [ |
| 911 | + TaskNotFoundError, |
| 912 | + TaskNotCancelableError, |
| 913 | + PushNotificationNotSupportedError, |
| 914 | + UnsupportedOperationError, |
| 915 | + ContentTypeNotSupportedError, |
| 916 | + InvalidAgentResponseError, |
| 917 | + ExtendedAgentCardNotConfiguredError, |
| 918 | + ExtensionSupportRequiredError, |
| 919 | + VersionNotSupportedError, |
| 920 | + ], |
| 921 | +) |
| 922 | +@pytest.mark.parametrize( |
| 923 | + 'handler_attr, client_method, request_params', |
| 924 | + [ |
| 925 | + pytest.param( |
| 926 | + 'on_message_send_stream', |
| 927 | + 'send_message', |
| 928 | + SendMessageRequest( |
| 929 | + message=Message( |
| 930 | + role=Role.ROLE_USER, |
| 931 | + message_id='msg-integration-test', |
| 932 | + parts=[Part(text='Hello, integration test!')], |
| 933 | + ) |
| 934 | + ), |
| 935 | + id='stream', |
| 936 | + ), |
| 937 | + pytest.param( |
| 938 | + 'on_subscribe_to_task', |
| 939 | + 'subscribe', |
| 940 | + SubscribeToTaskRequest(id='some-id'), |
| 941 | + id='subscribe', |
| 942 | + ), |
| 943 | + ], |
| 944 | +) |
| 945 | +async def test_client_handles_a2a_errors_streaming( |
| 946 | + transport_setups, error_cls, handler_attr, client_method, request_params |
| 947 | +) -> None: |
| 948 | + """Integration test to verify error propagation from streaming handlers to client. |
| 949 | +
|
| 950 | + The handler raises an A2AError before yielding any events. All transports |
| 951 | + must propagate this as the exact error_cls, not wrapped in an ExceptionGroup |
| 952 | + or converted to a generic client error. |
| 953 | + """ |
| 954 | + client = transport_setups.client |
| 955 | + handler = transport_setups.handler |
| 956 | + |
| 957 | + async def mock_generator(*args, **kwargs): |
| 958 | + raise error_cls('Test error message') |
| 959 | + yield |
| 960 | + |
| 961 | + getattr(handler, handler_attr).side_effect = mock_generator |
| 962 | + |
| 963 | + with pytest.raises(error_cls) as exc_info: |
| 964 | + async for _ in getattr(client, client_method)(request=request_params): |
| 965 | + pass |
| 966 | + |
| 967 | + assert 'Test error message' in str(exc_info.value) |
| 968 | + |
| 969 | + getattr(handler, handler_attr).side_effect = None |
| 970 | + |
| 971 | + await client.close() |
| 972 | + |
| 973 | + |
912 | 974 | @pytest.mark.asyncio |
913 | 975 | @pytest.mark.parametrize( |
914 | 976 | 'request_kwargs, expected_error_code', |
|
0 commit comments