forked from a2aproject/a2a-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathclient_factory.py
More file actions
417 lines (352 loc) · 14.8 KB
/
client_factory.py
File metadata and controls
417 lines (352 loc) · 14.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
from __future__ import annotations
import logging
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, cast
import httpx
from packaging.version import InvalidVersion, Version
from a2a.client.base_client import BaseClient
from a2a.client.card_resolver import A2ACardResolver
from a2a.client.client import Client, ClientConfig
from a2a.client.transports.base import ClientTransport
from a2a.client.transports.jsonrpc import JsonRpcTransport
from a2a.client.transports.rest import RestTransport
from a2a.client.transports.tenant_decorator import TenantTransportDecorator
from a2a.compat.v0_3.versions import is_legacy_version
from a2a.types.a2a_pb2 import (
AgentCapabilities,
AgentCard,
AgentInterface,
)
from a2a.utils.constants import (
PROTOCOL_VERSION_0_3,
PROTOCOL_VERSION_1_0,
PROTOCOL_VERSION_CURRENT,
VERSION_HEADER,
TransportProtocol,
)
if TYPE_CHECKING:
from a2a.client.interceptors import ClientCallInterceptor
try:
from a2a.client.transports.grpc import GrpcTransport
except ImportError:
GrpcTransport = None # type: ignore # pyright: ignore
try:
from a2a.compat.v0_3.grpc_transport import CompatGrpcTransport
except ImportError:
CompatGrpcTransport = None # type: ignore # pyright: ignore
logger = logging.getLogger(__name__)
TransportProducer = Callable[
[AgentCard, str, ClientConfig],
ClientTransport,
]
class ClientFactory:
"""ClientFactory is used to generate the appropriate client for the agent.
The factory is configured with a `ClientConfig` and optionally a list of
`Consumer`s to use for all generated `Client`s. The expected use is:
.. code-block:: python
factory = ClientFactory(config, consumers)
# Optionally register custom client implementations
factory.register('my_customer_transport', NewCustomTransportClient)
# Then with an agent card make a client with additional consumers and
# interceptors
client = factory.create(card, additional_consumers, interceptors)
Now the client can be used consistently regardless of the transport. This
aligns the client configuration with the server's capabilities.
"""
def __init__(
self,
config: ClientConfig,
):
client = config.httpx_client or httpx.AsyncClient()
client.headers.setdefault(VERSION_HEADER, PROTOCOL_VERSION_CURRENT)
config.httpx_client = client
self._config = config
self._registry: dict[str, TransportProducer] = {}
self._register_defaults(config.supported_protocol_bindings)
def _register_defaults(self, supported: list[str]) -> None:
# Empty support list implies JSON-RPC only.
if TransportProtocol.JSONRPC in supported or not supported:
def jsonrpc_transport_producer(
card: AgentCard,
url: str,
config: ClientConfig,
) -> ClientTransport:
interface = ClientFactory._find_best_interface(
list(card.supported_interfaces),
protocol_bindings=[TransportProtocol.JSONRPC],
url=url,
)
version = (
interface.protocol_version
if interface
else PROTOCOL_VERSION_CURRENT
)
if is_legacy_version(version):
from a2a.compat.v0_3.jsonrpc_transport import ( # noqa: PLC0415
CompatJsonRpcTransport,
)
return CompatJsonRpcTransport(
cast('httpx.AsyncClient', config.httpx_client),
card,
url,
)
return JsonRpcTransport(
cast('httpx.AsyncClient', config.httpx_client),
card,
url,
)
self.register(
TransportProtocol.JSONRPC,
jsonrpc_transport_producer,
)
if TransportProtocol.HTTP_JSON in supported:
def rest_transport_producer(
card: AgentCard,
url: str,
config: ClientConfig,
) -> ClientTransport:
interface = ClientFactory._find_best_interface(
list(card.supported_interfaces),
protocol_bindings=[TransportProtocol.HTTP_JSON],
url=url,
)
version = (
interface.protocol_version
if interface
else PROTOCOL_VERSION_CURRENT
)
if is_legacy_version(version):
from a2a.compat.v0_3.rest_transport import ( # noqa: PLC0415
CompatRestTransport,
)
return CompatRestTransport(
cast('httpx.AsyncClient', config.httpx_client),
card,
url,
)
return RestTransport(
cast('httpx.AsyncClient', config.httpx_client),
card,
url,
)
self.register(
TransportProtocol.HTTP_JSON,
rest_transport_producer,
)
if TransportProtocol.GRPC in supported:
if GrpcTransport is None:
raise ImportError(
'To use GrpcClient, its dependencies must be installed. '
'You can install them with \'pip install "a2a-sdk[grpc]"\''
)
_grpc_transport = GrpcTransport
def grpc_transport_producer(
card: AgentCard,
url: str,
config: ClientConfig,
) -> ClientTransport:
# The interface has already been selected and passed as `url`.
# We determine its version to use the appropriate transport implementation.
interface = ClientFactory._find_best_interface(
list(card.supported_interfaces),
protocol_bindings=[TransportProtocol.GRPC],
url=url,
)
version = (
interface.protocol_version
if interface
else PROTOCOL_VERSION_CURRENT
)
if (
is_legacy_version(version)
and CompatGrpcTransport is not None
):
return CompatGrpcTransport.create(card, url, config)
return _grpc_transport.create(card, url, config)
self.register(
TransportProtocol.GRPC,
grpc_transport_producer,
)
@staticmethod
def _find_best_interface(
interfaces: list[AgentInterface],
protocol_bindings: list[str] | None = None,
url: str | None = None,
) -> AgentInterface | None:
"""Finds the best interface based on protocol version priorities."""
candidates = [
i
for i in interfaces
if (
protocol_bindings is None
or i.protocol_binding in protocol_bindings
)
and (url is None or i.url == url)
]
if not candidates:
return None
# Prefer interface with version 1.0
for i in candidates:
if i.protocol_version == PROTOCOL_VERSION_1_0:
return i
best_gt_1_0 = None
best_ge_0_3 = None
best_no_version = None
for i in candidates:
if not i.protocol_version:
if best_no_version is None:
best_no_version = i
continue
try:
v = Version(i.protocol_version)
if best_gt_1_0 is None and v > Version(PROTOCOL_VERSION_1_0):
best_gt_1_0 = i
if best_ge_0_3 is None and v >= Version(PROTOCOL_VERSION_0_3):
best_ge_0_3 = i
except InvalidVersion:
pass
return best_gt_1_0 or best_ge_0_3 or best_no_version
@classmethod
async def connect( # noqa: PLR0913
cls,
agent: str | AgentCard,
client_config: ClientConfig | None = None,
interceptors: list[ClientCallInterceptor] | None = None,
relative_card_path: str | None = None,
resolver_http_kwargs: dict[str, Any] | None = None,
extra_transports: dict[str, TransportProducer] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> Client:
"""Convenience method for constructing a client.
Constructs a client that connects to the specified agent. Note that
creating multiple clients via this method is less efficient than
constructing an instance of ClientFactory and reusing that.
.. code-block:: python
# This will search for an AgentCard at /.well-known/agent-card.json
my_agent_url = 'https://travel.agents.example.com'
client = await ClientFactory.connect(my_agent_url)
Args:
agent: The base URL of the agent, or the AgentCard to connect to.
client_config: The ClientConfig to use when connecting to the agent.
interceptors: A list of interceptors to use for each request. These
are used for things like attaching credentials or http headers
to all outbound requests.
relative_card_path: If the agent field is a URL, this value is used as
the relative path when resolving the agent card. See
A2AAgentCardResolver.get_agent_card for more details.
resolver_http_kwargs: Dictionary of arguments to provide to the httpx
client when resolving the agent card. This value is provided to
A2AAgentCardResolver.get_agent_card as the http_kwargs parameter.
extra_transports: Additional transport protocols to enable when
constructing the client.
signature_verifier: A callable used to verify the agent card's signatures.
Returns:
A `Client` object.
"""
client_config = client_config or ClientConfig()
if isinstance(agent, str):
if not client_config.httpx_client:
async with httpx.AsyncClient() as client:
resolver = A2ACardResolver(client, agent)
card = await resolver.get_agent_card(
relative_card_path=relative_card_path,
http_kwargs=resolver_http_kwargs,
signature_verifier=signature_verifier,
)
else:
resolver = A2ACardResolver(client_config.httpx_client, agent)
card = await resolver.get_agent_card(
relative_card_path=relative_card_path,
http_kwargs=resolver_http_kwargs,
signature_verifier=signature_verifier,
)
else:
card = agent
factory = cls(client_config)
for label, generator in (extra_transports or {}).items():
factory.register(label, generator)
return factory.create(card, interceptors)
def register(self, label: str, generator: TransportProducer) -> None:
"""Register a new transport producer for a given transport label."""
self._registry[label] = generator
def create(
self,
card: AgentCard,
interceptors: list[ClientCallInterceptor] | None = None,
) -> Client:
"""Create a new `Client` for the provided `AgentCard`.
Args:
card: An `AgentCard` defining the characteristics of the agent.
interceptors: A list of interceptors to use for each request. These
are used for things like attaching credentials or http headers
to all outbound requests.
Returns:
A `Client` object.
Raises:
If there is no valid matching of the client configuration with the
server configuration, a `ValueError` is raised.
"""
client_set = self._config.supported_protocol_bindings or [
TransportProtocol.JSONRPC
]
transport_protocol = None
selected_interface = None
if self._config.use_client_preference:
for protocol_binding in client_set:
selected_interface = ClientFactory._find_best_interface(
list(card.supported_interfaces),
protocol_bindings=[protocol_binding],
)
if selected_interface:
transport_protocol = protocol_binding
break
else:
for supported_interface in card.supported_interfaces:
if supported_interface.protocol_binding in client_set:
transport_protocol = supported_interface.protocol_binding
selected_interface = ClientFactory._find_best_interface(
list(card.supported_interfaces),
protocol_bindings=[transport_protocol],
)
break
if not transport_protocol or not selected_interface:
raise ValueError('no compatible transports found.')
if transport_protocol not in self._registry:
raise ValueError(f'no client available for {transport_protocol}')
transport = self._registry[transport_protocol](
card, selected_interface.url, self._config
)
if selected_interface.tenant:
transport = TenantTransportDecorator(
transport, selected_interface.tenant
)
return BaseClient(
card,
self._config,
transport,
interceptors or [],
)
def minimal_agent_card(
url: str, transports: list[str] | None = None
) -> AgentCard:
"""Generates a minimal card to simplify bootstrapping client creation.
This minimal card is not viable itself to interact with the remote agent.
Instead this is a shorthand way to take a known url and transport option
and interact with the get card endpoint of the agent server to get the
correct agent card. This pattern is necessary for gRPC based card access
as typically these servers won't expose a well known path card.
"""
if transports is None:
transports = []
return AgentCard(
supported_interfaces=[
AgentInterface(protocol_binding=t, url=url) for t in transports
],
capabilities=AgentCapabilities(extended_agent_card=True),
default_input_modes=[],
default_output_modes=[],
description='',
skills=[],
version='',
name='',
)