Skip to content

Commit 245eca3

Browse files
knapgishymko
andauthored
feat: implement rich gRPC error details per A2A v1.0 spec (#790)
# Description This PR implements standard gRPC rich error handling using `google.rpc.Status` and `google.rpc.ErrorInfo`, bringing the SDK's gRPC transport fully in line with the A2A v1.0 specification. Previously, the gRPC server appended the exception name to the string message (e.g., "TaskNotFoundError: task not found"), and the client relied on string splitting to parse the error back into a domain exception. This approach was brittle and not interoperable with standard gRPC ecosystems (proxies, gateways, etc.). This PR replaces the legacy string-parsing heuristic entirely with strongly-typed binary metadata (`grpc-status-details-bin`). - [X] Follow the [`CONTRIBUTING` Guide](https://github.com/a2aproject/a2a-python/blob/main/CONTRIBUTING.md). - [X] Make your Pull Request title in the <https://www.conventionalcommits.org/> specification. - Important Prefixes for [release-please](https://github.com/googleapis/release-please): - `fix:` which represents bug fixes, and correlates to a [SemVer](https://semver.org/) patch. - `feat:` represents a new feature, and correlates to a SemVer minor. - `feat!:`, or `fix!:`, `refactor!:`, etc., which represent a breaking change (indicated by the `!`) and will result in a SemVer major. - [X] Ensure the tests and linter pass (Run `bash scripts/format.sh` from the repository root to format) - [X] Appropriate docs were updated (if necessary) Fixes #723 🦕 --------- Co-authored-by: Ivan Shymko <ishymko@google.com>
1 parent a55c97e commit 245eca3

9 files changed

Lines changed: 181 additions & 51 deletions

File tree

.jscpd.json

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
{
2-
"ignore": ["**/.github/**", "**/.git/**", "**/tests/**", "**/src/a2a/grpc/**", "**/.nox/**", "**/.venv/**"],
2+
"ignore": [
3+
"**/.github/**",
4+
"**/.git/**",
5+
"**/tests/**",
6+
"**/src/a2a/grpc/**",
7+
"**/src/a2a/compat/**",
8+
"**/.nox/**",
9+
"**/.venv/**"
10+
],
311
"threshold": 3,
412
"reporters": ["html", "markdown"]
513
}

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ classifiers = [
3333
[project.optional-dependencies]
3434
http-server = ["fastapi>=0.115.2", "sse-starlette", "starlette"]
3535
encryption = ["cryptography>=43.0.0"]
36-
grpc = ["grpcio>=1.60", "grpcio-tools>=1.60", "grpcio_reflection>=1.7.0"]
36+
grpc = ["grpcio>=1.60", "grpcio-tools>=1.60", "grpcio-status>=1.60", "grpcio_reflection>=1.7.0"]
3737
telemetry = ["opentelemetry-api>=1.33.0", "opentelemetry-sdk>=1.33.0"]
3838
postgresql = ["sqlalchemy[asyncio,postgresql-asyncpg]>=2.0.0"]
3939
mysql = ["sqlalchemy[asyncio,aiomysql]>=2.0.0"]

src/a2a/client/transports/grpc.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,29 @@
22

33
from collections.abc import AsyncGenerator, Callable
44
from functools import wraps
5-
from typing import Any, NoReturn
5+
from typing import Any, NoReturn, cast
66

7+
from a2a.client.errors import A2AClientError, A2AClientTimeoutError
78
from a2a.client.middleware import ClientCallContext
8-
from a2a.utils.errors import JSON_RPC_ERROR_CODE_MAP
99

1010

1111
try:
1212
import grpc # type: ignore[reportMissingModuleSource]
13+
14+
from grpc_status import rpc_status
1315
except ImportError as e:
1416
raise ImportError(
15-
'A2AGrpcClient requires grpcio and grpcio-tools to be installed. '
17+
'A2AGrpcClient requires grpcio, grpcio-tools, and grpcio-status to be installed. '
1618
'Install with: '
1719
"'pip install a2a-sdk[grpc]'"
1820
) from e
1921

2022

23+
from google.rpc import ( # type: ignore[reportMissingModuleSource]
24+
error_details_pb2,
25+
)
26+
2127
from a2a.client.client import ClientConfig
22-
from a2a.client.errors import A2AClientError, A2AClientTimeoutError
2328
from a2a.client.middleware import ClientCallInterceptor
2429
from a2a.client.optionals import Channel
2530
from a2a.client.transports.base import ClientTransport
@@ -43,27 +48,32 @@
4348
TaskPushNotificationConfig,
4449
)
4550
from a2a.utils.constants import PROTOCOL_VERSION_CURRENT, VERSION_HEADER
51+
from a2a.utils.errors import A2A_REASON_TO_ERROR
4652
from a2a.utils.telemetry import SpanKind, trace_class
4753

4854

4955
logger = logging.getLogger(__name__)
5056

51-
_A2A_ERROR_NAME_TO_CLS = {
52-
error_type.__name__: error_type for error_type in JSON_RPC_ERROR_CODE_MAP
53-
}
54-
5557

5658
def _map_grpc_error(e: grpc.aio.AioRpcError) -> NoReturn:
59+
5760
if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
5861
raise A2AClientTimeoutError('Client Request timed out') from e
5962

60-
details = e.details()
61-
if isinstance(details, str) and ': ' in details:
62-
error_type_name, error_message = details.split(': ', 1)
63-
# TODO(#723): Resolving imports by name is temporary until proper error handling structure is added in #723.
64-
exception_cls = _A2A_ERROR_NAME_TO_CLS.get(error_type_name)
65-
if exception_cls:
66-
raise exception_cls(error_message) from e
63+
# Use grpc_status to cleanly extract the rich Status from the call
64+
status = rpc_status.from_call(cast('grpc.Call', e))
65+
66+
if status is not None:
67+
for detail in status.details:
68+
if detail.Is(error_details_pb2.ErrorInfo.DESCRIPTOR):
69+
error_info = error_details_pb2.ErrorInfo()
70+
detail.Unpack(error_info)
71+
72+
if error_info.domain == 'a2a-protocol.org':
73+
exception_cls = A2A_REASON_TO_ERROR.get(error_info.reason)
74+
if exception_cls:
75+
raise exception_cls(status.message) from e
76+
6777
raise A2AClientError(f'gRPC Error {e.code().name}: {e.details()}') from e
6878

6979

src/a2a/server/request_handlers/grpc_handler.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,23 @@
33
import logging
44

55
from abc import ABC, abstractmethod
6-
from collections.abc import AsyncIterable, Awaitable
6+
from collections.abc import AsyncIterable, Awaitable, Callable
77

88

99
try:
1010
import grpc # type: ignore[reportMissingModuleSource]
1111
import grpc.aio # type: ignore[reportMissingModuleSource]
12+
13+
from grpc_status import rpc_status
1214
except ImportError as e:
1315
raise ImportError(
14-
'GrpcHandler requires grpcio and grpcio-tools to be installed. '
16+
'GrpcHandler requires grpcio, grpcio-tools, and grpcio-status to be installed. '
1517
'Install with: '
1618
"'pip install a2a-sdk[grpc]'"
1719
) from e
1820

19-
from collections.abc import Callable
20-
21-
from google.protobuf import empty_pb2, message
21+
from google.protobuf import any_pb2, empty_pb2, message
22+
from google.rpc import error_details_pb2, status_pb2
2223

2324
import a2a.types.a2a_pb2_grpc as a2a_grpc
2425

@@ -33,7 +34,7 @@
3334
from a2a.types import a2a_pb2
3435
from a2a.types.a2a_pb2 import AgentCard
3536
from a2a.utils import proto_utils
36-
from a2a.utils.errors import A2AError, TaskNotFoundError
37+
from a2a.utils.errors import A2A_ERROR_REASONS, A2AError, TaskNotFoundError
3738
from a2a.utils.helpers import maybe_await, validate, validate_async_generator
3839

3940

@@ -419,11 +420,41 @@ async def abort_context(
419420
) -> None:
420421
"""Sets the grpc errors appropriately in the context."""
421422
code = _ERROR_CODE_MAP.get(type(error))
423+
422424
if code:
423-
await context.abort(
424-
code,
425-
f'{type(error).__name__}: {error.message}',
425+
reason = A2A_ERROR_REASONS.get(type(error), 'UNKNOWN_ERROR')
426+
error_info = error_details_pb2.ErrorInfo(
427+
reason=reason,
428+
domain='a2a-protocol.org',
429+
)
430+
431+
status_code = (
432+
code.value[0] if code else grpc.StatusCode.UNKNOWN.value[0]
426433
)
434+
error_msg = (
435+
error.message if hasattr(error, 'message') else str(error)
436+
)
437+
438+
# Create standard Status and pack the ErrorInfo
439+
status = status_pb2.Status(code=status_code, message=error_msg)
440+
detail = any_pb2.Any()
441+
detail.Pack(error_info)
442+
status.details.append(detail)
443+
444+
# Use grpc_status to safely generate standard trailing metadata
445+
rich_status = rpc_status.to_status(status)
446+
447+
new_metadata: list[tuple[str, str | bytes]] = []
448+
trailing = context.trailing_metadata()
449+
if trailing:
450+
for k, v in trailing:
451+
new_metadata.append((str(k), v))
452+
453+
for k, v in rich_status.trailing_metadata:
454+
new_metadata.append((str(k), v))
455+
456+
context.set_trailing_metadata(tuple(new_metadata))
457+
await context.abort(rich_status.code, rich_status.details)
427458
else:
428459
await context.abort(
429460
grpc.StatusCode.UNKNOWN,

src/a2a/server/request_handlers/response_helpers.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
SendMessageResponse as SendMessageResponseProto,
2828
)
2929
from a2a.utils.errors import (
30+
JSON_RPC_ERROR_CODE_MAP,
3031
A2AError,
3132
AuthenticatedExtendedCardNotConfiguredError,
3233
ContentTypeNotSupportedError,
@@ -56,19 +57,6 @@
5657
InternalError: JSONRPCInternalError,
5758
}
5859

59-
ERROR_CODE_MAP: dict[type[A2AError], int] = {
60-
TaskNotFoundError: -32001,
61-
TaskNotCancelableError: -32002,
62-
PushNotificationNotSupportedError: -32003,
63-
UnsupportedOperationError: -32004,
64-
ContentTypeNotSupportedError: -32005,
65-
InvalidAgentResponseError: -32006,
66-
AuthenticatedExtendedCardNotConfiguredError: -32007,
67-
InvalidParamsError: -32602,
68-
InvalidRequestError: -32600,
69-
MethodNotFoundError: -32601,
70-
}
71-
7260

7361
# Tuple of all A2AError types for isinstance checks
7462
_A2A_ERROR_TYPES: tuple[type, ...] = (A2AError,)
@@ -136,7 +124,7 @@ def build_error_response(
136124
elif isinstance(error, A2AError):
137125
error_type = type(error)
138126
model_class = EXCEPTION_MAP.get(error_type, JSONRPCInternalError)
139-
code = ERROR_CODE_MAP.get(error_type, -32603)
127+
code = JSON_RPC_ERROR_CODE_MAP.get(error_type, -32603)
140128
jsonrpc_error = model_class(
141129
code=code,
142130
message=str(error),

src/a2a/utils/errors.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,26 @@ class MethodNotFoundError(A2AError):
8282
message = 'Method not found'
8383

8484

85+
class ExtensionSupportRequiredError(A2AError):
86+
"""Exception raised when extension support is required but not present."""
87+
88+
message = 'Extension support required'
89+
90+
91+
class VersionNotSupportedError(A2AError):
92+
"""Exception raised when the requested version is not supported."""
93+
94+
message = 'Version not supported'
95+
96+
8597
# For backward compatibility if needed, or just aliases for clean refactor
8698
# We remove the Pydantic models here.
8799

88100
__all__ = [
101+
'A2A_ERROR_REASONS',
102+
'A2A_REASON_TO_ERROR',
89103
'JSON_RPC_ERROR_CODE_MAP',
104+
'ExtensionSupportRequiredError',
90105
'InternalError',
91106
'InvalidAgentResponseError',
92107
'InvalidParamsError',
@@ -96,6 +111,7 @@ class MethodNotFoundError(A2AError):
96111
'TaskNotCancelableError',
97112
'TaskNotFoundError',
98113
'UnsupportedOperationError',
114+
'VersionNotSupportedError',
99115
]
100116

101117

@@ -112,3 +128,18 @@ class MethodNotFoundError(A2AError):
112128
MethodNotFoundError: -32601,
113129
InternalError: -32603,
114130
}
131+
132+
133+
A2A_ERROR_REASONS = {
134+
TaskNotFoundError: 'TASK_NOT_FOUND',
135+
TaskNotCancelableError: 'TASK_NOT_CANCELABLE',
136+
PushNotificationNotSupportedError: 'PUSH_NOTIFICATION_NOT_SUPPORTED',
137+
UnsupportedOperationError: 'UNSUPPORTED_OPERATION',
138+
ContentTypeNotSupportedError: 'CONTENT_TYPE_NOT_SUPPORTED',
139+
InvalidAgentResponseError: 'INVALID_AGENT_RESPONSE',
140+
AuthenticatedExtendedCardNotConfiguredError: 'EXTENDED_AGENT_CARD_NOT_CONFIGURED',
141+
ExtensionSupportRequiredError: 'EXTENSION_SUPPORT_REQUIRED',
142+
VersionNotSupportedError: 'VERSION_NOT_SUPPORTED',
143+
}
144+
145+
A2A_REASON_TO_ERROR = {reason: cls for cls, reason in A2A_ERROR_REASONS.items()}

tests/client/transports/test_grpc_client.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,14 @@
33
import grpc
44
import pytest
55

6+
from google.protobuf import any_pb2
7+
from google.rpc import error_details_pb2, status_pb2
8+
69
from a2a.client.middleware import ClientCallContext
710
from a2a.client.transports.grpc import GrpcTransport
811
from a2a.extensions.common import HTTP_EXTENSION_HEADER
912
from a2a.utils.constants import VERSION_HEADER, PROTOCOL_VERSION_CURRENT
13+
from a2a.utils.errors import A2A_ERROR_REASONS
1014
from a2a.types import a2a_pb2
1115
from a2a.types.a2a_pb2 import (
1216
AgentCapabilities,
@@ -32,7 +36,6 @@
3236
TaskStatusUpdateEvent,
3337
)
3438
from a2a.utils import get_text_parts
35-
from a2a.utils.errors import JSON_RPC_ERROR_CODE_MAP
3639

3740

3841
@pytest.fixture
@@ -245,28 +248,45 @@ async def test_send_message_with_timeout_context(
245248
assert kwargs['timeout'] == 12.5
246249

247250

248-
@pytest.mark.parametrize('error_cls', list(JSON_RPC_ERROR_CODE_MAP.keys()))
251+
@pytest.mark.parametrize('error_cls', list(A2A_ERROR_REASONS.keys()))
249252
@pytest.mark.asyncio
250-
async def test_grpc_mapped_errors(
253+
async def test_grpc_mapped_errors_rich(
251254
grpc_transport: GrpcTransport,
252255
mock_grpc_stub: AsyncMock,
253256
sample_message_send_params: SendMessageRequest,
254257
error_cls,
255258
) -> None:
256-
"""Test handling of mapped gRPC error responses."""
259+
"""Test handling of rich gRPC error responses with Status metadata."""
260+
261+
reason = A2A_ERROR_REASONS.get(error_cls, 'UNKNOWN_ERROR')
262+
263+
error_info = error_details_pb2.ErrorInfo(
264+
reason=reason,
265+
domain='a2a-protocol.org',
266+
)
267+
257268
error_details = f'{error_cls.__name__}: Mapped Error'
269+
status = status_pb2.Status(
270+
code=grpc.StatusCode.INTERNAL.value[0], message=error_details
271+
)
272+
detail = any_pb2.Any()
273+
detail.Pack(error_info)
274+
status.details.append(detail)
258275

259-
# We must trigger it from a standard transport method call, for example `send_message`.
260276
mock_grpc_stub.SendMessage.side_effect = grpc.aio.AioRpcError(
261277
code=grpc.StatusCode.INTERNAL,
262278
initial_metadata=grpc.aio.Metadata(),
263-
trailing_metadata=grpc.aio.Metadata(),
279+
trailing_metadata=grpc.aio.Metadata(
280+
('grpc-status-details-bin', status.SerializeToString()),
281+
),
264282
details=error_details,
265283
)
266284

267-
with pytest.raises(error_cls):
285+
with pytest.raises(error_cls) as excinfo:
268286
await grpc_transport.send_message(sample_message_send_params)
269287

288+
assert str(excinfo.value) == error_details
289+
270290

271291
@pytest.mark.asyncio
272292
async def test_send_message_message_response(

0 commit comments

Comments
 (0)