|
| 1 | +"""Test that streaming SSE responses clean up without athrow() errors. |
| 2 | +
|
| 3 | +Reproduces https://github.com/a2aproject/a2a-python/issues/911 — |
| 4 | +``RuntimeError: athrow(): asynchronous generator is already running`` |
| 5 | +during event-loop shutdown after consuming a streaming response. |
| 6 | +""" |
| 7 | + |
| 8 | +import asyncio |
| 9 | +import gc |
| 10 | + |
| 11 | +from typing import Any |
| 12 | +from uuid import uuid4 |
| 13 | + |
| 14 | +import httpx |
| 15 | +import pytest |
| 16 | + |
| 17 | +from starlette.applications import Starlette |
| 18 | + |
| 19 | +from a2a.client.base_client import BaseClient |
| 20 | +from a2a.client.client import ClientConfig |
| 21 | +from a2a.client.client_factory import ClientFactory |
| 22 | +from a2a.server.agent_execution import AgentExecutor, RequestContext |
| 23 | +from a2a.server.events import EventQueue |
| 24 | +from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager |
| 25 | +from a2a.server.request_handlers import DefaultRequestHandler |
| 26 | +from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes |
| 27 | +from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore |
| 28 | +from a2a.types import ( |
| 29 | + AgentCapabilities, |
| 30 | + AgentCard, |
| 31 | + AgentInterface, |
| 32 | + Message, |
| 33 | + Part, |
| 34 | + Role, |
| 35 | + SendMessageRequest, |
| 36 | +) |
| 37 | +from a2a.utils import TransportProtocol |
| 38 | + |
| 39 | + |
| 40 | +class _MessageExecutor(AgentExecutor): |
| 41 | + """Responds with a single Message event.""" |
| 42 | + |
| 43 | + async def execute(self, ctx: RequestContext, eq: EventQueue) -> None: |
| 44 | + await eq.enqueue_event( |
| 45 | + Message( |
| 46 | + role=Role.ROLE_AGENT, |
| 47 | + message_id=str(uuid4()), |
| 48 | + parts=[Part(text='Hello')], |
| 49 | + context_id=ctx.context_id, |
| 50 | + task_id=ctx.task_id, |
| 51 | + ) |
| 52 | + ) |
| 53 | + |
| 54 | + async def cancel(self, ctx: RequestContext, eq: EventQueue) -> None: |
| 55 | + pass |
| 56 | + |
| 57 | + |
| 58 | +@pytest.fixture |
| 59 | +def client(): |
| 60 | + """Creates a JSON-RPC client backed by an in-process ASGI server.""" |
| 61 | + card = AgentCard( |
| 62 | + name='T', |
| 63 | + description='T', |
| 64 | + version='1', |
| 65 | + capabilities=AgentCapabilities(streaming=True), |
| 66 | + default_input_modes=['text/plain'], |
| 67 | + default_output_modes=['text/plain'], |
| 68 | + supported_interfaces=[ |
| 69 | + AgentInterface( |
| 70 | + protocol_binding=TransportProtocol.JSONRPC, |
| 71 | + url='http://test', |
| 72 | + ), |
| 73 | + ], |
| 74 | + ) |
| 75 | + handler = DefaultRequestHandler( |
| 76 | + agent_executor=_MessageExecutor(), |
| 77 | + task_store=InMemoryTaskStore(), |
| 78 | + queue_manager=InMemoryQueueManager(), |
| 79 | + ) |
| 80 | + app = Starlette( |
| 81 | + routes=[ |
| 82 | + *create_agent_card_routes(agent_card=card, card_url='/card'), |
| 83 | + *create_jsonrpc_routes( |
| 84 | + agent_card=card, |
| 85 | + request_handler=handler, |
| 86 | + extended_agent_card=card, |
| 87 | + rpc_url='/', |
| 88 | + ), |
| 89 | + ] |
| 90 | + ) |
| 91 | + return ClientFactory( |
| 92 | + config=ClientConfig( |
| 93 | + httpx_client=httpx.AsyncClient( |
| 94 | + transport=httpx.ASGITransport(app=app), |
| 95 | + base_url='http://test', |
| 96 | + ) |
| 97 | + ) |
| 98 | + ).create(card) |
| 99 | + |
| 100 | + |
| 101 | +@pytest.mark.asyncio |
| 102 | +async def test_stream_message_no_athrow(client: BaseClient) -> None: |
| 103 | + """Consuming a streamed Message must not leave broken async generators.""" |
| 104 | + errors: list[dict[str, Any]] = [] |
| 105 | + loop = asyncio.get_event_loop() |
| 106 | + orig = loop.get_exception_handler() |
| 107 | + loop.set_exception_handler(lambda _l, ctx: errors.append(ctx)) |
| 108 | + |
| 109 | + try: |
| 110 | + msg = Message( |
| 111 | + role=Role.ROLE_USER, |
| 112 | + message_id=f'msg-{uuid4()}', |
| 113 | + parts=[Part(text='hi')], |
| 114 | + ) |
| 115 | + events = [ |
| 116 | + e |
| 117 | + async for e in client.send_message( |
| 118 | + request=SendMessageRequest(message=msg) |
| 119 | + ) |
| 120 | + ] |
| 121 | + assert events |
| 122 | + assert events[0][0].HasField('message') |
| 123 | + |
| 124 | + gc.collect() |
| 125 | + await loop.shutdown_asyncgens() |
| 126 | + |
| 127 | + bad = [ |
| 128 | + e |
| 129 | + for e in errors |
| 130 | + if 'asynchronous generator' in str(e.get('message', '')) |
| 131 | + ] |
| 132 | + assert not bad, '\n'.join(str(e.get('message', '')) for e in bad) |
| 133 | + finally: |
| 134 | + loop.set_exception_handler(orig) |
| 135 | + await client.close() |
0 commit comments