|
41 | 41 | TaskNotFoundError, |
42 | 42 | ) |
43 | 43 | from a2a.utils.helpers import maybe_await, validate |
| 44 | +from a2a.utils.proto_utils import validation_errors_to_bad_request |
44 | 45 |
|
45 | 46 |
|
46 | 47 | logger = logging.getLogger(__name__) |
@@ -403,26 +404,23 @@ async def abort_context( |
403 | 404 | error.message if hasattr(error, 'message') else str(error) |
404 | 405 | ) |
405 | 406 |
|
406 | | - # Create standard Status |
| 407 | + # Create standard Status with ErrorInfo for all A2A errors |
407 | 408 | status = status_pb2.Status(code=status_code, message=error_msg) |
| 409 | + error_info_detail = any_pb2.Any() |
| 410 | + error_info_detail.Pack(error_info) |
| 411 | + status.details.append(error_info_detail) |
408 | 412 |
|
| 413 | + # Append structured field violations for validation errors |
409 | 414 | if ( |
410 | 415 | isinstance(error, types.InvalidParamsError) |
411 | 416 | and error.data |
412 | 417 | and error.data.get('errors') |
413 | 418 | ): |
414 | | - bad_request = error_details_pb2.BadRequest() |
415 | | - for err_dict in error.data['errors']: |
416 | | - violation = bad_request.field_violations.add() |
417 | | - violation.field = err_dict.get('field', '') |
418 | | - violation.description = err_dict.get('message', '') |
419 | | - any_bad_request = any_pb2.Any() |
420 | | - any_bad_request.Pack(bad_request) |
421 | | - status.details.append(any_bad_request) |
422 | | - else: |
423 | | - detail = any_pb2.Any() |
424 | | - detail.Pack(error_info) |
425 | | - status.details.append(detail) |
| 419 | + bad_request_detail = any_pb2.Any() |
| 420 | + bad_request_detail.Pack( |
| 421 | + validation_errors_to_bad_request(error.data['errors']) |
| 422 | + ) |
| 423 | + status.details.append(bad_request_detail) |
426 | 424 |
|
427 | 425 | # Use grpc_status to safely generate standard trailing metadata |
428 | 426 | rich_status = rpc_status.to_status(status) |
|
0 commit comments