Skip to content

Commit d77cd68

Browse files
authored
fix: rely on agent executor implementation for stream termination (#988)
`active_task.py` already contains agent executor behavior validation, do not terminate the stream so that those errors can be raised, tests are updated to cover invalid behavior conditions.
1 parent 25e2a7d commit d77cd68

2 files changed

Lines changed: 180 additions & 4 deletions

File tree

src/a2a/server/request_handlers/default_request_handler_v2.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -271,11 +271,17 @@ async def on_message_send( # noqa: D102
271271
):
272272
self._validate_task_id_match(task_id, event.id)
273273
result = event
274+
# DO break here as it's "return_immediately".
275+
# AgentExecutor will continue to run in the background.
274276
break
275277

276278
if isinstance(event, Message):
277279
result = event
278-
break
280+
# Do NOT break here as Message is supposed to be the only
281+
# event in "Message-only" interaction.
282+
# ActiveTask consumer (see active_task.py) validates the event
283+
# stream and raises InvalidAgentResponseError if more events are
284+
# pushed after a Message.
279285

280286
if result is None:
281287
logger.debug('Missing result for task %s', request_context.task_id)
@@ -311,15 +317,18 @@ async def on_message_send_stream( # noqa: D102
311317
request=request_context,
312318
include_initial_task=False,
313319
):
320+
# Do NOT break here as we rely on AgentExecutor to yield control.
321+
# ActiveTask consumer (see active_task.py) validates the event
322+
# stream and raises InvalidAgentResponseError on misbehaving agents:
323+
# - an event after a Message
324+
# - Message after entering task mode
325+
# - an event after a terminal state
314326
if isinstance(event, Task):
315327
self._validate_task_id_match(task_id, event.id)
316328
yield apply_history_length(event, params.configuration)
317329
else:
318330
yield event
319331

320-
if isinstance(event, Message):
321-
break
322-
323332
@validate_request_params
324333
@validate(
325334
lambda self: self._agent_card.capabilities.push_notifications,

tests/server/request_handlers/test_default_request_handler_v2.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
)
2929
from a2a.types import (
3030
InternalError,
31+
InvalidAgentResponseError,
3132
InvalidParamsError,
3233
TaskNotFoundError,
3334
PushNotificationNotSupportedError,
@@ -1244,3 +1245,169 @@ async def test_on_message_send_with_push_notification():
12441245
push_store.set_info.assert_awaited_once_with(
12451246
result.id, push_config, context
12461247
)
1248+
1249+
1250+
class MultipleMessagesAgentExecutor(AgentExecutor):
1251+
"""Misbehaving agent that yields more than one Message."""
1252+
1253+
async def execute(self, context: RequestContext, event_queue: EventQueue):
1254+
await event_queue.enqueue_event(
1255+
new_text_message('first', role=Role.ROLE_AGENT)
1256+
)
1257+
await event_queue.enqueue_event(
1258+
new_text_message('second', role=Role.ROLE_AGENT)
1259+
)
1260+
1261+
async def cancel(self, context: RequestContext, event_queue: EventQueue):
1262+
pass
1263+
1264+
1265+
class MessageAfterTaskEventAgentExecutor(AgentExecutor):
1266+
"""Misbehaving agent that yields a task-mode event then a Message."""
1267+
1268+
async def execute(self, context: RequestContext, event_queue: EventQueue):
1269+
task = new_task_from_user_message(context.message)
1270+
await event_queue.enqueue_event(task)
1271+
updater = TaskUpdater(event_queue, task.id, task.context_id)
1272+
await updater.update_status(TaskState.TASK_STATE_WORKING)
1273+
await event_queue.enqueue_event(
1274+
new_text_message('stray message', role=Role.ROLE_AGENT)
1275+
)
1276+
1277+
async def cancel(self, context: RequestContext, event_queue: EventQueue):
1278+
pass
1279+
1280+
1281+
class TaskEventAfterMessageAgentExecutor(AgentExecutor):
1282+
"""Misbehaving agent that yields a Message and then a task-mode event."""
1283+
1284+
async def execute(self, context: RequestContext, event_queue: EventQueue):
1285+
await event_queue.enqueue_event(
1286+
new_text_message('only message', role=Role.ROLE_AGENT)
1287+
)
1288+
await event_queue.enqueue_event(
1289+
TaskStatusUpdateEvent(
1290+
task_id=str(context.task_id or ''),
1291+
context_id=str(context.context_id or ''),
1292+
status=TaskStatus(state=TaskState.TASK_STATE_WORKING),
1293+
)
1294+
)
1295+
1296+
async def cancel(self, context: RequestContext, event_queue: EventQueue):
1297+
pass
1298+
1299+
1300+
class EventAfterTerminalStateAgentExecutor(AgentExecutor):
1301+
"""Misbehaving agent that yields an event after reaching a terminal state."""
1302+
1303+
async def execute(self, context: RequestContext, event_queue: EventQueue):
1304+
task = new_task_from_user_message(context.message)
1305+
await event_queue.enqueue_event(task)
1306+
updater = TaskUpdater(event_queue, task.id, task.context_id)
1307+
await updater.complete()
1308+
await event_queue.enqueue_event(
1309+
new_text_message('after terminal', role=Role.ROLE_AGENT)
1310+
)
1311+
1312+
async def cancel(self, context: RequestContext, event_queue: EventQueue):
1313+
pass
1314+
1315+
1316+
@pytest.mark.asyncio
1317+
@pytest.mark.timeout(1)
1318+
async def test_on_message_send_stream_rejects_multiple_messages():
1319+
"""Stream surfaces InvalidAgentResponseError when the agent yields a
1320+
second Message after the first one (see comment in on_message_send_stream)."""
1321+
request_handler = DefaultRequestHandlerV2(
1322+
agent_executor=MultipleMessagesAgentExecutor(),
1323+
task_store=InMemoryTaskStore(),
1324+
agent_card=create_default_agent_card(),
1325+
)
1326+
params = SendMessageRequest(
1327+
message=Message(
1328+
role=Role.ROLE_USER,
1329+
message_id='msg_multi_stream',
1330+
parts=[Part(text='Hi')],
1331+
)
1332+
)
1333+
with pytest.raises(InvalidAgentResponseError, match='Multiple Message'):
1334+
async for _ in request_handler.on_message_send_stream(
1335+
params, create_server_call_context()
1336+
):
1337+
pass
1338+
1339+
1340+
@pytest.mark.asyncio
1341+
@pytest.mark.timeout(1)
1342+
async def test_on_message_send_stream_rejects_message_after_task_event():
1343+
"""Stream surfaces InvalidAgentResponseError when the agent yields a
1344+
Message after entering task mode (see comment in on_message_send_stream)."""
1345+
request_handler = DefaultRequestHandlerV2(
1346+
agent_executor=MessageAfterTaskEventAgentExecutor(),
1347+
task_store=InMemoryTaskStore(),
1348+
agent_card=create_default_agent_card(),
1349+
)
1350+
params = SendMessageRequest(
1351+
message=Message(
1352+
role=Role.ROLE_USER,
1353+
message_id='msg_after_task_stream',
1354+
parts=[Part(text='Hi')],
1355+
)
1356+
)
1357+
with pytest.raises(
1358+
InvalidAgentResponseError, match='Message object in task mode'
1359+
):
1360+
async for _ in request_handler.on_message_send_stream(
1361+
params, create_server_call_context()
1362+
):
1363+
pass
1364+
1365+
1366+
@pytest.mark.asyncio
1367+
@pytest.mark.timeout(1)
1368+
async def test_on_message_send_stream_rejects_task_event_after_message():
1369+
"""Stream surfaces InvalidAgentResponseError when the agent yields a
1370+
task-mode event after a Message (see comment in on_message_send_stream)."""
1371+
request_handler = DefaultRequestHandlerV2(
1372+
agent_executor=TaskEventAfterMessageAgentExecutor(),
1373+
task_store=InMemoryTaskStore(),
1374+
agent_card=create_default_agent_card(),
1375+
)
1376+
params = SendMessageRequest(
1377+
message=Message(
1378+
role=Role.ROLE_USER,
1379+
message_id='msg_then_task_stream',
1380+
parts=[Part(text='Hi')],
1381+
)
1382+
)
1383+
with pytest.raises(InvalidAgentResponseError, match='in message mode'):
1384+
async for _ in request_handler.on_message_send_stream(
1385+
params, create_server_call_context()
1386+
):
1387+
pass
1388+
1389+
1390+
@pytest.mark.asyncio
1391+
@pytest.mark.timeout(1)
1392+
async def test_on_message_send_stream_rejects_event_after_terminal_state():
1393+
"""Stream surfaces InvalidAgentResponseError when the agent yields an event
1394+
after reaching a terminal state (see comment in on_message_send_stream)."""
1395+
request_handler = DefaultRequestHandlerV2(
1396+
agent_executor=EventAfterTerminalStateAgentExecutor(),
1397+
task_store=InMemoryTaskStore(),
1398+
agent_card=create_default_agent_card(),
1399+
)
1400+
params = SendMessageRequest(
1401+
message=Message(
1402+
role=Role.ROLE_USER,
1403+
message_id='msg_after_terminal_stream',
1404+
parts=[Part(text='Hi')],
1405+
)
1406+
)
1407+
with pytest.raises(
1408+
InvalidAgentResponseError, match='Message object in task mode'
1409+
):
1410+
async for _ in request_handler.on_message_send_stream(
1411+
params, create_server_call_context()
1412+
):
1413+
pass

0 commit comments

Comments
 (0)