Skip to content

Commit c87dac2

Browse files
committed
Updates
1 parent e05ea15 commit c87dac2

6 files changed

Lines changed: 47 additions & 51 deletions

File tree

src/a2a/client/transports/grpc.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@
4747
TaskPushNotificationConfig,
4848
)
4949
from a2a.utils.constants import PROTOCOL_VERSION_CURRENT, VERSION_HEADER
50-
from a2a.utils.errors import A2A_REASON_TO_ERROR, A2AError, InvalidParamsError
50+
from a2a.utils.errors import A2A_REASON_TO_ERROR, A2AError
51+
from a2a.utils.proto_utils import bad_request_to_validation_errors
5152
from a2a.utils.telemetry import SpanKind, trace_class
5253

5354

@@ -66,22 +67,15 @@ def _map_grpc_error(e: grpc.aio.AioRpcError) -> NoReturn:
6667
if status is not None:
6768
exception_cls: type[A2AError] | None = None
6869
for detail in status.details:
69-
if detail.Is(error_details_pb2.BadRequest.DESCRIPTOR):
70-
bad_request = error_details_pb2.BadRequest()
71-
detail.Unpack(bad_request)
72-
errors = [
73-
{'field': v.field, 'message': v.description}
74-
for v in bad_request.field_violations
75-
]
76-
data = {'errors': errors}
77-
exception_cls = InvalidParamsError
78-
break
7970
if detail.Is(error_details_pb2.ErrorInfo.DESCRIPTOR):
8071
error_info = error_details_pb2.ErrorInfo()
8172
detail.Unpack(error_info)
8273
if error_info.domain == 'a2a-protocol.org':
8374
exception_cls = A2A_REASON_TO_ERROR.get(error_info.reason)
84-
break
75+
elif detail.Is(error_details_pb2.BadRequest.DESCRIPTOR):
76+
bad_request = error_details_pb2.BadRequest()
77+
detail.Unpack(bad_request)
78+
data = {'errors': bad_request_to_validation_errors(bad_request)}
8579

8680
if exception_cls:
8781
raise exception_cls(status.message, data=data) from e

src/a2a/server/request_handlers/grpc_handler.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
TaskNotFoundError,
4242
)
4343
from a2a.utils.helpers import maybe_await, validate
44+
from a2a.utils.proto_utils import validation_errors_to_bad_request
4445

4546

4647
logger = logging.getLogger(__name__)
@@ -403,26 +404,23 @@ async def abort_context(
403404
error.message if hasattr(error, 'message') else str(error)
404405
)
405406

406-
# Create standard Status
407+
# Create standard Status with ErrorInfo for all A2A errors
407408
status = status_pb2.Status(code=status_code, message=error_msg)
409+
error_info_detail = any_pb2.Any()
410+
error_info_detail.Pack(error_info)
411+
status.details.append(error_info_detail)
408412

413+
# Append structured field violations for validation errors
409414
if (
410415
isinstance(error, types.InvalidParamsError)
411416
and error.data
412417
and error.data.get('errors')
413418
):
414-
bad_request = error_details_pb2.BadRequest()
415-
for err_dict in error.data['errors']:
416-
violation = bad_request.field_violations.add()
417-
violation.field = err_dict.get('field', '')
418-
violation.description = err_dict.get('message', '')
419-
any_bad_request = any_pb2.Any()
420-
any_bad_request.Pack(bad_request)
421-
status.details.append(any_bad_request)
422-
else:
423-
detail = any_pb2.Any()
424-
detail.Pack(error_info)
425-
status.details.append(detail)
419+
bad_request_detail = any_pb2.Any()
420+
bad_request_detail.Pack(
421+
validation_errors_to_bad_request(error.data['errors'])
422+
)
423+
status.details.append(bad_request_detail)
426424

427425
# Use grpc_status to safely generate standard trailing metadata
428426
rich_status = rpc_status.to_status(status)

