Skip to content

Commit 5ad30e8

Browse files
committed
fix: handle REST query params as per 1.0 spec
1 parent 9856054 commit 5ad30e8

6 files changed

Lines changed: 322 additions & 603 deletions

File tree

src/a2a/client/transports/rest.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ async def get_task(
109109
params = MessageToDict(request)
110110
if 'id' in params:
111111
del params['id'] # id is part of the URL path
112+
if 'tenant' in params:
113+
del params['tenant']
112114

113115
response_data = await self._execute_request(
114116
'GET',
@@ -127,12 +129,16 @@ async def list_tasks(
127129
context: ClientCallContext | None = None,
128130
) -> ListTasksResponse:
129131
"""Retrieves tasks for an agent."""
132+
params = MessageToDict(request)
133+
if 'tenant' in params:
134+
del params['tenant']
135+
130136
response_data = await self._execute_request(
131137
'GET',
132138
'/tasks',
133139
request.tenant,
134140
context=context,
135-
params=MessageToDict(request),
141+
params=params,
136142
)
137143
response: ListTasksResponse = ParseDict(
138144
response_data, ListTasksResponse()
@@ -185,8 +191,10 @@ async def get_task_push_notification_config(
185191
params = MessageToDict(request)
186192
if 'id' in params:
187193
del params['id']
188-
if 'task_id' in params:
189-
del params['task_id']
194+
if 'taskId' in params:
195+
del params['taskId']
196+
if 'tenant' in params:
197+
del params['tenant']
190198

191199
response_data = await self._execute_request(
192200
'GET',
@@ -208,8 +216,10 @@ async def list_task_push_notification_configs(
208216
) -> ListTaskPushNotificationConfigsResponse:
209217
"""Lists push notification configurations for a specific task."""
210218
params = MessageToDict(request)
211-
if 'task_id' in params:
212-
del params['task_id']
219+
if 'taskId' in params:
220+
del params['taskId']
221+
if 'tenant' in params:
222+
del params['tenant']
213223

214224
response_data = await self._execute_request(
215225
'GET',
@@ -233,8 +243,10 @@ async def delete_task_push_notification_config(
233243
params = MessageToDict(request)
234244
if 'id' in params:
235245
del params['id']
236-
if 'task_id' in params:
237-
del params['task_id']
246+
if 'taskId' in params:
247+
del params['taskId']
248+
if 'tenant' in params:
249+
del params['tenant']
238250

239251
await self._execute_request(
240252
'DELETE',

src/a2a/server/request_handlers/rest_handler.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
MessageToDict,
88
MessageToJson,
99
Parse,
10-
ParseDict,
1110
)
1211

1312

@@ -27,7 +26,6 @@
2726
AgentCard,
2827
CancelTaskRequest,
2928
GetTaskPushNotificationConfigRequest,
30-
GetTaskRequest,
3129
SubscribeToTaskRequest,
3230
)
3331
from a2a.utils import proto_utils
@@ -39,7 +37,7 @@
3937
logger = logging.getLogger(__name__)
4038

4139

42-
@trace_class(kind=SpanKind.SERVER)
40+
@trace_class(kind=SpanKind.SERVER, exclude_list=['_parse_params'])
4341
class RESTHandler:
4442
"""Maps incoming REST-like (JSON+HTTP) requests to the appropriate request handler method and formats responses.
4543
@@ -248,9 +246,9 @@ async def on_get_task(
248246
A `Task` object containing the Task.
249247
"""
250248
task_id = request.path_params['id']
251-
history_length_str = request.query_params.get('historyLength')
252-
history_length = int(history_length_str) if history_length_str else None
253-
params = GetTaskRequest(id=task_id, history_length=history_length)
249+
params = a2a_pb2.GetTaskRequest()
250+
proto_utils.parse_params(request.query_params, params)
251+
params.id = task_id
254252
task = await self.request_handler.on_get_task(params, context)
255253
if task:
256254
return MessageToDict(task)
@@ -295,12 +293,11 @@ async def list_tasks(
295293
A list of `dict` representing the `Task` objects.
296294
"""
297295
params = a2a_pb2.ListTasksRequest()
298-
# Parse query params, keeping arrays/repeated fields in mind if there are any
299-
# Using a simple ParseDict for now, might need more robust query param parsing
300-
# if the request structure contains nested or repeated elements
301-
ParseDict(
302-
dict(request.query_params), params, ignore_unknown_fields=True
303-
)
296+
proto_utils.parse_params(request.query_params, params)
297+
# Ensure tenant is set if provided in context
298+
if context.tenant and not params.tenant:
299+
params.tenant = context.tenant
300+
304301
result = await self.request_handler.on_list_tasks(params, context)
305302
return MessageToDict(result)
306303

@@ -319,12 +316,12 @@ async def list_push_notifications(
319316
A list of `dict` representing the `TaskPushNotificationConfig` objects.
320317
"""
321318
task_id = request.path_params['id']
322-
params = a2a_pb2.ListTaskPushNotificationConfigsRequest(task_id=task_id)
323-
324-
# Parse query params, keeping arrays/repeated fields in mind if there are any
325-
ParseDict(
326-
dict(request.query_params), params, ignore_unknown_fields=True
327-
)
319+
params = a2a_pb2.ListTaskPushNotificationConfigsRequest()
320+
proto_utils.parse_params(request.query_params, params)
321+
params.task_id = task_id
322+
# Ensure tenant is set if provided in context
323+
if context.tenant and not params.tenant:
324+
params.tenant = context.tenant
328325

329326
result = (
330327
await self.request_handler.on_list_task_push_notification_configs(

src/a2a/utils/proto_utils.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,19 @@
1717
This module provides helper functions for common proto type operations.
1818
"""
1919

20-
from typing import Any
20+
from typing import TYPE_CHECKING, Any
21+
22+
from google.protobuf.json_format import ParseDict
23+
from google.protobuf.message import Message as ProtobufMessage
24+
25+
26+
if TYPE_CHECKING:
27+
from starlette.datastructures import QueryParams
28+
else:
29+
try:
30+
from starlette.datastructures import QueryParams
31+
except ImportError:
32+
QueryParams = Any
2133

2234
from a2a.types.a2a_pb2 import (
2335
Message,
@@ -131,3 +143,49 @@ def parse_string_integers_in_dict(value: Any, max_safe_digits: int = 15) -> Any:
131143
if stripped_value.isdigit() and len(stripped_value) > max_safe_digits:
132144
return int(value)
133145
return value
146+
147+
148+
def parse_params(params: QueryParams, message: ProtobufMessage) -> None:
149+
"""Converts REST query parameters back into a Protobuf message.
150+
151+
Handles A2A-specific pre-processing before calling ParseDict:
152+
- Booleans: 'true'/'false' -> True/False
153+
- Repeated: Supports BOTH repeated keys and comma-separated values.
154+
- Others: Handles string->enum/timestamp/number conversion via ParseDict.
155+
156+
See Also:
157+
https://a2a-protocol.org/latest/specification/#115-query-parameter-naming-for-request-parameters
158+
"""
159+
descriptor = message.DESCRIPTOR
160+
fields = {f.camelcase_name: f for f in descriptor.fields}
161+
processed: dict[str, Any] = {}
162+
163+
keys = params.keys()
164+
165+
for k in keys:
166+
if k not in fields:
167+
continue
168+
169+
field = fields[k]
170+
v_list = params.getlist(k)
171+
172+
if field.label == field.LABEL_REPEATED:
173+
accumulated: list[Any] = []
174+
for v in v_list:
175+
if not v:
176+
continue
177+
if isinstance(v, str):
178+
accumulated.extend([x for x in v.split(',') if x])
179+
else:
180+
accumulated.append(v)
181+
processed[k] = accumulated
182+
else:
183+
# For non-repeated fields, the last one wins.
184+
raw_val = v_list[-1]
185+
if raw_val is not None:
186+
parsed_val: Any = raw_val
187+
if field.type == field.TYPE_BOOL and isinstance(raw_val, str):
188+
parsed_val = raw_val.lower() == 'true'
189+
processed[k] = parsed_val
190+
191+
ParseDict(processed, message, ignore_unknown_fields=True)

tests/client/transports/test_rest_client.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest
66

77
from google.protobuf import json_format
8+
from google.protobuf.timestamp_pb2 import Timestamp
89
from httpx_sse import EventSource, ServerSentEvent
910

1011
from a2a.client import create_text_message_object
@@ -16,16 +17,16 @@
1617
AgentCard,
1718
AgentInterface,
1819
CancelTaskRequest,
19-
TaskPushNotificationConfig,
2020
DeleteTaskPushNotificationConfigRequest,
2121
GetExtendedAgentCardRequest,
2222
GetTaskPushNotificationConfigRequest,
2323
GetTaskRequest,
2424
ListTaskPushNotificationConfigsRequest,
2525
ListTasksRequest,
26-
Message,
2726
SendMessageRequest,
2827
SubscribeToTaskRequest,
28+
TaskPushNotificationConfig,
29+
TaskState,
2930
)
3031
from a2a.utils.constants import TransportProtocol
3132
from a2a.utils.errors import JSON_RPC_ERROR_CODE_MAP
@@ -175,6 +176,47 @@ async def test_send_message_with_timeout_context(
175176
assert 'timeout' in kwargs
176177
assert kwargs['timeout'] == httpx.Timeout(10.0)
177178

179+
@pytest.mark.asyncio
180+
async def test_url_serialization(
181+
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock
182+
):
183+
"""Test that query parameters are correctly serialized to the URL."""
184+
client = RestTransport(
185+
httpx_client=mock_httpx_client,
186+
agent_card=mock_agent_card,
187+
url='http://agent.example.com/api',
188+
)
189+
190+
timestamp = Timestamp()
191+
timestamp.FromJsonString('2024-03-09T16:00:00Z')
192+
193+
request = ListTasksRequest(
194+
tenant='my-tenant',
195+
status=TaskState.TASK_STATE_WORKING,
196+
include_artifacts=True,
197+
status_timestamp_after=timestamp,
198+
)
199+
200+
# Use real build_request to get actual URL serialization
201+
mock_httpx_client.build_request.side_effect = (
202+
httpx.AsyncClient().build_request
203+
)
204+
mock_httpx_client.send.return_value = AsyncMock(
205+
spec=httpx.Response, status_code=200, json=lambda: {'tasks': []}
206+
)
207+
208+
await client.list_tasks(request=request)
209+
210+
mock_httpx_client.send.assert_called_once()
211+
sent_request = mock_httpx_client.send.call_args[0][0]
212+
213+
# Check decoded query parameters for spec compliance
214+
params = sent_request.url.params
215+
assert params['status'] == 'TASK_STATE_WORKING'
216+
assert params['includeArtifacts'] == 'true'
217+
assert params['statusTimestampAfter'] == '2024-03-09T16:00:00Z'
218+
assert 'tenant' not in params
219+
178220

179221
class TestRestTransportExtensions:
180222
@pytest.mark.asyncio
@@ -616,7 +658,7 @@ async def test_rest_get_task_prepend_empty_tenant(
616658

617659
# 3. Verify the URL
618660
args, _ = mock_httpx_client.build_request.call_args
619-
assert args[1] == f'http://agent.example.com/api/tasks/task-123'
661+
assert args[1] == 'http://agent.example.com/api/tasks/task-123'
620662

621663
@pytest.mark.parametrize(
622664
'method_name, request_obj, expected_path',

0 commit comments

Comments
 (0)