Skip to content

Commit dce3650

Browse files
authored
test: improve test_end_to_end.py (#738)
# Description - Test artifacts. - Add more assertions for streaming: validate all events. - Fix non-streaming tests which were actually streaming.
1 parent 59b8871 commit dce3650

1 file changed

Lines changed: 66 additions & 49 deletions

File tree

tests/integration/test_end_to_end.py

Lines changed: 66 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from collections.abc import AsyncGenerator
2-
from typing import NamedTuple, cast
2+
from typing import NamedTuple
33

44
import grpc
55
import httpx
66
import pytest
77
import pytest_asyncio
88

99
from a2a.client.base_client import BaseClient
10-
from a2a.client.client import Client, ClientConfig
10+
from a2a.client.client import ClientConfig
1111
from a2a.client.client_factory import ClientFactory
1212
from a2a.server.agent_execution import AgentExecutor, RequestContext
1313
from a2a.server.apps import A2AFastAPIApplication, A2ARESTFastAPIApplication
@@ -26,7 +26,6 @@
2626
Part,
2727
Role,
2828
SendMessageConfiguration,
29-
SendMessageRequest,
3029
TaskState,
3130
a2a_pb2_grpc,
3231
)
@@ -42,6 +41,9 @@ async def execute(self, context: RequestContext, event_queue: EventQueue):
4241
)
4342
await task_updater.update_status(TaskState.TASK_STATE_SUBMITTED)
4443
await task_updater.update_status(TaskState.TASK_STATE_WORKING)
44+
await task_updater.add_artifact(
45+
parts=[Part(text='artifact content')], name='test-artifact'
46+
)
4547
await task_updater.update_status(
4648
TaskState.TASK_STATE_COMPLETED,
4749
message=task_updater.new_agent_message([Part(text='done')]),
@@ -80,7 +82,7 @@ def agent_card() -> AgentCard:
8082
)
8183

8284

83-
class TransportSetup(NamedTuple):
85+
class ClientSetup(NamedTuple):
8486
"""Holds the client and task_store for a given test."""
8587

8688
client: BaseClient
@@ -99,7 +101,7 @@ def base_e2e_setup():
99101

100102

101103
@pytest.fixture
102-
def rest_setup(agent_card, base_e2e_setup) -> TransportSetup:
104+
def rest_setup(agent_card, base_e2e_setup) -> ClientSetup:
103105
task_store, handler = base_e2e_setup
104106
app_builder = A2ARESTFastAPIApplication(agent_card, handler)
105107
app = app_builder.build()
@@ -112,15 +114,15 @@ def rest_setup(agent_card, base_e2e_setup) -> TransportSetup:
112114
supported_protocol_bindings=[TransportProtocol.HTTP_JSON],
113115
)
114116
)
115-
client = cast(BaseClient, factory.create(agent_card))
116-
return TransportSetup(
117+
client = factory.create(agent_card)
118+
return ClientSetup(
117119
client=client,
118120
task_store=task_store,
119121
)
120122

121123

122124
@pytest.fixture
123-
def jsonrpc_setup(agent_card, base_e2e_setup) -> TransportSetup:
125+
def jsonrpc_setup(agent_card, base_e2e_setup) -> ClientSetup:
124126
task_store, handler = base_e2e_setup
125127
app_builder = A2AFastAPIApplication(
126128
agent_card, handler, extended_agent_card=agent_card
@@ -135,8 +137,8 @@ def jsonrpc_setup(agent_card, base_e2e_setup) -> TransportSetup:
135137
supported_protocol_bindings=[TransportProtocol.JSONRPC],
136138
)
137139
)
138-
client = cast(BaseClient, factory.create(agent_card))
139-
return TransportSetup(
140+
client = factory.create(agent_card)
141+
return ClientSetup(
140142
client=client,
141143
task_store=task_store,
142144
)
@@ -145,7 +147,7 @@ def jsonrpc_setup(agent_card, base_e2e_setup) -> TransportSetup:
145147
@pytest_asyncio.fixture
146148
async def grpc_setup(
147149
agent_card: AgentCard, base_e2e_setup
148-
) -> AsyncGenerator[TransportSetup, None]:
150+
) -> AsyncGenerator[ClientSetup, None]:
149151
task_store, handler = base_e2e_setup
150152
server = grpc.aio.server()
151153
port = server.add_insecure_port('[::]:0')
@@ -168,12 +170,12 @@ async def grpc_setup(
168170

169171
factory = ClientFactory(
170172
config=ClientConfig(
171-
grpc_channel_factory=lambda url: grpc.aio.insecure_channel(url),
173+
grpc_channel_factory=grpc.aio.insecure_channel,
172174
supported_protocol_bindings=[TransportProtocol.GRPC],
173175
)
174176
)
175-
client = cast(BaseClient, factory.create(grpc_agent_card))
176-
yield TransportSetup(
177+
client = factory.create(grpc_agent_card)
178+
yield ClientSetup(
177179
client=client,
178180
task_store=task_store,
179181
)
@@ -189,14 +191,15 @@ async def grpc_setup(
189191
pytest.param('grpc_setup', id='gRPC'),
190192
]
191193
)
192-
def transport_setups(request) -> TransportSetup:
194+
def transport_setups(request) -> ClientSetup:
193195
"""Parametrized fixture that runs tests against all supported transports."""
194196
return request.getfixturevalue(request.param)
195197

196198

197199
@pytest.mark.asyncio
198200
async def test_end_to_end_send_message_blocking(transport_setups):
199201
client = transport_setups.client
202+
client._config.streaming = False
200203

201204
message_to_send = Message(
202205
role=Role.ROLE_USER,
@@ -211,16 +214,19 @@ async def test_end_to_end_send_message_blocking(transport_setups):
211214
request=message_to_send, configuration=configuration
212215
)
213216
]
214-
response, task = events[-1]
215-
216-
assert task
217-
assert task.id
218-
assert task.status.state == TaskState.TASK_STATE_COMPLETED
217+
assert len(events) == 1
218+
response, _ = events[0]
219+
assert response.task.id
220+
assert response.task.status.state == TaskState.TASK_STATE_COMPLETED
221+
assert len(response.task.artifacts) == 1
222+
assert response.task.artifacts[0].name == 'test-artifact'
223+
assert response.task.artifacts[0].parts[0].text == 'artifact content'
219224