src/a2a/server/request_handlers/request_handler.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -246,24 +246,8 @@ async def async_gen_wrapper(
246246

247247
return async_gen_wrapper
248248

249-
if inspect.iscoroutinefunction(method):
250-
251-
@functools.wraps(method)
252-
async def async_wrapper(
253-
self: RequestHandler,
254-
params: ProtoMessage,
255-
context: ServerCallContext,
256-
*args: Any,
257-
**kwargs: Any,
258-
) -> Any:
259-
if params is not None:
260-
validate_proto_required_fields(params)
261-
return await method(self, params, context, *args, **kwargs)
262-
263-
return async_wrapper
264-
265249
@functools.wraps(method)
266-
def sync_wrapper(
250+
async def async_wrapper(
267251
self: RequestHandler,
268252
params: ProtoMessage,
269253
context: ServerCallContext,
@@ -272,6 +256,6 @@ def sync_wrapper(
272256
) -> Any:
273257
if params is not None:
274258
validate_proto_required_fields(params)
275-
return method(self, params, context, *args, **kwargs)
259+
return await method(self, params, context, *args, **kwargs)
276260

277-
return sync_wrapper
261+
return async_wrapper

src/a2a/utils/proto_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from google.protobuf.descriptor import FieldDescriptor
2424
from google.protobuf.json_format import ParseDict
2525
from google.protobuf.message import Message as ProtobufMessage
26+
from google.rpc import error_details_pb2
2627

2728
from a2a.utils.errors import InvalidParamsError
2829

@@ -296,3 +297,25 @@ def validate_proto_required_fields(msg: ProtobufMessage) -> None:
296297
raise InvalidParamsError(
297298
message='Validation failed', data={'errors': errors}
298299
)
300+
301+
302+
def validation_errors_to_bad_request(
303+
errors: list[ValidationDetail],
304+
) -> error_details_pb2.BadRequest:
305+
"""Convert validation error details to a gRPC BadRequest proto."""
306+
bad_request = error_details_pb2.BadRequest()
307+
for err in errors:
308+
violation = bad_request.field_violations.add()
309+
violation.field = err['field']
310+
violation.description = err['message']
311+
return bad_request
312+
313+
314+
def bad_request_to_validation_errors(
315+
bad_request: error_details_pb2.BadRequest,
316+
) -> list[ValidationDetail]:
317+
"""Convert a gRPC BadRequest proto to validation error details."""
318+
return [
319+
ValidationDetail(field=v.field, message=v.description)
320+
for v in bad_request.field_violations
321+
]

tests/integration/test_end_to_end.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import AsyncGenerator
2-
from typing import Any, NamedTuple
2+
from typing import NamedTuple
33

44
import grpc
55
import httpx

tests/utils/test_proto_utils.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55

66
import httpx
77
import pytest
8+
89
from google.protobuf.json_format import MessageToDict, Parse
910
from google.protobuf.message import Message as ProtobufMessage
1011
from google.protobuf.timestamp_pb2 import Timestamp
12+
from starlette.datastructures import QueryParams
1113

1214
from a2a.types.a2a_pb2 import (
13-
AgentCard,
1415
AgentSkill,
1516
ListTasksRequest,
1617
Message,
@@ -23,8 +24,8 @@
2324
TaskStatus,
2425
TaskStatusUpdateEvent,
2526
)
26-
from starlette.datastructures import QueryParams
2727
from a2a.utils import proto_utils
28+
from a2a.utils.errors import InvalidParamsError
2829

2930

3031
class TestToStreamResponse:
@@ -255,8 +256,6 @@ def test_valid_required_fields(self):
255256

256257
def test_missing_required_fields(self):
257258
"""Test with empty message raising InvalidParamsError containing all errors."""
258-
from a2a.utils.errors import InvalidParamsError
259-
260259
msg = Message()
261260
with pytest.raises(InvalidParamsError) as exc_info:
262261
proto_utils.validate_proto_required_fields(msg)
@@ -268,8 +267,6 @@ def test_missing_required_fields(self):
268267

269268
def test_nested_required_fields(self):
270269
"""Test nested required fields inside TaskStatus."""
271-
from a2a.utils.errors import InvalidParamsError
272-
273270
# Task Status requires 'state'
274271
task = Task(id='task-1', status=TaskStatus())
275272
with pytest.raises(InvalidParamsError) as exc_info:

0 commit comments

Comments
 (0)