Skip to content

Commit d1ee922

Browse files
committed
test: improve test_end_to_end.py
- Test artifacts. - Add more assertions for streaming: validate all events. - Fix non-streaming tests which were actually streaming.
1 parent 8807957 commit d1ee922

1 file changed

Lines changed: 26 additions & 11 deletions

File tree

tests/integration/test_end_to_end.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ async def execute(self, context: RequestContext, event_queue: EventQueue):
4141
)
4242
await task_updater.update_status(TaskState.TASK_STATE_SUBMITTED)
4343
await task_updater.update_status(TaskState.TASK_STATE_WORKING)
44+
await task_updater.add_artifact(
45+
parts=[Part(text='artifact content')], name='test-artifact'
46+
)
4447
await task_updater.update_status(
4548
TaskState.TASK_STATE_COMPLETED,
4649
message=task_updater.new_agent_message([Part(text='done')]),
@@ -167,7 +170,7 @@ async def grpc_setup(
167170

168171
factory = ClientFactory(
169172
config=ClientConfig(
170-
grpc_channel_factory=lambda url: grpc.aio.insecure_channel(url),
173+
grpc_channel_factory=grpc.aio.insecure_channel,
171174
supported_protocol_bindings=[TransportProtocol.GRPC],
172175
)
173176
)
@@ -215,6 +218,9 @@ async def test_end_to_end_send_message_blocking(transport_setups):
215218
response, _ = events[0]
216219
assert response.task.id
217220
assert response.task.status.state == TaskState.TASK_STATE_COMPLETED
221+
assert len(response.task.artifacts) == 1
222+
assert response.task.artifacts[0].name == 'test-artifact'
223+
assert response.task.artifacts[0].parts[0].text == 'artifact content'
218224

219225

220226
@pytest.mark.asyncio
@@ -255,16 +261,26 @@ async def test_end_to_end_send_message_streaming(transport_setups):
255261
event async for event, _ in client.send_message(request=message_to_send)
256262
]
257263

258-
expected_states = [
259-
TaskState.TASK_STATE_SUBMITTED,
260-
TaskState.TASK_STATE_WORKING,
261-
TaskState.TASK_STATE_COMPLETED,
264+
expected_events = [
265+
('status_update', TaskState.TASK_STATE_SUBMITTED),
266+
('status_update', TaskState.TASK_STATE_WORKING),
267+
('artifact_update', None),
268+
('status_update', TaskState.TASK_STATE_COMPLETED),
262269
]
263270

264-
assert len(events) == len(expected_states)
265-
for event, expected_state in zip(events, expected_states):
266-
assert event.HasField('status_update')
267-
assert event.status_update.status.state == expected_state
271+
assert len(events) == len(expected_events)
272+
for event, (expected_type, expected_state) in zip(
273+
events, expected_events, strict=True
274+
):
275+
assert event.HasField(expected_type)
276+
if expected_type == 'status_update':
277+
assert event.status_update.status.state == expected_state
278+
elif expected_type == 'artifact_update':
279+
assert event.artifact_update.artifact.name == 'test-artifact'
280+
assert (
281+
event.artifact_update.artifact.parts[0].text
282+
== 'artifact content'
283+
)
268284

269285

270286
@pytest.mark.asyncio
@@ -328,8 +344,7 @@ async def test_end_to_end_list_tasks(transport_setups):
328344
assert list_response.total_size == total_items
329345
assert list_response.page_size == page_size
330346

331-
for task in list_response.tasks:
332-
actual_task_ids.append(task.id)
347+
actual_task_ids.extend([task.id for task in list_response.tasks])
333348

334349
token = list_response.next_page_token
335350

0 commit comments

Comments
 (0)