Skip to content

Commit 4771b5a

Browse files
authored
feat(rest): add tenant support to rest (#773)
## Changes - add tenant to ServerCallContext - add tenant-prefixed routes for REST endpoints - introduce tenant extraction from REST API paths ## Contribution guide - [x] Follow the [`CONTRIBUTING` Guide](https://github.com/a2aproject/a2a-python/blob/main/CONTRIBUTING.md). - [x] Make your Pull Request title in the <https://www.conventionalcommits.org/> specification. - Important Prefixes for [release-please](https://github.com/googleapis/release-please): - `fix:` which represents bug fixes, and correlates to a [SemVer](https://semver.org/) patch. - `feat:` represents a new feature, and correlates to a SemVer minor. - `feat!:`, or `fix!:`, `refactor!:`, etc., which represent a breaking change (indicated by the `!`) and will result in a SemVer major. - [ ] Ensure the tests and linter pass (Run `bash scripts/format.sh` from the repository root to format) - [x] Appropriate docs were updated (if necessary) Fixes #672 🦕
1 parent 5955197 commit 4771b5a

4 files changed

Lines changed: 215 additions & 5 deletions

File tree

src/a2a/server/agent_execution/context.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,11 @@ def add_activated_extension(self, uri: str) -> None:
160160
if self._call_context:
161161
self._call_context.activated_extensions.add(uri)
162162

163+
@property
164+
def tenant(self) -> str:
165+
"""The tenant associated with this request."""
166+
return self._call_context.tenant if self._call_context else ''
167+
163168
@property
164169
def requested_extensions(self) -> set[str]:
165170
"""Extensions that the client requested to activate."""

src/a2a/server/apps/rest/rest_adapter.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ async def _handle_request(
110110
method: Callable[[Request, ServerCallContext], Awaitable[Any]],
111111
request: Request,
112112
) -> Response:
113-
call_context = self._context_builder.build(request)
113+
call_context = self._build_call_context(request)
114+
114115
response = await method(request, call_context)
115116
return JSONResponse(content=response)
116117

@@ -130,7 +131,7 @@ async def _handle_streaming_request(
130131
message=f'Failed to pre-consume request body: {e}'
131132
) from e
132133

133-
call_context = self._context_builder.build(request)
134+
call_context = self._build_call_context(request)
134135

135136
async def event_generator(
136137
stream: AsyncIterable[Any],
@@ -185,7 +186,7 @@ async def handle_authenticated_agent_card(
185186
card_to_serve = self.agent_card
186187

187188
if self.extended_card_modifier:
188-
context = self._context_builder.build(request)
189+
context = self._build_call_context(request)
189190
card_to_serve = await maybe_await(
190191
self.extended_card_modifier(card_to_serve, context)
191192
)
@@ -205,7 +206,7 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]:
205206
A dictionary where each key is a tuple of (path, http_method) and
206207
the value is the callable handler for that route.
207208
"""
208-
routes: dict[tuple[str, str], Callable[[Request], Any]] = {
209+
base_routes: dict[tuple[str, str], Callable[[Request], Any]] = {
209210
('/message:send', 'POST'): functools.partial(
210211
self._handle_request, self.handler.on_message_send
211212
),
@@ -251,9 +252,22 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]:
251252
self._handle_request, self.handler.list_tasks
252253
),
253254
}
255+
254256
if self.agent_card.capabilities.extended_agent_card:
255-
routes[('/extendedAgentCard', 'GET')] = functools.partial(
257+
base_routes[('/extendedAgentCard', 'GET')] = functools.partial(
256258
self._handle_request, self.handle_authenticated_agent_card
257259
)
258260

261+
routes: dict[tuple[str, str], Callable[[Request], Any]] = {
262+
(p, method): handler
263+
for (path, method), handler in base_routes.items()
264+
for p in (path, f'/{{tenant}}{path}')
265+
}
266+
259267
return routes
268+
269+
def _build_call_context(self, request: Request) -> ServerCallContext:
270+
call_context = self._context_builder.build(request)
271+
if 'tenant' in request.path_params:
272+
call_context.tenant = request.path_params['tenant']
273+
return call_context

src/a2a/server/context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,6 @@ class ServerCallContext(BaseModel):
2121

2222
state: State = Field(default={})
2323
user: User = Field(default=UnauthenticatedUser())
24+
tenant: str = Field(default='')
2425
requested_extensions: set[str] = Field(default_factory=set)
2526
activated_extensions: set[str] = Field(default_factory=set)
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
import pytest
2+
from unittest.mock import MagicMock
3+
from fastapi import FastAPI
4+
from httpx import ASGITransport, AsyncClient
5+
6+
from a2a.server.apps.rest.fastapi_app import A2ARESTFastAPIApplication
7+
from a2a.server.request_handlers.request_handler import RequestHandler
8+
from a2a.types.a2a_pb2 import (
9+
AgentCard,
10+
ListTaskPushNotificationConfigsResponse,
11+
ListTasksResponse,
12+
Message,
13+
Part,
14+
Role,
15+
Task,
16+
TaskPushNotificationConfig,
17+
)
18+
19+
20+
@pytest.fixture
21+
async def agent_card() -> AgentCard:
22+
mock_agent_card = MagicMock(spec=AgentCard)
23+
mock_agent_card.url = 'http://mockurl.com'
24+
mock_capabilities = MagicMock()
25+
mock_capabilities.streaming = False
26+
mock_capabilities.push_notifications = True
27+
mock_capabilities.extended_agent_card = True
28+
mock_agent_card.capabilities = mock_capabilities
29+
return mock_agent_card
30+
31+
32+
@pytest.fixture
33+
async def request_handler() -> RequestHandler:
34+
handler = MagicMock(spec=RequestHandler)
35+
# Setup default return values for all handlers
36+
handler.on_message_send.return_value = Message(
37+
message_id='test',
38+
role=Role.ROLE_AGENT,
39+
parts=[Part(text='response message')],
40+
)
41+
handler.on_cancel_task.return_value = Task(id='1')
42+
handler.on_get_task.return_value = Task(id='1')
43+
handler.on_list_tasks.return_value = ListTasksResponse()
44+
handler.on_create_task_push_notification_config.return_value = (
45+
TaskPushNotificationConfig()
46+
)
47+
handler.on_get_task_push_notification_config.return_value = (
48+
TaskPushNotificationConfig()
49+
)
50+
handler.on_list_task_push_notification_configs.return_value = (
51+
ListTaskPushNotificationConfigsResponse()
52+
)
53+
handler.on_delete_task_push_notification_config.return_value = None
54+
return handler
55+
56+
57+
@pytest.fixture
58+
async def extended_card_modifier() -> MagicMock:
59+
modifier = MagicMock()
60+
modifier.return_value = AgentCard()
61+
return modifier
62+
63+
64+
@pytest.fixture
65+
async def app(
66+
agent_card: AgentCard,
67+
request_handler: RequestHandler,
68+
extended_card_modifier: MagicMock,
69+
) -> FastAPI:
70+
return A2ARESTFastAPIApplication(
71+
agent_card,
72+
request_handler,
73+
extended_card_modifier=extended_card_modifier,
74+
).build(agent_card_url='/well-known/agent.json', rpc_url='')
75+
76+
77+
@pytest.fixture
78+
async def client(app: FastAPI) -> AsyncClient:
79+
return AsyncClient(transport=ASGITransport(app=app), base_url='http://test')
80+
81+
82+
@pytest.mark.parametrize(
83+
'path_template, method, handler_method_name, json_body',
84+
[
85+
('/message:send', 'POST', 'on_message_send', {'message': {}}),
86+
('/tasks/1:cancel', 'POST', 'on_cancel_task', None),
87+
('/tasks/1', 'GET', 'on_get_task', None),
88+
('/tasks', 'GET', 'on_list_tasks', None),
89+
(
90+
'/tasks/1/pushNotificationConfigs/p1',
91+
'GET',
92+
'on_get_task_push_notification_config',
93+
None,
94+
),
95+
(
96+
'/tasks/1/pushNotificationConfigs/p1',
97+
'DELETE',
98+
'on_delete_task_push_notification_config',
99+
None,
100+
),
101+
(
102+
'/tasks/1/pushNotificationConfigs',
103+
'POST',
104+
'on_create_task_push_notification_config',
105+
{'config': {'url': 'http://foo'}},
106+
),
107+
(
108+
'/tasks/1/pushNotificationConfigs',
109+
'GET',
110+
'on_list_task_push_notification_configs',
111+
None,
112+
),
113+
],
114+
)
115+
@pytest.mark.anyio
116+
async def test_tenant_extraction_parametrized(
117+
client: AsyncClient,
118+
request_handler: MagicMock,
119+
extended_card_modifier: MagicMock,
120+
path_template: str,
121+
method: str,
122+
handler_method_name: str,
123+
json_body: dict | None,
124+
) -> None:
125+
"""Test tenant extraction for standard REST endpoints."""
126+
# Test with tenant
127+
tenant = 'my-tenant'
128+
tenant_path = f'/{tenant}{path_template}'
129+
130+
response = await client.request(method, tenant_path, json=json_body)
131+
response.raise_for_status()
132+
133+
# Verify handler call
134+
handler_mock = getattr(request_handler, handler_method_name)
135+
136+
assert handler_mock.called
137+
args, _ = handler_mock.call_args
138+
context = args[1]
139+
assert context.tenant == tenant
140+
141+
# Reset mock for non-tenant test
142+
handler_mock.reset_mock()
143+
144+
# Test without tenant
145+
response = await client.request(method, path_template, json=json_body)
146+
response.raise_for_status()
147+
148+
# Verify context.tenant == ""
149+
assert handler_mock.called
150+
args, _ = handler_mock.call_args
151+
context = args[1]
152+
assert context.tenant == ''
153+
154+
155+
@pytest.mark.anyio
156+
async def test_tenant_extraction_extended_agent_card(
157+
client: AsyncClient,
158+
extended_card_modifier: MagicMock,
159+
) -> None:
160+
"""Test tenant extraction specifically for extendedAgentCard endpoint.
161+
162+
This verifies that `extended_card_modifier` receives the correct context
163+
including the tenant, confirming that `_build_call_context` is used correctly.
164+
"""
165+
# Test with tenant
166+
tenant = 'my-tenant'
167+
tenant_path = f'/{tenant}/extendedAgentCard'
168+
169+
response = await client.get(tenant_path)
170+
response.raise_for_status()
171+
172+
# Verify extended_card_modifier called with tenant context
173+
assert extended_card_modifier.called
174+
args, _ = extended_card_modifier.call_args
175+
# args[0] is card_to_serve, args[1] is context
176+
context = args[1]
177+
assert context.tenant == tenant
178+
179+
# Reset mock for non-tenant test
180+
extended_card_modifier.reset_mock()
181+
182+
# Test without tenant
183+
response = await client.get('/extendedAgentCard')
184+
response.raise_for_status()
185+
186+
# Verify extended_card_modifier called with empty tenant context
187+
assert extended_card_modifier.called
188+
args, _ = extended_card_modifier.call_args
189+
context = args[1]
190+
assert context.tenant == ''

0 commit comments

Comments
 (0)