|
3 | 3 | import logging |
4 | 4 |
|
5 | 5 | from abc import ABC, abstractmethod |
6 | | -from collections.abc import AsyncIterable, Awaitable |
| 6 | +from collections.abc import AsyncIterable, Awaitable, Callable |
7 | 7 |
|
8 | 8 |
|
9 | 9 | try: |
10 | 10 | import grpc # type: ignore[reportMissingModuleSource] |
11 | 11 | import grpc.aio # type: ignore[reportMissingModuleSource] |
| 12 | + |
| 13 | + from grpc_status import rpc_status |
12 | 14 | except ImportError as e: |
13 | 15 | raise ImportError( |
14 | | - 'GrpcHandler requires grpcio and grpcio-tools to be installed. ' |
| 16 | + 'GrpcHandler requires grpcio, grpcio-tools, and grpcio-status to be installed. ' |
15 | 17 | 'Install with: ' |
16 | 18 | "'pip install a2a-sdk[grpc]'" |
17 | 19 | ) from e |
18 | 20 |
|
19 | | -from collections.abc import Callable |
20 | | - |
21 | | -from google.protobuf import empty_pb2, message |
| 21 | +from google.protobuf import any_pb2, empty_pb2, message |
| 22 | +from google.rpc import error_details_pb2, status_pb2 |
22 | 23 |
|
23 | 24 | import a2a.types.a2a_pb2_grpc as a2a_grpc |
24 | 25 |
|
|
33 | 34 | from a2a.types import a2a_pb2 |
34 | 35 | from a2a.types.a2a_pb2 import AgentCard |
35 | 36 | from a2a.utils import proto_utils |
36 | | -from a2a.utils.errors import A2AError, TaskNotFoundError |
| 37 | +from a2a.utils.errors import A2A_ERROR_REASONS, A2AError, TaskNotFoundError |
37 | 38 | from a2a.utils.helpers import maybe_await, validate, validate_async_generator |
38 | 39 |
|
39 | 40 |
|
@@ -419,11 +420,41 @@ async def abort_context( |
419 | 420 | ) -> None: |
420 | 421 | """Sets the grpc errors appropriately in the context.""" |
421 | 422 | code = _ERROR_CODE_MAP.get(type(error)) |
| 423 | + |
422 | 424 | if code: |
423 | | - await context.abort( |
424 | | - code, |
425 | | - f'{type(error).__name__}: {error.message}', |
| 425 | + reason = A2A_ERROR_REASONS.get(type(error), 'UNKNOWN_ERROR') |
| 426 | + error_info = error_details_pb2.ErrorInfo( |
| 427 | + reason=reason, |
| 428 | + domain='a2a-protocol.org', |
| 429 | + ) |
| 430 | + |
| 431 | + status_code = ( |
| 432 | + code.value[0] if code else grpc.StatusCode.UNKNOWN.value[0] |
426 | 433 | ) |
| 434 | + error_msg = ( |
| 435 | + error.message if hasattr(error, 'message') else str(error) |
| 436 | + ) |
| 437 | + |
| 438 | + # Create standard Status and pack the ErrorInfo |
| 439 | + status = status_pb2.Status(code=status_code, message=error_msg) |
| 440 | + detail = any_pb2.Any() |
| 441 | + detail.Pack(error_info) |
| 442 | + status.details.append(detail) |
| 443 | + |
| 444 | + # Use grpc_status to safely generate standard trailing metadata |
| 445 | + rich_status = rpc_status.to_status(status) |
| 446 | + |
| 447 | + new_metadata: list[tuple[str, str | bytes]] = [] |
| 448 | + trailing = context.trailing_metadata() |
| 449 | + if trailing: |
| 450 | + for k, v in trailing: |
| 451 | + new_metadata.append((str(k), v)) |
| 452 | + |
| 453 | + for k, v in rich_status.trailing_metadata: |
| 454 | + new_metadata.append((str(k), v)) |
| 455 | + |
| 456 | + context.set_trailing_metadata(tuple(new_metadata)) |
| 457 | + await context.abort(rich_status.code, rich_status.details) |
427 | 458 | else: |
428 | 459 | await context.abort( |
429 | 460 | grpc.StatusCode.UNKNOWN, |
|
0 commit comments