220225

221226
@pytest.mark.asyncio
222227
async def test_end_to_end_send_message_non_blocking(transport_setups):
223228
client = transport_setups.client
229+
client._config.streaming = False
224230

225231
message_to_send = Message(
226232
role=Role.ROLE_USER,
@@ -235,10 +241,10 @@ async def test_end_to_end_send_message_non_blocking(transport_setups):
235241
request=message_to_send, configuration=configuration
236242
)
237243
]
238-
response, task = events[-1]
239-
240-
assert task
241-
assert task.id
244+
assert len(events) == 1
245+
response, _ = events[0]
246+
assert response.task.id
247+
assert response.task.status.state == TaskState.TASK_STATE_SUBMITTED
242248

243249

244250
@pytest.mark.asyncio
@@ -252,20 +258,29 @@ async def test_end_to_end_send_message_streaming(transport_setups):
252258
)
253259

254260
events = [
255-
event async for event in client.send_message(request=message_to_send)
261+
event async for event, _ in client.send_message(request=message_to_send)
256262
]
257263

258-
assert len(events) > 0
259-
stream_response, task = events[-1]
264+
expected_events = [
265+
('status_update', TaskState.TASK_STATE_SUBMITTED),
266+
('status_update', TaskState.TASK_STATE_WORKING),
267+
('artifact_update', None),
268+
('status_update', TaskState.TASK_STATE_COMPLETED),
269+
]
260270

261-
assert stream_response.HasField('status_update')
262-
assert stream_response.status_update.task_id
263-
assert (
264-
stream_response.status_update.status.state
265-
== TaskState.TASK_STATE_COMPLETED
266-
)
267-
assert task
268-
assert task.status.state == TaskState.TASK_STATE_COMPLETED
271+
assert len(events) == len(expected_events)
272+
for event, (expected_type, expected_state) in zip(
273+
events, expected_events, strict=True
274+
):
275+
assert event.HasField(expected_type)
276+
if expected_type == 'status_update':
277+
assert event.status_update.status.state == expected_state
278+
elif expected_type == 'artifact_update':
279+
assert event.artifact_update.artifact.name == 'test-artifact'
280+
assert (
281+
event.artifact_update.artifact.parts[0].text
282+
== 'artifact content'
283+
)
269284

270285

271286
@pytest.mark.asyncio
@@ -301,21 +316,23 @@ async def test_end_to_end_list_tasks(transport_setups):
301316
total_items = 6
302317
page_size = 2
303318

319+
expected_task_ids = []
304320
for i in range(total_items):
305-
# We need to await the iterator to ensure request completes
306-
async for _ in client.send_message(
307-
request=Message(
308-
role=Role.ROLE_USER,
309-
message_id=f'msg-e2e-list-{i}',
310-
parts=[Part(text=f'Test List Tasks {i}')],
311-
),
312-
configuration=SendMessageConfiguration(blocking=False),
313-
):
314-
pass
321+
# One event is enough to get the task ID
322+
_, task = await anext(
323+
client.send_message(
324+
request=Message(
325+
role=Role.ROLE_USER,
326+
message_id=f'msg-e2e-list-{i}',
327+
parts=[Part(text=f'Test List Tasks {i}')],
328+
)
329+
)
330+
)
331+
expected_task_ids.append(task.id)
315332

316333
list_request = ListTasksRequest(page_size=page_size)
317334

318-
unique_task_ids = set()
335+
actual_task_ids = []
319336
token = None
320337

321338
while token != '':
@@ -327,9 +344,9 @@ async def test_end_to_end_list_tasks(transport_setups):
327344
assert list_response.total_size == total_items
328345
assert list_response.page_size == page_size
329346

330-
for task in list_response.tasks:
331-
unique_task_ids.add(task.id)
347+
actual_task_ids.extend([task.id for task in list_response.tasks])
332348

333349
token = list_response.next_page_token
334350

335-
assert len(unique_task_ids) == total_items
351+
assert len(actual_task_ids) == total_items
352+
assert sorted(actual_task_ids) == sorted(expected_task_ids)

0 commit comments

Comments
 (0)