Skip to content

Commit 1d3f8be

Browse files
authored
Merge branch 'main' into ishymko/metadata-refactor
2 parents 1b5db09 + 2e45c0d commit 1d3f8be

2 files changed

Lines changed: 43 additions & 1 deletion

File tree

src/a2a/client/transports/base.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from abc import ABC, abstractmethod
22
from collections.abc import AsyncGenerator, Callable
3+
from types import TracebackType
4+
5+
from typing_extensions import Self
36

47
from a2a.client.middleware import ClientCallContext
58
from a2a.types import (
@@ -19,6 +22,19 @@
1922
class ClientTransport(ABC):
2023
"""Abstract base class for a client transport."""
2124

25+
async def __aenter__(self) -> Self:
26+
"""Enters the async context manager, returning the transport itself."""
27+
return self
28+
29+
async def __aexit__(
30+
self,
31+
exc_type: type[BaseException] | None,
32+
exc_val: BaseException | None,
33+
exc_tb: TracebackType | None,
34+
) -> None:
35+
"""Exits the async context manager, ensuring close() is called."""
36+
await self.close()
37+
2238
@abstractmethod
2339
async def send_message(
2440
self,

tests/client/test_base_client.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from unittest.mock import AsyncMock, MagicMock
1+
from unittest.mock import AsyncMock, MagicMock, patch
22

33
import pytest
44

@@ -61,6 +61,32 @@ def base_client(
6161
)
6262

6363

64+
@pytest.mark.asyncio
65+
async def test_transport_async_context_manager() -> None:
66+
with (
67+
patch.object(ClientTransport, '__abstractmethods__', set()),
68+
patch.object(ClientTransport, 'close', new_callable=AsyncMock),
69+
):
70+
transport = ClientTransport()
71+
async with transport as t:
72+
assert t is transport
73+
transport.close.assert_not_awaited()
74+
transport.close.assert_awaited_once()
75+
76+
77+
@pytest.mark.asyncio
78+
async def test_transport_async_context_manager_on_exception() -> None:
79+
with (
80+
patch.object(ClientTransport, '__abstractmethods__', set()),
81+
patch.object(ClientTransport, 'close', new_callable=AsyncMock),
82+
):
83+
transport = ClientTransport()
84+
with pytest.raises(RuntimeError, match='boom'):
85+
async with transport:
86+
raise RuntimeError('boom')
87+
transport.close.assert_awaited_once()
88+
89+
6490
@pytest.mark.asyncio
6591
async def test_send_message_streaming(
6692
base_client: BaseClient, mock_transport: MagicMock, sample_message: Message

0 commit comments

Comments
 (0)