77
88import httpx
99
10+ from packaging .version import InvalidVersion , Version
11+
1012from a2a .client .base_client import BaseClient
1113from a2a .client .card_resolver import A2ACardResolver
1214from a2a .client .client import Client , ClientConfig , Consumer
2123 AgentInterface ,
2224)
2325from a2a .utils .constants import (
26+ PROTOCOL_VERSION_0_3 ,
27+ PROTOCOL_VERSION_1_0 ,
2428 PROTOCOL_VERSION_CURRENT ,
2529 VERSION_HEADER ,
2630 TransportProtocol ,
3337 GrpcTransport = None # type: ignore # pyright: ignore
3438
3539
40+ try :
41+ from a2a .compat .v0_3 .grpc_transport import CompatGrpcTransport
42+ except ImportError :
43+ CompatGrpcTransport = None # type: ignore # pyright: ignore
44+
45+
3646logger = logging .getLogger (__name__ )
3747
3848
@@ -109,11 +119,92 @@ def _register_defaults(self, supported: list[str]) -> None:
109119 'To use GrpcClient, its dependencies must be installed. '
110120 'You can install them with \' pip install "a2a-sdk[grpc]"\' '
111121 )
122+
123+ def grpc_transport_producer (
124+ card : AgentCard ,
125+ url : str ,
126+ config : ClientConfig ,
127+ interceptors : list [ClientCallInterceptor ],
128+ ) -> ClientTransport :
129+ # The interface has already been selected and passed as `url`.
130+ # We determine its version to use the appropriate transport implementation.
131+ interface = ClientFactory ._find_best_interface (
132+ list (card .supported_interfaces ),
133+ protocol_bindings = [TransportProtocol .GRPC ],
134+ url = url ,
135+ )
136+ version = (
137+ interface .protocol_version
138+ if interface
139+ else PROTOCOL_VERSION_CURRENT
140+ )
141+
142+ if version and CompatGrpcTransport is not None :
143+ try :
144+ v = Version (version )
145+ if (
146+ Version (PROTOCOL_VERSION_0_3 )
147+ <= v
148+ < Version (PROTOCOL_VERSION_1_0 )
149+ ):
150+ return CompatGrpcTransport .create (
151+ card , url , config , interceptors
152+ )
153+ except InvalidVersion :
154+ pass
155+
156+ return GrpcTransport .create (card , url , config , interceptors )
157+
112158 self .register (
113159 TransportProtocol .GRPC ,
114- GrpcTransport . create ,
160+ grpc_transport_producer ,
115161 )
116162
163+ @staticmethod
164+ def _find_best_interface (
165+ interfaces : list [AgentInterface ],
166+ protocol_bindings : list [str ] | None = None ,
167+ url : str | None = None ,
168+ ) -> AgentInterface | None :
169+ """Finds the best interface based on protocol version priorities."""
170+
171+ candidates = []
172+ for i in interfaces :
173+ if (
174+ (protocol_bindings is None or i .protocol_binding in protocol_bindings )
175+ and (url is None or i .url == url )
176+ ):
177+ candidates .append (i )
178+
179+ if not candidates :
180+ return None
181+
182+ # Prefer interface with version 1.0
183+ for i in candidates :
184+ if i .protocol_version == PROTOCOL_VERSION_1_0 :
185+ return i
186+
187+ best_gt_1_0 = None
188+ best_ge_0_3 = None
189+ best_no_version = None
190+
191+ for i in candidates :
192+ if not i .protocol_version :
193+ if best_no_version is None :
194+ best_no_version = i
195+ continue
196+
197+ try :
198+ v = Version (i .protocol_version )
199+ if best_gt_1_0 is None and v > Version (PROTOCOL_VERSION_1_0 ):
200+ best_gt_1_0 = i
201+ if best_ge_0_3 is None and v >= Version (PROTOCOL_VERSION_0_3 ):
202+ best_ge_0_3 = i
203+ except InvalidVersion :
204+ pass
205+
206+ return best_gt_1_0 or best_ge_0_3 or best_no_version
207+
117208 @classmethod
118209 async def connect ( # noqa: PLR0913
119210 cls ,
@@ -220,13 +311,9 @@ def create(
220311 selected_interface = None
221312 if self ._config .use_client_preference :
222313 for protocol_binding in client_set :
223- selected_interface = next (
224- (
225- si
226- for si in card .supported_interfaces
227- if si .protocol_binding == protocol_binding
228- ),
229- None ,
314+ selected_interface = ClientFactory ._find_best_interface (
315+ list (card .supported_interfaces ),
316+ protocol_bindings = [protocol_binding ],
230317 )
231318 if selected_interface :
232319 transport_protocol = protocol_binding
@@ -235,7 +322,10 @@ def create(
235322 for supported_interface in card .supported_interfaces :
236323 if supported_interface .protocol_binding in client_set :
237324 transport_protocol = supported_interface .protocol_binding
238- selected_interface = supported_interface
325+ selected_interface = ClientFactory ._find_best_interface (
326+ list (card .supported_interfaces ),
327+ protocol_bindings = [transport_protocol ],
328+ )
239329 break
240330 if not transport_protocol or not selected_interface :
241331 raise ValueError ('no compatible transports found.' )
0 commit comments