Skip to content

Commit 0623015

Browse files
authored
fix: fix bad "list tasks" merge for JSON-RPC (#698)
Fixes #697 (comment), bad merge in #696. Cover "list tasks" in client-server integration tests which would prevent it. Re #559.
1 parent b5cfb1e commit 0623015

3 files changed

Lines changed: 94 additions & 12 deletions

File tree

src/a2a/server/request_handlers/jsonrpc_handler.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
GetTaskRequest,
2727
ListTaskPushNotificationConfigRequest,
2828
ListTasksRequest,
29-
ListTasksResponse,
3029
Message,
3130
SendMessageRequest,
3231
SendMessageResponse,
@@ -388,25 +387,27 @@ async def list_tasks(
388387
self,
389388
request: ListTasksRequest,
390389
context: ServerCallContext | None = None,
391-
) -> ListTasksResponse:
390+
) -> dict[str, Any]:
392391
"""Handles the 'tasks/list' JSON-RPC method.
393392
394393
Args:
395394
request: The incoming `ListTasksRequest` object.
396395
context: Context provided by the server.
397396
398397
Returns:
399-
A `ListTasksResponse` object containing the Task or a JSON-RPC error.
398+
A dict representing the JSON-RPC response.
400399
"""
400+
request_id = self._get_request_id(context)
401401
try:
402-
result = await self.request_handler.on_list_tasks(request, context)
403-
except ServerError:
404-
return ListTasksResponse(
405-
# This needs to be appropriately handled since error fields on proto messages
406-
# might be different from the old pydantic models
407-
# Ignoring proto error handling for now as it diverges from the current pattern
402+
response = await self.request_handler.on_list_tasks(
403+
request, context
404+
)
405+
result = MessageToDict(response, preserving_proto_field_name=False)
406+
return _build_success_response(request_id, result)
407+
except ServerError as e:
408+
return _build_error_response(
409+
request_id, e.error if e.error else InternalError()
408410
)
409-
return result
410411

411412
async def list_push_notification_config(
412413
self,

tests/integration/test_client_server_integration.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
TaskState,
4848
TaskStatus,
4949
TaskStatusUpdateEvent,
50+
ListTasksRequest,
51+
ListTasksResponse,
5052
)
5153
from cryptography.hazmat.primitives import asymmetric
5254
from cryptography.hazmat.primitives.asymmetric import ec
@@ -91,6 +93,11 @@
9193
status=TaskStatus(state=TaskState.TASK_STATE_WORKING),
9294
)
9395

96+
LIST_TASKS_RESPONSE = ListTasksResponse(
97+
tasks=[TASK_FROM_BLOCKING, GET_TASK_RESPONSE],
98+
next_page_token='page-2',
99+
)
100+
94101

95102
def create_key_provider(verification_key: Any):
96103
"""Creates a key provider function for testing."""
@@ -121,6 +128,7 @@ async def stream_side_effect(*args, **kwargs):
121128
# Configure other methods
122129
handler.on_get_task.return_value = GET_TASK_RESPONSE
123130
handler.on_cancel_task.return_value = CANCEL_TASK_RESPONSE
131+
handler.on_list_tasks.return_value = LIST_TASKS_RESPONSE
124132
handler.on_create_task_push_notification_config.return_value = (
125133
CALLBACK_CONFIG
126134
)
@@ -450,6 +458,57 @@ def channel_factory(address: str) -> Channel:
450458
await transport.close()
451459

452460

461+
@pytest.mark.asyncio
462+
@pytest.mark.parametrize(
463+
'transport_setup_fixture',
464+
[
465+
pytest.param('jsonrpc_setup', id='JSON-RPC'),
466+
pytest.param('rest_setup', id='REST'),
467+
],
468+
)
469+
async def test_http_transport_list_tasks(
470+
transport_setup_fixture: str, request
471+
) -> None:
472+
transport_setup: TransportSetup = request.getfixturevalue(
473+
transport_setup_fixture
474+
)
475+
transport = transport_setup.transport
476+
handler = transport_setup.handler
477+
478+
params = ListTasksRequest(page_size=10, page_token='page-1')
479+
result = await transport.list_tasks(request=params)
480+
481+
assert len(result.tasks) == 2
482+
assert result.next_page_token == 'page-2'
483+
handler.on_list_tasks.assert_awaited_once()
484+
485+
if hasattr(transport, 'close'):
486+
await transport.close()
487+
488+
489+
@pytest.mark.asyncio
490+
async def test_grpc_transport_list_tasks(
491+
grpc_server_and_handler: tuple[str, AsyncMock],
492+
agent_card: AgentCard,
493+
) -> None:
494+
server_address, handler = grpc_server_and_handler
495+
496+
def channel_factory(address: str) -> Channel:
497+
return grpc.aio.insecure_channel(address)
498+
499+
channel = channel_factory(server_address)
500+
transport = GrpcTransport(channel=channel, agent_card=agent_card)
501+
502+
params = ListTasksRequest(page_size=10, page_token='page-1')
503+
result = await transport.list_tasks(request=params)
504+
505+
assert len(result.tasks) == 2
506+
assert result.next_page_token == 'page-2'
507+
handler.on_list_tasks.assert_awaited_once()
508+
509+
await transport.close()
510+
511+
453512
@pytest.mark.asyncio
454513
@pytest.mark.parametrize(
455514
'transport_setup_fixture',

tests/server/request_handlers/test_jsonrpc_handler.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,30 @@ async def test_on_list_tasks_success(self) -> None:
190190
response = await handler.list_tasks(request, call_context)
191191

192192
request_handler.on_list_tasks.assert_awaited_once()
193-
self.assertIsInstance(response, ListTasksResponse)
194-
self.assertEqual(response, mock_result)
193+
self.assertIsInstance(response, dict)
194+
self.assertTrue(is_success_response(response))
195+
self.assertIn('tasks', response['result'])
196+
self.assertEqual(len(response['result']['tasks']), 2)
197+
self.assertEqual(response['result']['nextPageToken'], '123')
198+
199+
async def test_on_list_tasks_error(self) -> None:
200+
request_handler = AsyncMock(spec=DefaultRequestHandler)
201+
handler = JSONRPCHandler(self.mock_agent_card, request_handler)
202+
203+
request_handler.on_list_tasks.side_effect = ServerError(
204+
InternalError(message='DB down')
205+
)
206+
from a2a.types.a2a_pb2 import ListTasksRequest
207+
208+
request = ListTasksRequest(page_size=10)
209+
call_context = ServerCallContext(state={'request_id': '2'})
210+
211+
response = await handler.list_tasks(request, call_context)
212+
213+
request_handler.on_list_tasks.assert_awaited_once()
214+
self.assertIsInstance(response, dict)
215+
self.assertTrue(is_error_response(response))
216+
self.assertEqual(response['error']['message'], 'DB down')
195217

196218
async def test_on_cancel_task_success(self) -> None:
197219
mock_agent_executor = AsyncMock(spec=AgentExecutor)

0 commit comments

Comments
 (0)