|
8 | 8 | from httpx_sse import EventSource, ServerSentEvent |
9 | 9 |
|
10 | 10 | from a2a.client import create_text_message_object |
11 | | -from a2a.client.errors import A2AClientHTTPError |
| 11 | +from a2a.client.errors import A2AClientHTTPError, A2AClientTimeoutError |
12 | 12 | from a2a.client.transports.rest import RestTransport |
13 | 13 | from a2a.extensions.common import HTTP_EXTENSION_HEADER |
14 | 14 | from a2a.types.a2a_pb2 import ( |
@@ -56,6 +56,40 @@ def _assert_extensions_header(mock_kwargs: dict, expected_extensions: set[str]): |
56 | 56 | assert actual_extensions == expected_extensions |
57 | 57 |
|
58 | 58 |
|
| 59 | +class TestRestTransport: |
| 60 | + @pytest.mark.asyncio |
| 61 | + @patch('a2a.client.transports.rest.aconnect_sse') |
| 62 | + async def test_send_message_streaming_timeout( |
| 63 | + self, |
| 64 | + mock_aconnect_sse: AsyncMock, |
| 65 | + mock_httpx_client: AsyncMock, |
| 66 | + mock_agent_card: MagicMock, |
| 67 | + ): |
| 68 | + client = RestTransport( |
| 69 | + httpx_client=mock_httpx_client, agent_card=mock_agent_card |
| 70 | + ) |
| 71 | + params = SendMessageRequest( |
| 72 | + message=create_text_message_object(content='Hello stream') |
| 73 | + ) |
| 74 | + mock_event_source = AsyncMock(spec=EventSource) |
| 75 | + mock_event_source.response = MagicMock(spec=httpx.Response) |
| 76 | + mock_event_source.response.raise_for_status.return_value = None |
| 77 | + mock_event_source.aiter_sse.side_effect = httpx.TimeoutException( |
| 78 | + 'Read timed out' |
| 79 | + ) |
| 80 | + mock_aconnect_sse.return_value.__aenter__.return_value = ( |
| 81 | + mock_event_source |
| 82 | + ) |
| 83 | + |
| 84 | + with pytest.raises(A2AClientTimeoutError) as exc_info: |
| 85 | + _ = [ |
| 86 | + item |
| 87 | + async for item in client.send_message_streaming(request=params) |
| 88 | + ] |
| 89 | + |
| 90 | + assert 'Client Request timed out' in str(exc_info.value) |
| 91 | + |
| 92 | + |
59 | 93 | class TestRestTransportExtensions: |
60 | 94 | @pytest.mark.asyncio |
61 | 95 | async def test_send_message_with_default_extensions( |
|
0 commit comments