Skip to content

Commit 1c648a2

Browse files
committed
WIP
1 parent 9562522 commit 1c648a2

6 files changed

Lines changed: 110 additions & 10 deletions

File tree

src/a2a/client/transports/grpc.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,17 +61,29 @@ def _map_grpc_error(e: grpc.aio.AioRpcError) -> NoReturn:
6161

6262
# Use grpc_status to cleanly extract the rich Status from the call
6363
status = rpc_status.from_call(cast('grpc.Call', e))
64+
data = None
6465

6566
if status is not None:
67+
exception_cls = None
6668
for detail in status.details:
67-
if detail.Is(error_details_pb2.ErrorInfo.DESCRIPTOR):
69+
if detail.Is(error_details_pb2.BadRequest.DESCRIPTOR):
70+
bad_request = error_details_pb2.BadRequest()
71+
detail.Unpack(bad_request)
72+
errors = [
73+
{'field': v.field, 'message': v.description}
74+
for v in bad_request.field_violations
75+
]
76+
data = {'errors': errors}
77+
# Infer InvalidParamsError from BadRequest details
78+
exception_cls = A2A_REASON_TO_ERROR.get('INVALID_PARAMS')
79+
elif detail.Is(error_details_pb2.ErrorInfo.DESCRIPTOR):
6880
error_info = error_details_pb2.ErrorInfo()
6981
detail.Unpack(error_info)
70-
7182
if error_info.domain == 'a2a-protocol.org':
7283
exception_cls = A2A_REASON_TO_ERROR.get(error_info.reason)
73-
if exception_cls:
74-
raise exception_cls(status.message) from e
84+
85+
if exception_cls:
86+
raise exception_cls(status.message, data=data) from e
7587

7688
raise A2AClientError(f'gRPC Error {e.code().name}: {e.details()}') from e
7789

src/a2a/client/transports/jsonrpc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,9 +318,10 @@ def _create_jsonrpc_error(self, error_dict: dict[str, Any]) -> Exception:
318318
"""Creates the appropriate A2AError from a JSON-RPC error dictionary."""
319319
code = error_dict.get('code')
320320
message = error_dict.get('message', str(error_dict))
321+
data = error_dict.get('data')
321322

322323
if isinstance(code, int) and code in _JSON_RPC_ERROR_CODE_TO_A2A_ERROR:
323-
return _JSON_RPC_ERROR_CODE_TO_A2A_ERROR[code](message)
324+
return _JSON_RPC_ERROR_CODE_TO_A2A_ERROR[code](message, data=data)
324325

325326
# Fallback to general A2AClientError
326327
return A2AClientError(f'JSON-RPC Error {code}: {message}')

src/a2a/server/request_handlers/grpc_handler.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -438,16 +438,29 @@ async def abort_context(
438438
error.message if hasattr(error, 'message') else str(error)
439439
)
440440

441-
# Create standard Status and pack the ErrorInfo
441+
# Create standard Status
442442
status = status_pb2.Status(code=status_code, message=error_msg)
443-
detail = any_pb2.Any()
444-
detail.Pack(error_info)
445-
status.details.append(detail)
443+
444+
# Exclusive details based on error type:
445+
if error.data and error.data.get('errors'):
446+
bad_request = error_details_pb2.BadRequest()
447+
for err_dict in error.data['errors']:
448+
violation = bad_request.field_violations.add()
449+
violation.field = err_dict.get('field', '')
450+
violation.description = err_dict.get('message', '')
451+
any_bad_request = any_pb2.Any()
452+
any_bad_request.Pack(bad_request)
453+
status.details.append(any_bad_request)
454+
else:
455+
detail = any_pb2.Any()
456+
detail.Pack(error_info)
457+
status.details.append(detail)
446458

447459
# Use grpc_status to safely generate standard trailing metadata
448460
rich_status = rpc_status.to_status(status)
449461

