Skip to content

Commit 5b354e4

Browse files
authored
feat: handle tenant in Client (#758)
## Changes - Rest client transport `rest` prepends path with tenant if provided. - add `tenant_decorator.py` - add TenantTransportDecorator` to `tenant_decorator.py` which adds default tenant to requests in for provided ## Contributing - [x] Follow the [`CONTRIBUTING` Guide](https://github.com/a2aproject/a2a-python/blob/main/CONTRIBUTING.md). - [x] Make your Pull Request title in the <https://www.conventionalcommits.org/> specification. - Important Prefixes for [release-please](https://github.com/googleapis/release-please): - `fix:` which represents bug fixes, and correlates to a [SemVer](https://semver.org/) patch. - `feat:` represents a new feature, and correlates to a SemVer minor. - `feat!:`, or `fix!:`, `refactor!:`, etc., which represent a breaking change (indicated by the `!`) and will result in a SemVer major. - [x] Ensure the tests and linter pass (Run `bash scripts/format.sh` from the repository root to format) - [x] Appropriate docs were updated (if necessary) Fixes #672 🦕
1 parent ced3f99 commit 5b354e4

9 files changed

Lines changed: 1010 additions & 219 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ omit = [
176176
"*/__init__.py",
177177
"src/a2a/types/a2a_pb2.py",
178178
"src/a2a/types/a2a_pb2_grpc.py",
179+
"src/a2a/compat/*/*_pb2*.py",
179180
]
180181

181182
[tool.coverage.report]

src/a2a/client/client_factory.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from a2a.client.transports.base import ClientTransport
1515
from a2a.client.transports.jsonrpc import JsonRpcTransport
1616
from a2a.client.transports.rest import RestTransport
17+
from a2a.client.transports.tenant_decorator import TenantTransportDecorator
1718
from a2a.types.a2a_pb2 import (
1819
AgentCapabilities,
1920
AgentCard,
@@ -216,28 +217,27 @@ def create(
216217
TransportProtocol.JSONRPC
217218
]
218219
transport_protocol = None
219-
transport_url = None
220+
selected_interface = None
220221
if self._config.use_client_preference:
221222
for protocol_binding in client_set:
222-
supported_interface = next(
223+
selected_interface = next(
223224
(
224225
si
225226
for si in card.supported_interfaces
226227
if si.protocol_binding == protocol_binding
227228
),
228229
None,
229230
)
230-
if supported_interface:
231+
if selected_interface:
231232
transport_protocol = protocol_binding
232-
transport_url = supported_interface.url
233233
break
234234
else:
235235
for supported_interface in card.supported_interfaces:
236236
if supported_interface.protocol_binding in client_set:
237237
transport_protocol = supported_interface.protocol_binding
238-
transport_url = supported_interface.url
238+
selected_interface = supported_interface
239239
break
240-
if not transport_protocol or not transport_url:
240+
if not transport_protocol or not selected_interface:
241241
raise ValueError('no compatible transports found.')
242242
if transport_protocol not in self._registry:
243243
raise ValueError(f'no client available for {transport_protocol}')
@@ -252,9 +252,14 @@ def create(
252252
self._config.extensions = all_extensions
253253

254254
transport = self._registry[transport_protocol](
255-
card, transport_url, self._config, interceptors or []
255+
card, selected_interface.url, self._config, interceptors or []
256256
)
257257

258+
if selected_interface.tenant:
259+
transport = TenantTransportDecorator(
260+
transport, selected_interface.tenant
261+
)
262+
258263
return BaseClient(
259264
card,
260265
self._config,

src/a2a/client/transports/rest.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ async def send_message(
7979
request, context, extensions
8080
)
8181
response_data = await self._send_post_request(
82-
'/message:send', payload, modified_kwargs
82+
'/message:send', request.tenant, payload, modified_kwargs
8383
)
8484
response: SendMessageResponse = ParseDict(
8585
response_data, SendMessageResponse()
@@ -97,10 +97,10 @@ async def send_message_streaming(
9797
payload, modified_kwargs = await self._prepare_send_message(
9898
request, context, extensions
9999
)
100-
101100
async for event in self._send_stream_request(
102101
'POST',
103102
'/message:stream',
103+
request.tenant,
104104
http_kwargs=modified_kwargs,
105105
json=payload,
106106
):
@@ -130,6 +130,7 @@ async def get_task(
130130

131131
response_data = await self._send_get_request(
132132
f'/tasks/{request.id}',
133+
request.tenant,
133134
params,
134135
modified_kwargs,
135136
)
@@ -153,8 +154,10 @@ async def list_tasks(
153154
modified_kwargs,
154155
extensions if extensions is not None else self.extensions,
155156
)
157+
156158
response_data = await self._send_get_request(
157159
'/tasks',
160+
request.tenant,
158161
_model_to_query_params(request),
159162
modified_kwargs,
160163
)
@@ -181,8 +184,12 @@ async def cancel_task(
181184
modified_kwargs,
182185
context,
183186
)
187+
184188
response_data = await self._send_post_request(
185-
f'/tasks/{request.id}:cancel', payload, modified_kwargs
189+
f'/tasks/{request.id}:cancel',
190+
request.tenant,
191+
payload,
192+
modified_kwargs,
186193
)
187194
response: Task = ParseDict(response_data, Task())
188195
return response
@@ -203,8 +210,10 @@ async def create_task_push_notification_config(
203210
payload, modified_kwargs = await self._apply_interceptors(
204211
payload, modified_kwargs, context
205212
)
213+
206214
response_data = await self._send_post_request(
207215
f'/tasks/{request.task_id}/pushNotificationConfigs',
216+
request.tenant,
208217
payload,
209218
modified_kwargs,
210219
)
@@ -235,8 +244,10 @@ async def get_task_push_notification_config(
235244
del params['id']
236245
if 'task_id' in params:
237246
del params['task_id']
247+
238248
response_data = await self._send_get_request(
239249
f'/tasks/{request.task_id}/pushNotificationConfigs/{request.id}',
250+
request.tenant,
240251
params,
241252
modified_kwargs,
242253
)
@@ -265,8 +276,10 @@ async def list_task_push_notification_configs(
265276
)
266277
if 'task_id' in params:
267278
del params['task_id']
279+
268280
response_data = await self._send_get_request(
269281
f'/tasks/{request.task_id}/pushNotificationConfigs',
282+
request.tenant,
270283
params,
271284
modified_kwargs,
272285
)
@@ -297,8 +310,10 @@ async def delete_task_push_notification_config(
297310
del params['id']
298311
if 'task_id' in params:
299312
del params['task_id']
313+
300314
await self._send_delete_request(
301315
f'/tasks/{request.task_id}/pushNotificationConfigs/{request.id}',
316+
request.tenant,
302317
params,
303318
modified_kwargs,
304319
)
@@ -319,6 +334,7 @@ async def subscribe(
319334
async for event in self._send_stream_request(
320335
'GET',
321336
f'/tasks/{request.id}:subscribe',
337+
request.tenant,
322338
http_kwargs=modified_kwargs,
323339
):
324340
yield event
@@ -347,7 +363,7 @@ async def get_extended_agent_card(
347363
context,
348364
)
349365
response_data = await self._send_get_request(
350-
'/extendedAgentCard', {}, modified_kwargs
366+
'/extendedAgentCard', request.tenant, {}, modified_kwargs
351367
)
352368
response: AgentCard = ParseDict(response_data, AgentCard())
353369

@@ -363,6 +379,10 @@ async def close(self) -> None:
363379
"""Closes the httpx client."""
364380
await self.httpx_client.aclose()
365381

382+
def _get_path(self, base_path: str, tenant: str) -> str:
383+
"""Returns the full path, prepending the tenant if provided."""
384+
return f'/{tenant}{base_path}' if tenant else base_path
385+
366386
async def _apply_interceptors(
367387
self,
368388
request_payload: dict[str, Any],
@@ -425,16 +445,18 @@ async def _send_stream_request(
425445
self,
426446
method: str,
427447
target: str,
448+
tenant: str,
428449
http_kwargs: dict[str, Any] | None = None,
429450
**kwargs: Any,
430451
) -> AsyncGenerator[StreamResponse]:
431452
final_kwargs = dict(http_kwargs or {})
432453
final_kwargs.update(kwargs)
454+
path = self._get_path(target, tenant)
433455

434456
async for sse_data in send_http_stream_request(
435457
self.httpx_client,
436458
method,
437-
f'{self.url}{target}',
459+
f'{self.url}{path}',
438460
self._handle_http_error,
439461
**final_kwargs,
440462
):
@@ -449,13 +471,15 @@ async def _send_request(self, request: httpx.Request) -> dict[str, Any]:
449471
async def _send_post_request(
450472
self,
451473
target: str,
474+
tenant: str,
452475
rpc_request_payload: dict[str, Any],
453476
http_kwargs: dict[str, Any] | None = None,
454477
) -> dict[str, Any]:
478+
path = self._get_path(target, tenant)
455479
return await self._send_request(
456480
self.httpx_client.build_request(
457481
'POST',
458-
f'{self.url}{target}',
482+
f'{self.url}{path}',
459483
json=rpc_request_payload,
460484
**(http_kwargs or {}),
461485
)
@@ -464,13 +488,15 @@ async def _send_post_request(
464488
async def _send_get_request(
465489
self,
466490
target: str,
491+
tenant: str,
467492
query_params: dict[str, str],
468493
http_kwargs: dict[str, Any] | None = None,
469494
) -> dict[str, Any]:
495+
path = self._get_path(target, tenant)
470496
return await self._send_request(
471497
self.httpx_client.build_request(
472498
'GET',
473-
f'{self.url}{target}',
499+
f'{self.url}{path}',
474500
params=query_params,
475501
**(http_kwargs or {}),
476502
)
@@ -479,13 +505,15 @@ async def _send_get_request(
479505
async def _send_delete_request(
480506
self,
481507
target: str,
508+
tenant: str,
482509
query_params: dict[str, Any],
483510
http_kwargs: dict[str, Any] | None = None,
484511
) -> dict[str, Any]:
512+
path = self._get_path(target, tenant)
485513
return await self._send_request(
486514
self.httpx_client.build_request(
487515
'DELETE',
488-
f'{self.url}{target}',
516+
f'{self.url}{path}',
489517
params=query_params,
490518
**(http_kwargs or {}),
491519
)

0 commit comments

Comments
 (0)