|
28 | 28 | ) |
29 | 29 | from a2a.types import ( |
30 | 30 | InternalError, |
| 31 | + InvalidAgentResponseError, |
31 | 32 | InvalidParamsError, |
32 | 33 | TaskNotFoundError, |
33 | 34 | PushNotificationNotSupportedError, |
@@ -1244,3 +1245,169 @@ async def test_on_message_send_with_push_notification(): |
1244 | 1245 | push_store.set_info.assert_awaited_once_with( |
1245 | 1246 | result.id, push_config, context |
1246 | 1247 | ) |
| 1248 | + |
| 1249 | + |
| 1250 | +class MultipleMessagesAgentExecutor(AgentExecutor): |
| 1251 | + """Misbehaving agent that yields more than one Message.""" |
| 1252 | + |
| 1253 | + async def execute(self, context: RequestContext, event_queue: EventQueue): |
| 1254 | + await event_queue.enqueue_event( |
| 1255 | + new_text_message('first', role=Role.ROLE_AGENT) |
| 1256 | + ) |
| 1257 | + await event_queue.enqueue_event( |
| 1258 | + new_text_message('second', role=Role.ROLE_AGENT) |
| 1259 | + ) |
| 1260 | + |
| 1261 | + async def cancel(self, context: RequestContext, event_queue: EventQueue): |
| 1262 | + pass |
| 1263 | + |
| 1264 | + |
| 1265 | +class MessageAfterTaskEventAgentExecutor(AgentExecutor): |
| 1266 | + """Misbehaving agent that yields a task-mode event then a Message.""" |
| 1267 | + |
| 1268 | + async def execute(self, context: RequestContext, event_queue: EventQueue): |
| 1269 | + task = new_task_from_user_message(context.message) |
| 1270 | + await event_queue.enqueue_event(task) |
| 1271 | + updater = TaskUpdater(event_queue, task.id, task.context_id) |
| 1272 | + await updater.update_status(TaskState.TASK_STATE_WORKING) |
| 1273 | + await event_queue.enqueue_event( |
| 1274 | + new_text_message('stray message', role=Role.ROLE_AGENT) |
| 1275 | + ) |
| 1276 | + |
| 1277 | + async def cancel(self, context: RequestContext, event_queue: EventQueue): |
| 1278 | + pass |
| 1279 | + |
| 1280 | + |
| 1281 | +class TaskEventAfterMessageAgentExecutor(AgentExecutor): |
| 1282 | + """Misbehaving agent that yields a Message and then a task-mode event.""" |
| 1283 | + |
| 1284 | + async def execute(self, context: RequestContext, event_queue: EventQueue): |
| 1285 | + await event_queue.enqueue_event( |
| 1286 | + new_text_message('only message', role=Role.ROLE_AGENT) |
| 1287 | + ) |
| 1288 | + await event_queue.enqueue_event( |
| 1289 | + TaskStatusUpdateEvent( |
| 1290 | + task_id=str(context.task_id or ''), |
| 1291 | + context_id=str(context.context_id or ''), |
| 1292 | + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), |
| 1293 | + ) |
| 1294 | + ) |
| 1295 | + |
| 1296 | + async def cancel(self, context: RequestContext, event_queue: EventQueue): |
| 1297 | + pass |
| 1298 | + |
| 1299 | + |
| 1300 | +class EventAfterTerminalStateAgentExecutor(AgentExecutor): |
| 1301 | + """Misbehaving agent that yields an event after reaching a terminal state.""" |
| 1302 | + |
| 1303 | + async def execute(self, context: RequestContext, event_queue: EventQueue): |
| 1304 | + task = new_task_from_user_message(context.message) |
| 1305 | + await event_queue.enqueue_event(task) |
| 1306 | + updater = TaskUpdater(event_queue, task.id, task.context_id) |
| 1307 | + await updater.complete() |
| 1308 | + await event_queue.enqueue_event( |
| 1309 | + new_text_message('after terminal', role=Role.ROLE_AGENT) |
| 1310 | + ) |
| 1311 | + |
| 1312 | + async def cancel(self, context: RequestContext, event_queue: EventQueue): |
| 1313 | + pass |
| 1314 | + |
| 1315 | + |
| 1316 | +@pytest.mark.asyncio |
| 1317 | +@pytest.mark.timeout(1) |
| 1318 | +async def test_on_message_send_stream_rejects_multiple_messages(): |
| 1319 | + """Stream surfaces InvalidAgentResponseError when the agent yields a |
| 1320 | + second Message after the first one (see comment in on_message_send_stream).""" |
| 1321 | + request_handler = DefaultRequestHandlerV2( |
| 1322 | + agent_executor=MultipleMessagesAgentExecutor(), |
| 1323 | + task_store=InMemoryTaskStore(), |
| 1324 | + agent_card=create_default_agent_card(), |
| 1325 | + ) |
| 1326 | + params = SendMessageRequest( |
| 1327 | + message=Message( |
| 1328 | + role=Role.ROLE_USER, |
| 1329 | + message_id='msg_multi_stream', |
| 1330 | + parts=[Part(text='Hi')], |
| 1331 | + ) |
| 1332 | + ) |
| 1333 | + with pytest.raises(InvalidAgentResponseError, match='Multiple Message'): |
| 1334 | + async for _ in request_handler.on_message_send_stream( |
| 1335 | + params, create_server_call_context() |
| 1336 | + ): |
| 1337 | + pass |
| 1338 | + |
| 1339 | + |
| 1340 | +@pytest.mark.asyncio |
| 1341 | +@pytest.mark.timeout(1) |
| 1342 | +async def test_on_message_send_stream_rejects_message_after_task_event(): |
| 1343 | + """Stream surfaces InvalidAgentResponseError when the agent yields a |
| 1344 | + Message after entering task mode (see comment in on_message_send_stream).""" |
| 1345 | + request_handler = DefaultRequestHandlerV2( |
| 1346 | + agent_executor=MessageAfterTaskEventAgentExecutor(), |
| 1347 | + task_store=InMemoryTaskStore(), |
| 1348 | + agent_card=create_default_agent_card(), |
| 1349 | + ) |
| 1350 | + params = SendMessageRequest( |
| 1351 | + message=Message( |
| 1352 | + role=Role.ROLE_USER, |
| 1353 | + message_id='msg_after_task_stream', |
| 1354 | + parts=[Part(text='Hi')], |
| 1355 | + ) |
| 1356 | + ) |
| 1357 | + with pytest.raises( |
| 1358 | + InvalidAgentResponseError, match='Message object in task mode' |
| 1359 | + ): |
| 1360 | + async for _ in request_handler.on_message_send_stream( |
| 1361 | + params, create_server_call_context() |
| 1362 | + ): |
| 1363 | + pass |
| 1364 | + |
| 1365 | + |
| 1366 | +@pytest.mark.asyncio |
| 1367 | +@pytest.mark.timeout(1) |
| 1368 | +async def test_on_message_send_stream_rejects_task_event_after_message(): |
| 1369 | + """Stream surfaces InvalidAgentResponseError when the agent yields a |
| 1370 | + task-mode event after a Message (see comment in on_message_send_stream).""" |
| 1371 | + request_handler = DefaultRequestHandlerV2( |
| 1372 | + agent_executor=TaskEventAfterMessageAgentExecutor(), |
| 1373 | + task_store=InMemoryTaskStore(), |
| 1374 | + agent_card=create_default_agent_card(), |
| 1375 | + ) |
| 1376 | + params = SendMessageRequest( |
| 1377 | + message=Message( |
| 1378 | + role=Role.ROLE_USER, |
| 1379 | + message_id='msg_then_task_stream', |
| 1380 | + parts=[Part(text='Hi')], |
| 1381 | + ) |
| 1382 | + ) |
| 1383 | + with pytest.raises(InvalidAgentResponseError, match='in message mode'): |
| 1384 | + async for _ in request_handler.on_message_send_stream( |
| 1385 | + params, create_server_call_context() |
| 1386 | + ): |
| 1387 | + pass |
| 1388 | + |
| 1389 | + |
| 1390 | +@pytest.mark.asyncio |
| 1391 | +@pytest.mark.timeout(1) |
| 1392 | +async def test_on_message_send_stream_rejects_event_after_terminal_state(): |
| 1393 | + """Stream surfaces InvalidAgentResponseError when the agent yields an event |
| 1394 | + after reaching a terminal state (see comment in on_message_send_stream).""" |
| 1395 | + request_handler = DefaultRequestHandlerV2( |
| 1396 | + agent_executor=EventAfterTerminalStateAgentExecutor(), |
| 1397 | + task_store=InMemoryTaskStore(), |
| 1398 | + agent_card=create_default_agent_card(), |
| 1399 | + ) |
| 1400 | + params = SendMessageRequest( |
| 1401 | + message=Message( |
| 1402 | + role=Role.ROLE_USER, |
| 1403 | + message_id='msg_after_terminal_stream', |
| 1404 | + parts=[Part(text='Hi')], |
| 1405 | + ) |
| 1406 | + ) |
| 1407 | + with pytest.raises( |
| 1408 | + InvalidAgentResponseError, match='Message object in task mode' |
| 1409 | + ): |
| 1410 | + async for _ in request_handler.on_message_send_stream( |
| 1411 | + params, create_server_call_context() |
| 1412 | + ): |
| 1413 | + pass |
0 commit comments