450462
new_metadata: list[tuple[str, str | bytes]] = []
463+
451464
trailing = context.trailing_metadata()
452465
if trailing:
453466
for k, v in trailing:

src/a2a/server/request_handlers/jsonrpc_handler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def _build_error_response(
9292
jsonrpc_error = model_class(
9393
code=code,
9494
message=str(error),
95+
data=error.data,
9596
)
9697
else:
9798
jsonrpc_error = JSONRPCInternalError(message=str(error))

tests/integration/test_end_to_end.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import AsyncGenerator
2-
from typing import NamedTuple
2+
from typing import Any, NamedTuple
33

44
import grpc
55
import httpx
@@ -31,6 +31,7 @@
3131
a2a_pb2_grpc,
3232
)
3333
from a2a.utils import TransportProtocol
34+
from a2a.utils.errors import InvalidParamsError
3435

3536

3637
def assert_message_matches(message, expected_role, expected_text):
@@ -546,3 +547,34 @@ async def test_end_to_end_input_required(transport_setups):
546547
],
547548
)
548549
assert_message_matches(task.status.message, Role.ROLE_AGENT, 'done')
550+
551+
552+
@pytest.mark.asyncio
553+
@pytest.mark.parametrize(
554+
'empty_request, expected_fields',
555+
[
556+
(
557+
SendMessageRequest(),
558+
['message'],
559+
),
560+
(
561+
SendMessageRequest(message=Message()),
562+
['message.message_id', 'message.role', 'message.parts'],
563+
),
564+
],
565+
)
566+
async def test_end_to_end_validation_errors(
567+
transport_setups,
568+
empty_request: SendMessageRequest,
569+
expected_fields: list[str],
570+
) -> None:
571+
client = transport_setups.client
572+
573+
with pytest.raises(InvalidParamsError) as exc_info:
574+
async for _ in client.send_message(request=empty_request):
575+
pass
576+
577+
errors = exc_info.value.data.get('errors', [])
578+
assert {e['field'] for e in errors} == set(expected_fields)
579+
580+
await client.close()

tests/utils/test_proto_utils.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,3 +239,44 @@ def _message_to_rest_params(self, message: ProtobufMessage) -> QueryParams:
239239
return httpx.Request(
240240
'GET', 'http://api.example.com', params=rest_dict
241241
).url.params
242+
243+
244+
class TestValidateProtoRequiredFields:
245+
"""Tests for validate_proto_required_fields function."""
246+
247+
def test_valid_required_fields(self):
248+
"""Test with all required fields present."""
249+
msg = Message(
250+
message_id='msg-1',
251+
role=Role.ROLE_USER,
252+
parts=[Part(text='hello')],
253+
)
254+
proto_utils.validate_proto_required_fields(msg)
255+
256+
def test_missing_required_fields(self):
257+
"""Test with empty message raising InvalidParamsError containing all errors."""
258+
from a2a.utils.errors import InvalidParamsError
259+
260+
msg = Message()
261+
with pytest.raises(InvalidParamsError) as exc_info:
262+
proto_utils.validate_proto_required_fields(msg)
263+
264+
err = exc_info.value
265+
errors = err.data.get('errors', []) if err.data else []
266+
267+
assert {e['field'] for e in errors} == {'message_id', 'role', 'parts'}
268+
269+
def test_nested_required_fields(self):
270+
"""Test nested required fields inside TaskStatus."""
271+
from a2a.utils.errors import InvalidParamsError
272+
273+
# Task Status requires 'state'
274+
task = Task(id='task-1', status=TaskStatus())
275+
with pytest.raises(InvalidParamsError) as exc_info:
276+
proto_utils.validate_proto_required_fields(task)
277+
278+
err = exc_info.value
279+
errors = err.data.get('errors', []) if err.data else []
280+
281+
fields = [e['field'] for e in errors]
282+
assert 'status.state' in fields

0 commit comments

Comments
 (0)