Skip to content

Commit 63e9d74

Browse files
committed
Merge remote-tracking branch 'origin/1.0-dev' into ishymko/763-client-timeouts
2 parents a6dbc51 + 5955197 commit 63e9d74

12 files changed

Lines changed: 1014 additions & 224 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ omit = [
176176
"*/__init__.py",
177177
"src/a2a/types/a2a_pb2.py",
178178
"src/a2a/types/a2a_pb2_grpc.py",
179+
"src/a2a/compat/*/*_pb2*.py",
179180
]
180181

181182
[tool.coverage.report]

src/a2a/client/client_factory.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from a2a.client.transports.base import ClientTransport
1515
from a2a.client.transports.jsonrpc import JsonRpcTransport
1616
from a2a.client.transports.rest import RestTransport
17+
from a2a.client.transports.tenant_decorator import TenantTransportDecorator
1718
from a2a.types.a2a_pb2 import (
1819
AgentCapabilities,
1920
AgentCard,
@@ -216,28 +217,27 @@ def create(
216217
TransportProtocol.JSONRPC
217218
]
218219
transport_protocol = None
219-
transport_url = None
220+
selected_interface = None
220221
if self._config.use_client_preference:
221222
for protocol_binding in client_set:
222-
supported_interface = next(
223+
selected_interface = next(
223224
(
224225
si
225226
for si in card.supported_interfaces
226227
if si.protocol_binding == protocol_binding
227228
),
228229
None,
229230
)
230-
if supported_interface:
231+
if selected_interface:
231232
transport_protocol = protocol_binding
232-
transport_url = supported_interface.url
233233
break
234234
else:
235235
for supported_interface in card.supported_interfaces:
236236
if supported_interface.protocol_binding in client_set:
237237
transport_protocol = supported_interface.protocol_binding
238-
transport_url = supported_interface.url
238+
selected_interface = supported_interface
239239
break
240-
if not transport_protocol or not transport_url:
240+
if not transport_protocol or not selected_interface:
241241
raise ValueError('no compatible transports found.')
242242
if transport_protocol not in self._registry:
243243
raise ValueError(f'no client available for {transport_protocol}')
@@ -252,9 +252,14 @@ def create(
252252
self._config.extensions = all_extensions
253253

254254
transport = self._registry[transport_protocol](
255-
card, transport_url, self._config, interceptors or []
255+
card, selected_interface.url, self._config, interceptors or []
256256
)
257257

258+
if selected_interface.tenant:
259+
transport = TenantTransportDecorator(
260+
transport, selected_interface.tenant
261+
)
262+
258263
return BaseClient(
259264
card,
260265
self._config,

src/a2a/client/transports/grpc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,6 @@ async def _call_grpc_stream(
345345
)
346346
while True:
347347
response = await stream.read()
348-
if response == grpc.aio.EOF:
348+
if response == grpc.aio.EOF: # pyright: ignore[reportAttributeAccessIssue]
349349
break
350350
yield response

src/a2a/client/transports/jsonrpc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -458,11 +458,11 @@ async def _apply_interceptors(
458458

459459
def _get_http_args(
460460
self, context: ClientCallContext | None
461-
) -> dict[str, Any] | None:
461+
) -> dict[str, Any]:
462462
http_kwargs: dict[str, Any] = {}
463463
if context and context.timeout is not None:
464464
http_kwargs['timeout'] = httpx.Timeout(context.timeout)
465-
return http_kwargs if http_kwargs else None
465+
return http_kwargs
466466

467467
def _create_jsonrpc_error(self, error_dict: dict[str, Any]) -> Exception:
468468
"""Creates the appropriate A2AError from a JSON-RPC error dictionary."""

src/a2a/client/transports/rest.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ async def send_message(
7979
request, context, extensions
8080
)
8181
response_data = await self._send_post_request(
82-
'/message:send', payload, modified_kwargs
82+
'/message:send', request.tenant, payload, modified_kwargs
8383
)
8484
response: SendMessageResponse = ParseDict(
8585
response_data, SendMessageResponse()
@@ -97,10 +97,10 @@ async def send_message_streaming(
9797
payload, modified_kwargs = await self._prepare_send_message(
9898
request, context, extensions
9999
)
100-
101100
async for event in self._send_stream_request(
102101
'POST',
103102
'/message:stream',
103+
request.tenant,
104104
http_kwargs=modified_kwargs,
105105
json=payload,
106106
):
@@ -130,6 +130,7 @@ async def get_task(
130130

131131
response_data = await self._send_get_request(
132132
f'/tasks/{request.id}',
133+
request.tenant,
133134
params,
134135
modified_kwargs,
135136
)
@@ -153,8 +154,10 @@ async def list_tasks(
153154
modified_kwargs,
154155
extensions if extensions is not None else self.extensions,
155156
)
157+
156158
response_data = await self._send_get_request(
157159
'/tasks',
160+
request.tenant,
158161
_model_to_query_params(request),
159162
modified_kwargs,
160163
)
@@ -181,8 +184,12 @@ async def cancel_task(
181184
modified_kwargs,
182185
context,
183186
)
187+
184188
response_data = await self._send_post_request(
185-
f'/tasks/{request.id}:cancel', payload, modified_kwargs
189+
f'/tasks/{request.id}:cancel',
190+
request.tenant,
191+
payload,
192+
modified_kwargs,
186193
)
187194
response: Task = ParseDict(response_data, Task())
188195
return response
@@ -203,8 +210,10 @@ async def create_task_push_notification_config(
203210
payload, modified_kwargs = await self._apply_interceptors(
204211
payload, modified_kwargs, context
205212
)
213+
206214
response_data = await self._send_post_request(
207215
f'/tasks/{request.task_id}/pushNotificationConfigs',
216+
request.tenant,
208217
payload,
209218
modified_kwargs,
210219
)
@@ -235,8 +244,10 @@ async def get_task_push_notification_config(
235244
del params['id']
236245
if 'task_id' in params:
237246
del params['task_id']
247+
238248
response_data = await self._send_get_request(
239249
f'/tasks/{request.task_id}/pushNotificationConfigs/{request.id}',
250+
request.tenant,
240251
params,
241252
modified_kwargs,
242253
)
@@ -265,8 +276,10 @@ async def list_task_push_notification_configs(
265276
)
266277
if 'task_id' in params:
267278
del params['task_id']
279+
268280
response_data = await self._send_get_request(
269281
f'/tasks/{request.task_id}/pushNotificationConfigs',
282+
request.tenant,
270283
params,
271284
modified_kwargs,
272285
)
@@ -297,8 +310,10 @@ async def delete_task_push_notification_config(
297310
del params['id']
298311
if 'task_id' in params:
299312
del params['task_id']
313+
300314
await self._send_delete_request(
301315
f'/tasks/{request.task_id}/pushNotificationConfigs/{request.id}',
316+
request.tenant,
302317
params,
303318
modified_kwargs,
304319
)
@@ -319,6 +334,7 @@ async def subscribe(
319334
async for event in self._send_stream_request(
320335
'GET',
321336
f'/tasks/{request.id}:subscribe',
337+
request.tenant,
322338
http_kwargs=modified_kwargs,
323339
):
324340
yield event
@@ -347,7 +363,7 @@ async def get_extended_agent_card(
347363
context,
348364
)
349365
response_data = await self._send_get_request(
350-
'/extendedAgentCard', {}, modified_kwargs
366+
'/extendedAgentCard', request.tenant, {}, modified_kwargs
351367
)
352368
response: AgentCard = ParseDict(response_data, AgentCard())
353369

@@ -363,6 +379,10 @@ async def close(self) -> None:
363379
"""Closes the httpx client."""
364380
await self.httpx_client.aclose()
365381

382+
def _get_path(self, base_path: str, tenant: str) -> str:
383+
"""Returns the full path, prepending the tenant if provided."""
384+
return f'/{tenant}{base_path}' if tenant else base_path
385+
366386
async def _apply_interceptors(
367387
self,
368388
request_payload: dict[str, Any],
@@ -376,7 +396,7 @@ async def _apply_interceptors(
376396

377397
def _get_http_args(
378398
self, context: ClientCallContext | None
379-
) -> dict[str, Any] | None:
399+
) -> dict[str, Any]:
380400
http_kwargs: dict[str, Any] = {}
381401
if context and context.timeout is not None:
382402
http_kwargs['timeout'] = httpx.Timeout(context.timeout)
@@ -428,16 +448,18 @@ async def _send_stream_request(
428448
self,
429449
method: str,
430450
target: str,
451+
tenant: str,
431452
http_kwargs: dict[str, Any] | None = None,
432453
**kwargs: Any,
433454
) -> AsyncGenerator[StreamResponse]:
434455
final_kwargs = dict(http_kwargs or {})
435456
final_kwargs.update(kwargs)
457+
path = self._get_path(target, tenant)
436458

437459
async for sse_data in send_http_stream_request(
438460
self.httpx_client,
439461
method,
440-
f'{self.url}{target}',
462+
f'{self.url}{path}',
441463
self._handle_http_error,
442464
**final_kwargs,
443465
):
@@ -452,13 +474,15 @@ async def _send_request(self, request: httpx.Request) -> dict[str, Any]:
452474
async def _send_post_request(
453475
self,
454476
target: str,
477+
tenant: str,
455478
rpc_request_payload: dict[str, Any],
456479
http_kwargs: dict[str, Any] | None = None,
457480
) -> dict[str, Any]:
481+
path = self._get_path(target, tenant)
458482
return await self._send_request(
459483
self.httpx_client.build_request(
460484
'POST',
461-
f'{self.url}{target}',
485+
f'{self.url}{path}',
462486
json=rpc_request_payload,
463487
**(http_kwargs or {}),
464488
)
@@ -467,13 +491,15 @@ async def _send_post_request(
467491
async def _send_get_request(
468492
self,
469493
target: str,
494+
tenant: str,
470495
query_params: dict[str, str],
471496
http_kwargs: dict[str, Any] | None = None,
472497
) -> dict[str, Any]:
498+
path = self._get_path(target, tenant)
473499
return await self._send_request(
474500
self.httpx_client.build_request(
475501
'GET',
476-
f'{self.url}{target}',
502+
f'{self.url}{path}',
477503
params=query_params,
478504
**(http_kwargs or {}),
479505
)
@@ -482,13 +508,15 @@ async def _send_get_request(
482508
async def _send_delete_request(
483509
self,
484510
target: str,
511+
tenant: str,
485512
query_params: dict[str, Any],
486513
http_kwargs: dict[str, Any] | None = None,
487514
) -> dict[str, Any]:
515+
path = self._get_path(target, tenant)
488516
return await self._send_request(
489517
self.httpx_client.build_request(
490518
'DELETE',
491-
f'{self.url}{target}',
519+
f'{self.url}{path}',
492520
params=query_params,
493521
**(http_kwargs or {}),
494522
)

0 commit comments

Comments
 (0)