3232from a2a .utils import TransportProtocol
3333
3434
35+ def assert_message_matches (message , expected_role , expected_text ):
36+ assert message .role == expected_role
37+ assert message .parts [0 ].text == expected_text
38+
39+
40+ def assert_history_matches (history , expected_history ):
41+ assert len (history ) == len (expected_history )
42+ for msg , (expected_role , expected_text ) in zip (
43+ history , expected_history , strict = True
44+ ):
45+ assert_message_matches (msg , expected_role , expected_text )
46+
47+
48+ def assert_artifacts_match (artifacts , expected_artifacts ):
49+ assert len (artifacts ) == len (expected_artifacts )
50+ for artifact , (expected_name , expected_text ) in zip (
51+ artifacts , expected_artifacts , strict = True
52+ ):
53+ assert artifact .name == expected_name
54+ assert artifact .parts [0 ].text == expected_text
55+
56+
57+ def assert_events_match (events , expected_events ):
58+ assert len (events ) == len (expected_events )
59+ for (event , _ ), (expected_type , expected_val ) in zip (
60+ events , expected_events , strict = True
61+ ):
62+ assert event .HasField (expected_type )
63+ if expected_type == 'status_update' :
64+ assert event .status_update .status .state == expected_val
65+ elif expected_type == 'artifact_update' :
66+ if expected_val is not None :
67+ assert_artifacts_match (
68+ [event .artifact_update .artifact ],
69+ expected_val ,
70+ )
71+ else :
72+ raise ValueError (f'Unexpected event type: { expected_type } ' )
73+
74+
3575class MockAgentExecutor (AgentExecutor ):
3676 async def execute (self , context : RequestContext , event_queue : EventQueue ):
3777 task_updater = TaskUpdater (
3878 event_queue ,
3979 context .task_id ,
4080 context .context_id ,
4181 )
42- await task_updater .update_status (TaskState .TASK_STATE_SUBMITTED )
43- await task_updater .update_status (TaskState .TASK_STATE_WORKING )
44- await task_updater .add_artifact (
45- parts = [Part (text = 'artifact content' )], name = 'test-artifact'
82+ user_input = context .get_user_input ()
83+
84+ is_input_required_resumption = (
85+ context .current_task is not None
86+ and context .current_task .status .state
87+ == TaskState .TASK_STATE_INPUT_REQUIRED
4688 )
89+
90+ if not is_input_required_resumption :
91+ await task_updater .update_status (
92+ TaskState .TASK_STATE_SUBMITTED ,
93+ message = task_updater .new_agent_message (
94+ [Part (text = 'task submitted' )]
95+ ),
96+ )
97+
4798 await task_updater .update_status (
48- TaskState .TASK_STATE_COMPLETED ,
49- message = task_updater .new_agent_message ([Part (text = 'done ' )]),
99+ TaskState .TASK_STATE_WORKING ,
100+ message = task_updater .new_agent_message ([Part (text = 'task working ' )]),
50101 )
51102
103+ if user_input == 'Need input' :
104+ await task_updater .update_status (
105+ TaskState .TASK_STATE_INPUT_REQUIRED ,
106+ message = task_updater .new_agent_message (
107+ [Part (text = 'Please provide input' )]
108+ ),
109+ )
110+ else :
111+ await task_updater .add_artifact (
112+ parts = [Part (text = 'artifact content' )], name = 'test-artifact'
113+ )
114+ await task_updater .update_status (
115+ TaskState .TASK_STATE_COMPLETED ,
116+ message = task_updater .new_agent_message ([Part (text = 'done' )]),
117+ )
118+
52119 async def cancel (self , context : RequestContext , event_queue : EventQueue ):
53120 raise NotImplementedError ('Cancellation is not supported' )
54121
@@ -218,9 +285,18 @@ async def test_end_to_end_send_message_blocking(transport_setups):
218285 response , _ = events [0 ]
219286 assert response .task .id
220287 assert response .task .status .state == TaskState .TASK_STATE_COMPLETED
221- assert len (response .task .artifacts ) == 1
222- assert response .task .artifacts [0 ].name == 'test-artifact'
223- assert response .task .artifacts [0 ].parts [0 ].text == 'artifact content'
288+ assert_artifacts_match (
289+ response .task .artifacts ,
290+ [('test-artifact' , 'artifact content' )],
291+ )
292+ assert_history_matches (
293+ response .task .history ,
294+ [
295+ (Role .ROLE_USER , 'Run dummy agent!' ),
296+ (Role .ROLE_AGENT , 'task submitted' ),
297+ (Role .ROLE_AGENT , 'task working' ),
298+ ],
299+ )
224300
225301
226302@pytest .mark .asyncio
@@ -245,6 +321,12 @@ async def test_end_to_end_send_message_non_blocking(transport_setups):
245321 response , _ = events [0 ]
246322 assert response .task .id
247323 assert response .task .status .state == TaskState .TASK_STATE_SUBMITTED
324+ assert_history_matches (
325+ response .task .history ,
326+ [
327+ (Role .ROLE_USER , 'Run dummy agent!' ),
328+ ],
329+ )
248330
249331
250332@pytest .mark .asyncio
@@ -258,29 +340,30 @@ async def test_end_to_end_send_message_streaming(transport_setups):
258340 )
259341
260342 events = [
261- event async for event , _ in client .send_message (request = message_to_send )
343+ event async for event in client .send_message (request = message_to_send )
262344 ]
263345
264- expected_events = [
265- ('status_update' , TaskState .TASK_STATE_SUBMITTED ),
266- ('status_update' , TaskState .TASK_STATE_WORKING ),
267- ('artifact_update' , None ),
268- ('status_update' , TaskState .TASK_STATE_COMPLETED ),
269- ]
346+ assert_events_match (
347+ events ,
348+ [
349+ ('status_update' , TaskState .TASK_STATE_SUBMITTED ),
350+ ('status_update' , TaskState .TASK_STATE_WORKING ),
351+ ('artifact_update' , [('test-artifact' , 'artifact content' )]),
352+ ('status_update' , TaskState .TASK_STATE_COMPLETED ),
353+ ],
354+ )
270355
271- assert len (events ) == len (expected_events )
272- for event , (expected_type , expected_state ) in zip (
273- events , expected_events , strict = True
274- ):
275- assert event .HasField (expected_type )
276- if expected_type == 'status_update' :
277- assert event .status_update .status .state == expected_state
278- elif expected_type == 'artifact_update' :
279- assert event .artifact_update .artifact .name == 'test-artifact'
280- assert (
281- event .artifact_update .artifact .parts [0 ].text
282- == 'artifact content'
283- )
356+ task = await client .get_task (request = GetTaskRequest (id = events [0 ][1 ].id ))
357+ assert_history_matches (
358+ task .history ,
359+ [
360+ (Role .ROLE_USER , 'Run dummy agent!' ),
361+ (Role .ROLE_AGENT , 'task submitted' ),
362+ (Role .ROLE_AGENT , 'task working' ),
363+ ],
364+ )
365+ assert task .status .state == TaskState .TASK_STATE_COMPLETED
366+ assert_message_matches (task .status .message , Role .ROLE_AGENT , 'done' )
284367
285368
286369@pytest .mark .asyncio
@@ -307,6 +390,14 @@ async def test_end_to_end_get_task(transport_setups):
307390 TaskState .TASK_STATE_WORKING ,
308391 TaskState .TASK_STATE_COMPLETED ,
309392 }
393+ assert_history_matches (
394+ retrieved_task .history ,
395+ [
396+ (Role .ROLE_USER , 'Test Get Task' ),
397+ (Role .ROLE_AGENT , 'task submitted' ),
398+ (Role .ROLE_AGENT , 'task working' ),
399+ ],
400+ )
310401
311402
312403@pytest .mark .asyncio
@@ -346,7 +437,93 @@ async def test_end_to_end_list_tasks(transport_setups):
346437
347438 actual_task_ids .extend ([task .id for task in list_response .tasks ])
348439
440+ for task in list_response .tasks :
441+ assert len (task .history ) >= 1
442+ assert task .history [0 ].role == Role .ROLE_USER
443+ assert task .history [0 ].parts [0 ].text .startswith ('Test List Tasks ' )
444+
349445 token = list_response .next_page_token
350446
351447 assert len (actual_task_ids ) == total_items
352448 assert sorted (actual_task_ids ) == sorted (expected_task_ids )
449+
450+
451+ @pytest .mark .asyncio
452+ async def test_end_to_end_input_required (transport_setups ):
453+ client = transport_setups .client
454+
455+ message_to_send = Message (
456+ role = Role .ROLE_USER ,
457+ message_id = 'msg-e2e-input-req-1' ,
458+ parts = [Part (text = 'Need input' )],
459+ )
460+
461+ events = [
462+ event async for event in client .send_message (request = message_to_send )
463+ ]
464+
465+ assert_events_match (
466+ events ,
467+ [
468+ ('status_update' , TaskState .TASK_STATE_SUBMITTED ),
469+ ('status_update' , TaskState .TASK_STATE_WORKING ),
470+ ('status_update' , TaskState .TASK_STATE_INPUT_REQUIRED ),
471+ ],
472+ )
473+
474+ task = await client .get_task (request = GetTaskRequest (id = events [0 ][1 ].id ))
475+
476+ assert task .status .state == TaskState .TASK_STATE_INPUT_REQUIRED
477+ assert_history_matches (
478+ task .history ,
479+ [
480+ (Role .ROLE_USER , 'Need input' ),
481+ (Role .ROLE_AGENT , 'task submitted' ),
482+ (Role .ROLE_AGENT , 'task working' ),
483+ ],
484+ )
485+ assert_message_matches (
486+ task .status .message , Role .ROLE_AGENT , 'Please provide input'
487+ )
488+
489+ # Follow-up message
490+ follow_up_message = Message (
491+ task_id = task .id ,
492+ role = Role .ROLE_USER ,
493+ message_id = 'msg-e2e-input-req-2' ,
494+ parts = [Part (text = 'Here is the input' )],
495+ )
496+
497+ follow_up_events = [
498+ event async for event in client .send_message (request = follow_up_message )
499+ ]
500+
501+ assert_events_match (
502+ follow_up_events ,
503+ [
504+ ('status_update' , TaskState .TASK_STATE_WORKING ),
505+ ('artifact_update' , [('test-artifact' , 'artifact content' )]),
506+ ('status_update' , TaskState .TASK_STATE_COMPLETED ),
507+ ],
508+ )
509+
510+ task = await client .get_task (request = GetTaskRequest (id = task .id ))
511+
512+ assert task .status .state == TaskState .TASK_STATE_COMPLETED
513+ assert_artifacts_match (
514+ task .artifacts ,
515+ [('test-artifact' , 'artifact content' )],
516+ )
517+
518+ assert_history_matches (
519+ task .history ,
520+ [
521+ (Role .ROLE_USER , 'Need input' ),
522+ (Role .ROLE_AGENT , 'task submitted' ),
523+ (Role .ROLE_AGENT , 'task working' ),
524+ (Role .ROLE_AGENT , 'Please provide input' ),
525+ (Role .ROLE_USER , 'Here is the input' ),
526+ (Role .ROLE_AGENT , 'task working' ),
527+ ],
528+ )
529+ assert_message_matches (task .status .message , Role .ROLE_AGENT , 'done' )
0 commit comments