Skip to content

Commit 7998a26

Browse files
authored
test: test history and TASK_STATE_INPUT_REQUIRED in test_end_to_end.py (#745)
Add `task.history` assertions and test `TASK_STATE_INPUT_REQUIRED`. **Note:** tests use `get_task` API call in non-blocking tests for assertions as `task` returned from `Client` and maintained by `ClientTaskManager` can't be trusted and handles history in a different way compared to the server (see #734).
1 parent dce3650 commit 7998a26

1 file changed

Lines changed: 206 additions & 29 deletions

File tree

tests/integration/test_end_to_end.py

Lines changed: 206 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,90 @@
3232
from a2a.utils import TransportProtocol
3333

3434

35+
def assert_message_matches(message, expected_role, expected_text):
36+
assert message.role == expected_role
37+
assert message.parts[0].text == expected_text
38+
39+
40+
def assert_history_matches(history, expected_history):
41+
assert len(history) == len(expected_history)
42+
for msg, (expected_role, expected_text) in zip(
43+
history, expected_history, strict=True
44+
):
45+
assert_message_matches(msg, expected_role, expected_text)
46+
47+
48+
def assert_artifacts_match(artifacts, expected_artifacts):
49+
assert len(artifacts) == len(expected_artifacts)
50+
for artifact, (expected_name, expected_text) in zip(
51+
artifacts, expected_artifacts, strict=True
52+
):
53+
assert artifact.name == expected_name
54+
assert artifact.parts[0].text == expected_text
55+
56+
57+
def assert_events_match(events, expected_events):
58+
assert len(events) == len(expected_events)
59+
for (event, _), (expected_type, expected_val) in zip(
60+
events, expected_events, strict=True
61+
):
62+
assert event.HasField(expected_type)
63+
if expected_type == 'status_update':
64+
assert event.status_update.status.state == expected_val
65+
elif expected_type == 'artifact_update':
66+
if expected_val is not None:
67+
assert_artifacts_match(
68+
[event.artifact_update.artifact],
69+
expected_val,
70+
)
71+
else:
72+
raise ValueError(f'Unexpected event type: {expected_type}')
73+
74+
3575
class MockAgentExecutor(AgentExecutor):
3676
async def execute(self, context: RequestContext, event_queue: EventQueue):
3777
task_updater = TaskUpdater(
3878
event_queue,
3979
context.task_id,
4080
context.context_id,
4181
)
42-
await task_updater.update_status(TaskState.TASK_STATE_SUBMITTED)
43-
await task_updater.update_status(TaskState.TASK_STATE_WORKING)
44-
await task_updater.add_artifact(
45-
parts=[Part(text='artifact content')], name='test-artifact'
82+
user_input = context.get_user_input()
83+
84+
is_input_required_resumption = (
85+
context.current_task is not None
86+
and context.current_task.status.state
87+
== TaskState.TASK_STATE_INPUT_REQUIRED
4688
)
89+
90+
if not is_input_required_resumption:
91+
await task_updater.update_status(
92+
TaskState.TASK_STATE_SUBMITTED,
93+
message=task_updater.new_agent_message(
94+
[Part(text='task submitted')]
95+
),
96+
)
97+
4798
await task_updater.update_status(
48-
TaskState.TASK_STATE_COMPLETED,
49-
message=task_updater.new_agent_message([Part(text='done')]),
99+
TaskState.TASK_STATE_WORKING,
100+
message=task_updater.new_agent_message([Part(text='task working')]),
50101
)
51102

103+
if user_input == 'Need input':
104+
await task_updater.update_status(
105+
TaskState.TASK_STATE_INPUT_REQUIRED,
106+
message=task_updater.new_agent_message(
107+
[Part(text='Please provide input')]
108+
),
109+
)
110+
else:
111+
await task_updater.add_artifact(
112+
parts=[Part(text='artifact content')], name='test-artifact'
113+
)
114+
await task_updater.update_status(
115+
TaskState.TASK_STATE_COMPLETED,
116+
message=task_updater.new_agent_message([Part(text='done')]),
117+
)
118+
52119
async def cancel(self, context: RequestContext, event_queue: EventQueue):
53120
raise NotImplementedError('Cancellation is not supported')
54121

@@ -218,9 +285,18 @@ async def test_end_to_end_send_message_blocking(transport_setups):
218285
response, _ = events[0]
219286
assert response.task.id
220287
assert response.task.status.state == TaskState.TASK_STATE_COMPLETED
221-
assert len(response.task.artifacts) == 1
222-
assert response.task.artifacts[0].name == 'test-artifact'
223-
assert response.task.artifacts[0].parts[0].text == 'artifact content'
288+
assert_artifacts_match(
289+
response.task.artifacts,
290+
[('test-artifact', 'artifact content')],
291+
)
292+
assert_history_matches(
293+
response.task.history,
294+
[
295+
(Role.ROLE_USER, 'Run dummy agent!'),
296+
(Role.ROLE_AGENT, 'task submitted'),
297+
(Role.ROLE_AGENT, 'task working'),
298+
],
299+
)
224300

225301

226302
@pytest.mark.asyncio
@@ -245,6 +321,12 @@ async def test_end_to_end_send_message_non_blocking(transport_setups):
245321
response, _ = events[0]
246322
assert response.task.id
247323
assert response.task.status.state == TaskState.TASK_STATE_SUBMITTED
324+
assert_history_matches(
325+
response.task.history,
326+
[
327+
(Role.ROLE_USER, 'Run dummy agent!'),
328+
],
329+
)
248330

249331

250332
@pytest.mark.asyncio
@@ -258,29 +340,30 @@ async def test_end_to_end_send_message_streaming(transport_setups):
258340
)
259341

260342
events = [
261-
event async for event, _ in client.send_message(request=message_to_send)
343+
event async for event in client.send_message(request=message_to_send)
262344
]
263345

264-
expected_events = [
265-
('status_update', TaskState.TASK_STATE_SUBMITTED),
266-
('status_update', TaskState.TASK_STATE_WORKING),
267-
('artifact_update', None),
268-
('status_update', TaskState.TASK_STATE_COMPLETED),
269-
]
346+
assert_events_match(
347+
events,
348+
[
349+
('status_update', TaskState.TASK_STATE_SUBMITTED),
350+
('status_update', TaskState.TASK_STATE_WORKING),
351+
('artifact_update', [('test-artifact', 'artifact content')]),
352+
('status_update', TaskState.TASK_STATE_COMPLETED),
353+
],
354+
)
270355

271-
assert len(events) == len(expected_events)
272-
for event, (expected_type, expected_state) in zip(
273-
events, expected_events, strict=True
274-
):
275-
assert event.HasField(expected_type)
276-
if expected_type == 'status_update':
277-
assert event.status_update.status.state == expected_state
278-
elif expected_type == 'artifact_update':
279-
assert event.artifact_update.artifact.name == 'test-artifact'
280-
assert (
281-
event.artifact_update.artifact.parts[0].text
282-
== 'artifact content'
283-
)
356+
task = await client.get_task(request=GetTaskRequest(id=events[0][1].id))
357+
assert_history_matches(
358+
task.history,
359+
[
360+
(Role.ROLE_USER, 'Run dummy agent!'),
361+
(Role.ROLE_AGENT, 'task submitted'),
362+
(Role.ROLE_AGENT, 'task working'),
363+
],
364+
)
365+
assert task.status.state == TaskState.TASK_STATE_COMPLETED
366+
assert_message_matches(task.status.message, Role.ROLE_AGENT, 'done')
284367

285368

286369
@pytest.mark.asyncio
@@ -307,6 +390,14 @@ async def test_end_to_end_get_task(transport_setups):
307390
TaskState.TASK_STATE_WORKING,
308391
TaskState.TASK_STATE_COMPLETED,
309392
}
393+
assert_history_matches(
394+
retrieved_task.history,
395+
[
396+
(Role.ROLE_USER, 'Test Get Task'),
397+
(Role.ROLE_AGENT, 'task submitted'),
398+
(Role.ROLE_AGENT, 'task working'),
399+
],
400+
)
310401

