Skip to content

Commit dcab123

Browse files
committed
feat(compat): GRPC client compatible with 0.3 servers.
1 parent 3197a73 commit dcab123

8 files changed

Lines changed: 898 additions & 13 deletions

File tree

src/a2a/client/client_factory.py

Lines changed: 99 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
import httpx
99

10+
from packaging.version import InvalidVersion, Version
11+
1012
from a2a.client.base_client import BaseClient
1113
from a2a.client.card_resolver import A2ACardResolver
1214
from a2a.client.client import Client, ClientConfig, Consumer
@@ -21,6 +23,8 @@
2123
AgentInterface,
2224
)
2325
from a2a.utils.constants import (
26+
PROTOCOL_VERSION_0_3,
27+
PROTOCOL_VERSION_1_0,
2428
PROTOCOL_VERSION_CURRENT,
2529
VERSION_HEADER,
2630
TransportProtocol,
@@ -33,6 +37,12 @@
3337
GrpcTransport = None # type: ignore # pyright: ignore
3438

3539

40+
try:
41+
from a2a.compat.v0_3.grpc_transport import CompatGrpcTransport
42+
except ImportError:
43+
CompatGrpcTransport = None # type: ignore # pyright: ignore
44+
45+
3646
logger = logging.getLogger(__name__)
3747

3848

@@ -109,11 +119,92 @@ def _register_defaults(self, supported: list[str]) -> None:
109119
'To use GrpcClient, its dependencies must be installed. '
110120
'You can install them with \'pip install "a2a-sdk[grpc]"\''
111121
)
122+
123+
def grpc_transport_producer(
124+
card: AgentCard,
125+
url: str,
126+
config: ClientConfig,
127+
interceptors: list[ClientCallInterceptor],
128+
) -> ClientTransport:
129+
# The interface has already been selected and passed as `url`.
130+
# We determine its version to use the appropriate transport implementation.
131+
interface = ClientFactory._find_best_interface(
132+
list(card.supported_interfaces),
133+
protocol_bindings=[TransportProtocol.GRPC],
134+
url=url,
135+
)
136+
version = (
137+
interface.protocol_version
138+
if interface
139+
else PROTOCOL_VERSION_CURRENT
140+
)
141+
142+
if version and CompatGrpcTransport is not None:
143+
try:
144+
v = Version(version)
145+
if (
146+
Version(PROTOCOL_VERSION_0_3)
147+
<= v
148+
< Version(PROTOCOL_VERSION_1_0)
149+
):
150+
return CompatGrpcTransport.create(
151+
card, url, config, interceptors
152+
)
153+
except InvalidVersion:
154+
pass
155+
156+
return GrpcTransport.create(card, url, config, interceptors)
157+
112158
self.register(
113159
TransportProtocol.GRPC,
114-
GrpcTransport.create,
160+
grpc_transport_producer,
115161
)
116162

163+
@staticmethod
164+
def _find_best_interface(
165+
interfaces: list[AgentInterface],
166+
protocol_bindings: list[str] | None = None,
167+
url: str | None = None,
168+
) -> AgentInterface | None:
169+
"""Finds the best interface based on protocol version priorities."""
170+
171+
candidates = []
172+
for i in interfaces:
173+
if (
174+
(protocol_bindings is None or i.protocol_binding in protocol_bindings)
175+
and (url is None or i.url == url)
176+
):
177+
candidates.append(i)
178+
179+
if not candidates:
180+
return None
181+
182+
# Prefer interface with version 1.0
183+
for i in candidates:
184+
if i.protocol_version == PROTOCOL_VERSION_1_0:
185+
return i
186+
187+
best_gt_1_0 = None
188+
best_ge_0_3 = None
189+
best_no_version = None
190+
191+
for i in candidates:
192+
if not i.protocol_version:
193+
if best_no_version is None:
194+
best_no_version = i
195+
continue
196+
197+
try:
198+
v = Version(i.protocol_version)
199+
if best_gt_1_0 is None and v > Version(PROTOCOL_VERSION_1_0):
200+
best_gt_1_0 = i
201+
if best_ge_0_3 is None and v >= Version(PROTOCOL_VERSION_0_3):
202+
best_ge_0_3 = i
203+
except InvalidVersion:
204+
pass
205+
206+
return best_gt_1_0 or best_ge_0_3 or best_no_version
207+
117208
@classmethod
118209
async def connect( # noqa: PLR0913
119210
cls,
@@ -220,13 +311,9 @@ def create(
220311
selected_interface = None
221312
if self._config.use_client_preference:
222313
for protocol_binding in client_set:
223-
selected_interface = next(
224-
(
225-
si
226-
for si in card.supported_interfaces
227-
if si.protocol_binding == protocol_binding
228-
),
229-
None,
314+
selected_interface = ClientFactory._find_best_interface(
315+
list(card.supported_interfaces),
316+
protocol_bindings=[protocol_binding],
230317
)
231318
if selected_interface:
232319
transport_protocol = protocol_binding
@@ -235,7 +322,10 @@ def create(
235322
for supported_interface in card.supported_interfaces:
236323
if supported_interface.protocol_binding in client_set:
237324
transport_protocol = supported_interface.protocol_binding
238-
selected_interface = supported_interface
325+
selected_interface = ClientFactory._find_best_interface(
326+
list(card.supported_interfaces),
327+
protocol_bindings=[transport_protocol],
328+
)
239329
break
240330
if not transport_protocol or not selected_interface:
241331
raise ValueError('no compatible transports found.')

src/a2a/compat/v0_3/conversions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1144,7 +1144,10 @@ def to_compat_get_task_push_notification_config_request(
11441144
types_v03.GetTaskPushNotificationConfigParams | types_v03.TaskIdParams
11451145
)
11461146
if core_req.id:
1147-
params = types_v03.GetTaskPushNotificationConfigParams(
1147+
params: (
1148+
types_v03.TaskIdParams
1149+
| types_v03.GetTaskPushNotificationConfigParams
1150+
) = types_v03.GetTaskPushNotificationConfigParams(
11481151
id=core_req.task_id, push_notification_config_id=core_req.id
11491152
)
11501153
else:

0 commit comments

Comments
 (0)