Skip to content

Commit ca70e7d

Browse files
committed
feat(compat): REST and JSONRPC clients.
1 parent 9856054 commit ca70e7d

9 files changed

Lines changed: 2006 additions & 34 deletions

File tree

src/a2a/client/client_factory.py

Lines changed: 95 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
except ImportError:
4343
CompatGrpcTransport = None # type: ignore # pyright: ignore
4444

45-
4645
logger = logging.getLogger(__name__)
4746

4847

@@ -92,24 +91,88 @@ def _register_defaults(self, supported: list[str]) -> None:
9291
# Empty support list implies JSON-RPC only.
9392

9493
if TransportProtocol.JSONRPC in supported or not supported:
95-
self.register(
96-
TransportProtocol.JSONRPC,
97-
lambda card, url, config, interceptors: JsonRpcTransport(
94+
95+
def jsonrpc_transport_producer(
96+
card: AgentCard,
97+
url: str,
98+
config: ClientConfig,
99+
interceptors: list[ClientCallInterceptor],
100+
) -> ClientTransport:
101+
interface = ClientFactory._find_best_interface(
102+
list(card.supported_interfaces),
103+
protocol_bindings=[TransportProtocol.JSONRPC],
104+
url=url,
105+
)
106+
version = (
107+
interface.protocol_version
108+
if interface
109+
else PROTOCOL_VERSION_CURRENT
110+
)
111+
112+
if ClientFactory._is_legacy_version(version):
113+
from a2a.compat.v0_3.jsonrpc_transport import ( # noqa: PLC0415
114+
CompatJsonRpcTransport,
115+
)
116+
117+
return CompatJsonRpcTransport(
118+
cast('httpx.AsyncClient', config.httpx_client),
119+
card,
120+
url,
121+
interceptors,
122+
)
123+
124+
return JsonRpcTransport(
98125
cast('httpx.AsyncClient', config.httpx_client),
99126
card,
100127
url,
101128
interceptors,
102-
),
129+
)
130+
131+
self.register(
132+
TransportProtocol.JSONRPC,
133+
jsonrpc_transport_producer,
103134
)
104135
if TransportProtocol.HTTP_JSON in supported:
105-
self.register(
106-
TransportProtocol.HTTP_JSON,
107-
lambda card, url, config, interceptors: RestTransport(
136+
137+
def rest_transport_producer(
138+
card: AgentCard,
139+
url: str,
140+
config: ClientConfig,
141+
interceptors: list[ClientCallInterceptor],
142+
) -> ClientTransport:
143+
interface = ClientFactory._find_best_interface(
144+
list(card.supported_interfaces),
145+
protocol_bindings=[TransportProtocol.HTTP_JSON],
146+
url=url,
147+
)
148+
version = (
149+
interface.protocol_version
150+
if interface
151+
else PROTOCOL_VERSION_CURRENT
152+
)
153+
154+
if ClientFactory._is_legacy_version(version):
155+
from a2a.compat.v0_3.rest_transport import ( # noqa: PLC0415
156+
CompatRestTransport,
157+
)
158+
159+
return CompatRestTransport(
160+
cast('httpx.AsyncClient', config.httpx_client),
161+
card,
162+
url,
163+
interceptors,
164+
)
165+
166+
return RestTransport(
108167
cast('httpx.AsyncClient', config.httpx_client),
109168
card,
110169
url,
111170
interceptors,
112-
),
171+
)
172+
173+
self.register(
174+
TransportProtocol.HTTP_JSON,
175+
rest_transport_producer,
113176
)
114177
if TransportProtocol.GRPC in supported:
115178
if GrpcTransport is None:
@@ -137,27 +200,17 @@ def grpc_transport_producer(
137200
else PROTOCOL_VERSION_CURRENT
138201
)
139202

140-
compat_transport = CompatGrpcTransport
141-
if version and compat_transport is not None:
142-
try:
143-
v = Version(version)
144-
if (
145-
Version(PROTOCOL_VERSION_0_3)
146-
<= v
147-
< Version(PROTOCOL_VERSION_1_0)
148-
):
149-
return compat_transport.create(
150-
card, url, config, interceptors
151-
)
152-
except InvalidVersion:
153-
pass
154-
155-
grpc_transport = GrpcTransport
156-
if grpc_transport is not None:
157-
return grpc_transport.create(
203+
if (
204+
ClientFactory._is_legacy_version(version)
205+
and CompatGrpcTransport is not None
206+
):
207+
return CompatGrpcTransport.create(
158208
card, url, config, interceptors
159209
)
160210

211+
if GrpcTransport is not None:
212+
return GrpcTransport.create(card, url, config, interceptors)
213+
161214
raise ImportError(
162215
'GrpcTransport is not available. '
163216
'You can install it with \'pip install "a2a-sdk[grpc]"\''
@@ -168,6 +221,21 @@ def grpc_transport_producer(
168221
grpc_transport_producer,
169222
)
170223

224+
@staticmethod
225+
def _is_legacy_version(version: str | None) -> bool:
226+
"""Determines if the given version is a legacy protocol version (>=0.3 and <1.0)."""
227+
if not version:
228+
return False
229+
try:
230+
v = Version(version)
231+
return (
232+
Version(PROTOCOL_VERSION_0_3)
233+
<= v
234+
< Version(PROTOCOL_VERSION_1_0)
235+
)
236+
except InvalidVersion:
237+
return False
238+
171239
@staticmethod
172240
def _find_best_interface(
173241
interfaces: list[AgentInterface],

src/a2a/compat/v0_3/grpc_transport.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,10 @@ async def send_message(
135135
proto_utils.FromProto.task(resp_proto.task)
136136
)
137137
)
138-
if which == 'message':
138+
if which == 'msg':
139139
return a2a_pb2.SendMessageResponse(
140140
message=conversions.to_core_message(
141-
proto_utils.FromProto.message(resp_proto.message)
141+
proto_utils.FromProto.message(resp_proto.msg)
142142
)
143143
)
144144
return a2a_pb2.SendMessageResponse()

0 commit comments

Comments
 (0)