Skip to content

Commit 2e45c0d

Browse files
feat: add async context manager support to ClientTransport (#682)
# Description Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [x] Follow the [`CONTRIBUTING` Guide](https://github.com/a2aproject/a2a-python/blob/main/CONTRIBUTING.md). - [x] Make your Pull Request title in the <https://www.conventionalcommits.org/> specification. - Important Prefixes for [release-please](https://github.com/googleapis/release-please): - `fix:` which represents bug fixes, and correlates to a [SemVer](https://semver.org/) patch. - `feat:` represents a new feature, and correlates to a SemVer minor. - `feat!:`, or `fix!:`, `refactor!:`, etc., which represent a breaking change (indicated by the `!`) and will result in a SemVer major. - [x] Ensure the tests and linter pass (Run `bash scripts/format.sh` from the repository root to format) - [x] Appropriate docs were updated (if necessary) Fixes #674 🦕 ## Problem `ClientTransport` defines an abstract `close()` method but does not implement `__aenter__`/`__aexit__`. This means transports cannot be used with `async with`, which is the idiomatic Python pattern for managing async resources. If an exception occurs between creating a transport and calling `close()`, the underlying connection (e.g., gRPC channel, httpx client) is never cleaned up: ```python transport = GrpcTransport(channel=channel, agent_card=agent_card) result = await transport.send_message(params) # if this raises, close() is never called await transport.close() ``` ## Fix Added `__aenter__` and `__aexit__` methods to `ClientTransport` in `src/a2a/client/transports/base.py`: `__aenter__` returns `self`. `__aexit__` awaits `close()`. This enables the standard async context manager pattern on all transport implementations (`GrpcTransport`, `RestTransport`, `JsonRpcTransport`): ```python async with GrpcTransport(channel=channel, agent_card=agent_card) as transport: result = await transport.send_message(params) # close() called automatically, even on exceptions ``` This is a non-breaking, additive change. Calling `close()` manually or via `try/finally` continues to work exactly as before. ## Test Tests were added to `tests/client/test_base_client.py` since it already imports and mocks `ClientTransport`. Happy to move them to a dedicated file if maintainers prefer. ## Note As mentioned in #674, the same pattern could also be applied to `BaseClient`, which wraps `ClientTransport` and also exposes a `close()` method. I've kept this PR scoped to `ClientTransport` only. Happy to extend it to `BaseClient` in this same PR or a follow-up if maintainers prefer. Release-As: 0.3.23 --------- Co-authored-by: Ivan Shymko <vana.shimko@gmail.com>
1 parent c91d4fb commit 2e45c0d

2 files changed

Lines changed: 43 additions & 1 deletion

File tree

src/a2a/client/transports/base.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from abc import ABC, abstractmethod
22
from collections.abc import AsyncGenerator, Callable
3+
from types import TracebackType
4+
5+
from typing_extensions import Self
36

47
from a2a.client.middleware import ClientCallContext
58
from a2a.types import (
@@ -19,6 +22,19 @@
1922
class ClientTransport(ABC):
2023
"""Abstract base class for a client transport."""
2124

25+
async def __aenter__(self) -> Self:
26+
"""Enters the async context manager, returning the transport itself."""
27+
return self
28+
29+
async def __aexit__(
30+
self,
31+
exc_type: type[BaseException] | None,
32+
exc_val: BaseException | None,
33+
exc_tb: TracebackType | None,
34+
) -> None:
35+
"""Exits the async context manager, ensuring close() is called."""
36+
await self.close()
37+
2238
@abstractmethod
2339
async def send_message(
2440
self,

tests/client/test_base_client.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from unittest.mock import AsyncMock, MagicMock
1+
from unittest.mock import AsyncMock, MagicMock, patch
22

33
import pytest
44

@@ -61,6 +61,32 @@ def base_client(
6161
)
6262

6363

64+
@pytest.mark.asyncio
65+
async def test_transport_async_context_manager() -> None:
66+
with (
67+
patch.object(ClientTransport, '__abstractmethods__', set()),
68+
patch.object(ClientTransport, 'close', new_callable=AsyncMock),
69+
):
70+
transport = ClientTransport()
71+
async with transport as t:
72+
assert t is transport
73+
transport.close.assert_not_awaited()
74+
transport.close.assert_awaited_once()
75+
76+
77+
@pytest.mark.asyncio
78+
async def test_transport_async_context_manager_on_exception() -> None:
79+
with (
80+
patch.object(ClientTransport, '__abstractmethods__', set()),
81+
patch.object(ClientTransport, 'close', new_callable=AsyncMock),
82+
):
83+
transport = ClientTransport()
84+
with pytest.raises(RuntimeError, match='boom'):
85+
async with transport:
86+
raise RuntimeError('boom')
87+
transport.close.assert_awaited_once()
88+
89+
6490
@pytest.mark.asyncio
6591
async def test_send_message_streaming(
6692
base_client: BaseClient, mock_transport: MagicMock, sample_message: Message

0 commit comments

Comments
 (0)