|
1 | | -from unittest.mock import AsyncMock, MagicMock |
| 1 | +from unittest.mock import AsyncMock, MagicMock, patch |
2 | 2 |
|
3 | 3 | import pytest |
4 | 4 |
|
@@ -61,6 +61,32 @@ def base_client( |
61 | 61 | ) |
62 | 62 |
|
63 | 63 |
|
| 64 | +@pytest.mark.asyncio |
| 65 | +async def test_transport_async_context_manager() -> None: |
| 66 | + with ( |
| 67 | + patch.object(ClientTransport, '__abstractmethods__', set()), |
| 68 | + patch.object(ClientTransport, 'close', new_callable=AsyncMock), |
| 69 | + ): |
| 70 | + transport = ClientTransport() |
| 71 | + async with transport as t: |
| 72 | + assert t is transport |
| 73 | + transport.close.assert_not_awaited() |
| 74 | + transport.close.assert_awaited_once() |
| 75 | + |
| 76 | + |
| 77 | +@pytest.mark.asyncio |
| 78 | +async def test_transport_async_context_manager_on_exception() -> None: |
| 79 | + with ( |
| 80 | + patch.object(ClientTransport, '__abstractmethods__', set()), |
| 81 | + patch.object(ClientTransport, 'close', new_callable=AsyncMock), |
| 82 | + ): |
| 83 | + transport = ClientTransport() |
| 84 | + with pytest.raises(RuntimeError, match='boom'): |
| 85 | + async with transport: |
| 86 | + raise RuntimeError('boom') |
| 87 | + transport.close.assert_awaited_once() |
| 88 | + |
| 89 | + |
64 | 90 | @pytest.mark.asyncio |
65 | 91 | async def test_send_message_streaming( |
66 | 92 | base_client: BaseClient, mock_transport: MagicMock, sample_message: Message |
|
0 commit comments