311402

312403
@pytest.mark.asyncio
@@ -346,7 +437,93 @@ async def test_end_to_end_list_tasks(transport_setups):
346437

347438
actual_task_ids.extend([task.id for task in list_response.tasks])
348439

440+
for task in list_response.tasks:
441+
assert len(task.history) >= 1
442+
assert task.history[0].role == Role.ROLE_USER
443+
assert task.history[0].parts[0].text.startswith('Test List Tasks ')
444+
349445
token = list_response.next_page_token
350446

351447
assert len(actual_task_ids) == total_items
352448
assert sorted(actual_task_ids) == sorted(expected_task_ids)
449+
450+
451+
@pytest.mark.asyncio
452+
async def test_end_to_end_input_required(transport_setups):
453+
client = transport_setups.client
454+
455+
message_to_send = Message(
456+
role=Role.ROLE_USER,
457+
message_id='msg-e2e-input-req-1',
458+
parts=[Part(text='Need input')],
459+
)
460+
461+
events = [
462+
event async for event in client.send_message(request=message_to_send)
463+
]
464+
465+
assert_events_match(
466+
events,
467+
[
468+
('status_update', TaskState.TASK_STATE_SUBMITTED),
469+
('status_update', TaskState.TASK_STATE_WORKING),
470+
('status_update', TaskState.TASK_STATE_INPUT_REQUIRED),
471+
],
472+
)
473+
474+
task = await client.get_task(request=GetTaskRequest(id=events[0][1].id))
475+
476+
assert task.status.state == TaskState.TASK_STATE_INPUT_REQUIRED
477+
assert_history_matches(
478+
task.history,
479+
[
480+
(Role.ROLE_USER, 'Need input'),
481+
(Role.ROLE_AGENT, 'task submitted'),
482+
(Role.ROLE_AGENT, 'task working'),
483+
],
484+
)
485+
assert_message_matches(
486+
task.status.message, Role.ROLE_AGENT, 'Please provide input'
487+
)
488+
489+
# Follow-up message
490+
follow_up_message = Message(
491+
task_id=task.id,
492+
role=Role.ROLE_USER,
493+
message_id='msg-e2e-input-req-2',
494+
parts=[Part(text='Here is the input')],
495+
)
496+
497+
follow_up_events = [
498+
event async for event in client.send_message(request=follow_up_message)
499+
]
500+
501+
assert_events_match(
502+
follow_up_events,
503+
[
504+
('status_update', TaskState.TASK_STATE_WORKING),
505+
('artifact_update', [('test-artifact', 'artifact content')]),
506+
('status_update', TaskState.TASK_STATE_COMPLETED),
507+
],
508+
)
509+
510+
task = await client.get_task(request=GetTaskRequest(id=task.id))
511+
512+
assert task.status.state == TaskState.TASK_STATE_COMPLETED
513+
assert_artifacts_match(
514+
task.artifacts,
515+
[('test-artifact', 'artifact content')],
516+
)
517+
518+
assert_history_matches(
519+
task.history,
520+
[
521+
(Role.ROLE_USER, 'Need input'),
522+
(Role.ROLE_AGENT, 'task submitted'),
523+
(Role.ROLE_AGENT, 'task working'),
524+
(Role.ROLE_AGENT, 'Please provide input'),
525+
(Role.ROLE_USER, 'Here is the input'),
526+
(Role.ROLE_AGENT, 'task working'),
527+
],
528+
)
529+
assert_message_matches(task.status.message, Role.ROLE_AGENT, 'done')

0 commit comments

Comments
 (0)