11from collections .abc import AsyncGenerator
2- from typing import NamedTuple , cast
2+ from typing import NamedTuple
33
44import grpc
55import httpx
66import pytest
77import pytest_asyncio
88
99from a2a .client .base_client import BaseClient
10- from a2a .client .client import Client , ClientConfig
10+ from a2a .client .client import ClientConfig
1111from a2a .client .client_factory import ClientFactory
1212from a2a .server .agent_execution import AgentExecutor , RequestContext
1313from a2a .server .apps import A2AFastAPIApplication , A2ARESTFastAPIApplication
2626 Part ,
2727 Role ,
2828 SendMessageConfiguration ,
29- SendMessageRequest ,
3029 TaskState ,
3130 a2a_pb2_grpc ,
3231)
@@ -42,6 +41,9 @@ async def execute(self, context: RequestContext, event_queue: EventQueue):
4241 )
4342 await task_updater .update_status (TaskState .TASK_STATE_SUBMITTED )
4443 await task_updater .update_status (TaskState .TASK_STATE_WORKING )
44+ await task_updater .add_artifact (
45+ parts = [Part (text = 'artifact content' )], name = 'test-artifact'
46+ )
4547 await task_updater .update_status (
4648 TaskState .TASK_STATE_COMPLETED ,
4749 message = task_updater .new_agent_message ([Part (text = 'done' )]),
@@ -80,7 +82,7 @@ def agent_card() -> AgentCard:
8082 )
8183
8284
83- class TransportSetup (NamedTuple ):
85+ class ClientSetup (NamedTuple ):
8486 """Holds the client and task_store for a given test."""
8587
8688 client : BaseClient
@@ -99,7 +101,7 @@ def base_e2e_setup():
99101
100102
101103@pytest .fixture
102- def rest_setup (agent_card , base_e2e_setup ) -> TransportSetup :
104+ def rest_setup (agent_card , base_e2e_setup ) -> ClientSetup :
103105 task_store , handler = base_e2e_setup
104106 app_builder = A2ARESTFastAPIApplication (agent_card , handler )
105107 app = app_builder .build ()
@@ -112,15 +114,15 @@ def rest_setup(agent_card, base_e2e_setup) -> TransportSetup:
112114 supported_protocol_bindings = [TransportProtocol .HTTP_JSON ],
113115 )
114116 )
115- client = cast ( BaseClient , factory .create (agent_card ) )
116- return TransportSetup (
117+ client = factory .create (agent_card )
118+ return ClientSetup (
117119 client = client ,
118120 task_store = task_store ,
119121 )
120122
121123
122124@pytest .fixture
123- def jsonrpc_setup (agent_card , base_e2e_setup ) -> TransportSetup :
125+ def jsonrpc_setup (agent_card , base_e2e_setup ) -> ClientSetup :
124126 task_store , handler = base_e2e_setup
125127 app_builder = A2AFastAPIApplication (
126128 agent_card , handler , extended_agent_card = agent_card
@@ -135,8 +137,8 @@ def jsonrpc_setup(agent_card, base_e2e_setup) -> TransportSetup:
135137 supported_protocol_bindings = [TransportProtocol .JSONRPC ],
136138 )
137139 )
138- client = cast ( BaseClient , factory .create (agent_card ) )
139- return TransportSetup (
140+ client = factory .create (agent_card )
141+ return ClientSetup (
140142 client = client ,
141143 task_store = task_store ,
142144 )
@@ -145,7 +147,7 @@ def jsonrpc_setup(agent_card, base_e2e_setup) -> TransportSetup:
145147@pytest_asyncio .fixture
146148async def grpc_setup (
147149 agent_card : AgentCard , base_e2e_setup
148- ) -> AsyncGenerator [TransportSetup , None ]:
150+ ) -> AsyncGenerator [ClientSetup , None ]:
149151 task_store , handler = base_e2e_setup
150152 server = grpc .aio .server ()
151153 port = server .add_insecure_port ('[::]:0' )
@@ -168,12 +170,12 @@ async def grpc_setup(
168170
169171 factory = ClientFactory (
170172 config = ClientConfig (
171- grpc_channel_factory = lambda url : grpc .aio .insecure_channel ( url ) ,
173+ grpc_channel_factory = grpc .aio .insecure_channel ,
172174 supported_protocol_bindings = [TransportProtocol .GRPC ],
173175 )
174176 )
175- client = cast ( BaseClient , factory .create (grpc_agent_card ) )
176- yield TransportSetup (
177+ client = factory .create (grpc_agent_card )
178+ yield ClientSetup (
177179 client = client ,
178180 task_store = task_store ,
179181 )
@@ -189,14 +191,15 @@ async def grpc_setup(
189191 pytest .param ('grpc_setup' , id = 'gRPC' ),
190192 ]
191193)
192- def transport_setups (request ) -> TransportSetup :
194+ def transport_setups (request ) -> ClientSetup :
193195 """Parametrized fixture that runs tests against all supported transports."""
194196 return request .getfixturevalue (request .param )
195197
196198
197199@pytest .mark .asyncio
198200async def test_end_to_end_send_message_blocking (transport_setups ):
199201 client = transport_setups .client
202+ client ._config .streaming = False
200203
201204 message_to_send = Message (
202205 role = Role .ROLE_USER ,
@@ -211,16 +214,19 @@ async def test_end_to_end_send_message_blocking(transport_setups):
211214 request = message_to_send , configuration = configuration
212215 )
213216 ]
214- response , task = events [- 1 ]
215-
216- assert task
217- assert task .id
218- assert task .status .state == TaskState .TASK_STATE_COMPLETED
217+ assert len (events ) == 1
218+ response , _ = events [0 ]
219+ assert response .task .id
220+ assert response .task .status .state == TaskState .TASK_STATE_COMPLETED
221+ assert len (response .task .artifacts ) == 1
222+ assert response .task .artifacts [0 ].name == 'test-artifact'
223+ assert response .task .artifacts [0 ].parts [0 ].text == 'artifact content'
219224
220225
221226@pytest .mark .asyncio
222227async def test_end_to_end_send_message_non_blocking (transport_setups ):
223228 client = transport_setups .client
229+ client ._config .streaming = False
224230
225231 message_to_send = Message (
226232 role = Role .ROLE_USER ,
@@ -235,10 +241,10 @@ async def test_end_to_end_send_message_non_blocking(transport_setups):
235241 request = message_to_send , configuration = configuration
236242 )
237243 ]
238- response , task = events [ - 1 ]
239-
240- assert task
241- assert task .id
244+ assert len ( events ) == 1
245+ response , _ = events [ 0 ]
246+ assert response . task . id
247+ assert response . task .status . state == TaskState . TASK_STATE_SUBMITTED
242248
243249
244250@pytest .mark .asyncio
@@ -252,20 +258,29 @@ async def test_end_to_end_send_message_streaming(transport_setups):
252258 )
253259
254260 events = [
255- event async for event in client .send_message (request = message_to_send )
261+ event async for event , _ in client .send_message (request = message_to_send )
256262 ]
257263
258- assert len (events ) > 0
259- stream_response , task = events [- 1 ]
264+ expected_events = [
265+ ('status_update' , TaskState .TASK_STATE_SUBMITTED ),
266+ ('status_update' , TaskState .TASK_STATE_WORKING ),
267+ ('artifact_update' , None ),
268+ ('status_update' , TaskState .TASK_STATE_COMPLETED ),
269+ ]
260270
261- assert stream_response .HasField ('status_update' )
262- assert stream_response .status_update .task_id
263- assert (
264- stream_response .status_update .status .state
265- == TaskState .TASK_STATE_COMPLETED
266- )
267- assert task
268- assert task .status .state == TaskState .TASK_STATE_COMPLETED
271+ assert len (events ) == len (expected_events )
272+ for event , (expected_type , expected_state ) in zip (
273+ events , expected_events , strict = True
274+ ):
275+ assert event .HasField (expected_type )
276+ if expected_type == 'status_update' :
277+ assert event .status_update .status .state == expected_state
278+ elif expected_type == 'artifact_update' :
279+ assert event .artifact_update .artifact .name == 'test-artifact'
280+ assert (
281+ event .artifact_update .artifact .parts [0 ].text
282+ == 'artifact content'
283+ )
269284
270285
271286@pytest .mark .asyncio
@@ -301,21 +316,23 @@ async def test_end_to_end_list_tasks(transport_setups):
301316 total_items = 6
302317 page_size = 2
303318
319+ expected_task_ids = []
304320 for i in range (total_items ):
305- # We need to await the iterator to ensure request completes
306- async for _ in client .send_message (
307- request = Message (
308- role = Role .ROLE_USER ,
309- message_id = f'msg-e2e-list-{ i } ' ,
310- parts = [Part (text = f'Test List Tasks { i } ' )],
311- ),
312- configuration = SendMessageConfiguration (blocking = False ),
313- ):
314- pass
321+ # One event is enough to get the task ID
322+ _ , task = await anext (
323+ client .send_message (
324+ request = Message (
325+ role = Role .ROLE_USER ,
326+ message_id = f'msg-e2e-list-{ i } ' ,
327+ parts = [Part (text = f'Test List Tasks { i } ' )],
328+ )
329+ )
330+ )
331+ expected_task_ids .append (task .id )
315332
316333 list_request = ListTasksRequest (page_size = page_size )
317334
318- unique_task_ids = set ()
335+ actual_task_ids = []
319336 token = None
320337
321338 while token != '' :
@@ -327,9 +344,9 @@ async def test_end_to_end_list_tasks(transport_setups):
327344 assert list_response .total_size == total_items
328345 assert list_response .page_size == page_size
329346
330- for task in list_response .tasks :
331- unique_task_ids .add (task .id )
347+ actual_task_ids .extend ([task .id for task in list_response .tasks ])
332348
333349 token = list_response .next_page_token
334350
335- assert len (unique_task_ids ) == total_items
351+ assert len (actual_task_ids ) == total_items
352+ assert sorted (actual_task_ids ) == sorted (expected_task_ids )
0 commit comments