Skip to content

Commit 164f919

Browse files
authored
feat(server, grpc): Implement tenant context propagation for gRPC requests. (#781)
## Changes - adds tenant propagation to ServerCallContext for gRPC requests in grpc_handler - added a unit tests `TestTenantExtraction` - moved test from `test_rest_tenant.py` to `test_rest_fastapi_app.py` and deleted empty `test_rest_tenant.py` file ## Contributing guide Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [x] Follow the [`CONTRIBUTING` Guide](https://github.com/a2aproject/a2a-python/blob/main/CONTRIBUTING.md). - [x] Make your Pull Request title in the <https://www.conventionalcommits.org/> specification. - Important Prefixes for [release-please](https://github.com/googleapis/release-please): - `fix:` which represents bug fixes, and correlates to a [SemVer](https://semver.org/) patch. - `feat:` represents a new feature, and correlates to a SemVer minor. - `feat!:`, or `fix!:`, `refactor!:`, etc., which represent a breaking change (indicated by the `!`) and will result in a SemVer major. - [ ] Ensure the tests and linter pass (Run `bash scripts/format.sh` from the repository root to format) - [x] Appropriate docs were updated (if necessary) Fixes #672 🦕
1 parent f124ddd commit 164f919

4 files changed

Lines changed: 429 additions & 208 deletions

File tree

src/a2a/server/request_handlers/grpc_handler.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from collections.abc import Callable
2020

21-
from google.protobuf import empty_pb2
21+
from google.protobuf import empty_pb2, message
2222

2323
import a2a.types.a2a_pb2_grpc as a2a_grpc
2424

@@ -142,7 +142,7 @@ async def SendMessage(
142142
"""
143143
try:
144144
# Construct the server context object
145-
server_context = self.context_builder.build(context)
145+
server_context = self._build_call_context(context, request)
146146
task_or_message = await self.request_handler.on_message_send(
147147
request, server_context
148148
)
@@ -177,7 +177,7 @@ async def SendStreamingMessage(
177177
(Task, Message, TaskStatusUpdateEvent, TaskArtifactUpdateEvent)
178178
or gRPC error responses if an A2AError is raised.
179179
"""
180-
server_context = self.context_builder.build(context)
180+
server_context = self._build_call_context(context, request)
181181
try:
182182
async for event in self.request_handler.on_message_send_stream(
183183
request, server_context
@@ -203,7 +203,7 @@ async def CancelTask(
203203
A `Task` object containing the updated Task or a gRPC error.
204204
"""
205205
try:
206-
server_context = self.context_builder.build(context)
206+
server_context = self._build_call_context(context, request)
207207
task = await self.request_handler.on_cancel_task(
208208
request, server_context
209209
)
@@ -236,7 +236,7 @@ async def SubscribeToTask(
236236
`StreamResponse` objects containing streaming events
237237
"""
238238
try:
239-
server_context = self.context_builder.build(context)
239+
server_context = self._build_call_context(context, request)
240240
async for event in self.request_handler.on_subscribe_to_task(
241241
request,
242242
server_context,
@@ -260,7 +260,7 @@ async def GetTaskPushNotificationConfig(
260260
A `TaskPushNotificationConfig` object containing the config.
261261
"""
262262
try:
263-
server_context = self.context_builder.build(context)
263+
server_context = self._build_call_context(context, request)
264264
return (
265265
await self.request_handler.on_get_task_push_notification_config(
266266
request,
@@ -296,7 +296,7 @@ async def CreateTaskPushNotificationConfig(
296296
(due to the `@validate` decorator).
297297
"""
298298
try:
299-
server_context = self.context_builder.build(context)
299+
server_context = self._build_call_context(context, request)
300300
return await self.request_handler.on_create_task_push_notification_config(
301301
request,
302302
server_context,
@@ -320,7 +320,7 @@ async def ListTaskPushNotificationConfigs(
320320
A `ListTaskPushNotificationConfigsResponse` object containing the configs.
321321
"""
322322
try:
323-
server_context = self.context_builder.build(context)
323+
server_context = self._build_call_context(context, request)
324324
return await self.request_handler.on_list_task_push_notification_configs(
325325
request,
326326
server_context,
@@ -344,7 +344,7 @@ async def DeleteTaskPushNotificationConfig(
344344
An empty `Empty` object.
345345
"""
346346
try:
347-
server_context = self.context_builder.build(context)
347+
server_context = self._build_call_context(context, request)
348348
await self.request_handler.on_delete_task_push_notification_config(
349349
request,
350350
server_context,
@@ -369,7 +369,7 @@ async def GetTask(
369369
A `Task` object.
370370
"""
371371
try:
372-
server_context = self.context_builder.build(context)
372+
server_context = self._build_call_context(context, request)
373373
task = await self.request_handler.on_get_task(
374374
request, server_context
375375
)
@@ -395,7 +395,7 @@ async def ListTasks(
395395
A `ListTasksResponse` object.
396396
"""
397397
try:
398-
server_context = self.context_builder.build(context)
398+
server_context = self._build_call_context(context, request)
399399
return await self.request_handler.on_list_tasks(
400400
request, server_context
401401
)
@@ -442,3 +442,12 @@ def _set_extension_metadata(
442442
for e in sorted(server_context.activated_extensions)
443443
]
444444
)
445+
446+
def _build_call_context(
447+
self,
448+
context: grpc.aio.ServicerContext,
449+
request: message.Message,
450+
) -> ServerCallContext:
451+
server_context = self.context_builder.build(context)
452+
server_context.tenant = getattr(request, 'tenant', '')
453+
return server_context

tests/server/apps/rest/test_rest_fastapi_app.py

Lines changed: 155 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,20 @@
99
from google.protobuf import json_format
1010
from httpx import ASGITransport, AsyncClient
1111

12-
from a2a.types import a2a_pb2
1312
from a2a.server.apps.rest import fastapi_app, rest_adapter
1413
from a2a.server.apps.rest.fastapi_app import A2ARESTFastAPIApplication
1514
from a2a.server.apps.rest.rest_adapter import RESTAdapter
1615
from a2a.server.request_handlers.request_handler import RequestHandler
16+
from a2a.types import a2a_pb2
1717
from a2a.types.a2a_pb2 import (
1818
AgentCard,
19+
ListTaskPushNotificationConfigsResponse,
20+
ListTasksResponse,
1921
Message,
2022
Part,
2123
Role,
2224
Task,
25+
TaskPushNotificationConfig,
2326
TaskState,
2427
TaskStatus,
2528
)
@@ -36,6 +39,8 @@ async def agent_card() -> AgentCard:
3639
# Mock the capabilities object with streaming disabled
3740
mock_capabilities = MagicMock()
3841
mock_capabilities.streaming = False
42+
mock_capabilities.push_notifications = True
43+
mock_capabilities.extended_agent_card = True
3944
mock_agent_card.capabilities = mock_capabilities
4045

4146
return mock_agent_card
@@ -60,6 +65,11 @@ async def request_handler() -> RequestHandler:
6065
return MagicMock(spec=RequestHandler)
6166

6267

68+
@pytest.fixture
69+
async def extended_card_modifier() -> MagicMock | None:
70+
return None
71+
72+
6373
@pytest.fixture
6474
async def streaming_app(
6575
streaming_agent_card: AgentCard, request_handler: RequestHandler
@@ -81,13 +91,17 @@ async def streaming_client(streaming_app: FastAPI) -> AsyncClient:
8191

8292
@pytest.fixture
8393
async def app(
84-
agent_card: AgentCard, request_handler: RequestHandler
94+
agent_card: AgentCard,
95+
request_handler: RequestHandler,
96+
extended_card_modifier: MagicMock | None,
8597
) -> FastAPI:
8698
"""Builds the FastAPI application for testing."""
8799

88-
return A2ARESTFastAPIApplication(agent_card, request_handler).build(
89-
agent_card_url='/well-known/agent.json', rpc_url=''
90-
)
100+
return A2ARESTFastAPIApplication(
101+
agent_card,
102+
request_handler,
103+
extended_card_modifier=extended_card_modifier,
104+
).build(agent_card_url='/well-known/agent.json', rpc_url='')
91105

92106

93107
@pytest.fixture
@@ -396,5 +410,141 @@ async def test_send_message_rejected_task(
396410
assert expected_response == actual_response
397411

398412

413+
@pytest.mark.anyio
414+
class TestTenantExtraction:
415+
@pytest.fixture(autouse=True)
416+
def configure_mocks(self, request_handler: MagicMock) -> None:
417+
# Setup default return values for all handlers
418+
request_handler.on_message_send.return_value = Message(
419+
message_id='test',
420+
role=Role.ROLE_AGENT,
421+
parts=[Part(text='response message')],
422+
)
423+
request_handler.on_cancel_task.return_value = Task(id='1')
424+
request_handler.on_get_task.return_value = Task(id='1')
425+
request_handler.on_list_tasks.return_value = ListTasksResponse()
426+
request_handler.on_create_task_push_notification_config.return_value = (
427+
TaskPushNotificationConfig()
428+
)
429+
request_handler.on_get_task_push_notification_config.return_value = (
430+
TaskPushNotificationConfig()
431+
)
432+
request_handler.on_list_task_push_notification_configs.return_value = (
433+
ListTaskPushNotificationConfigsResponse()
434+
)
435+
request_handler.on_delete_task_push_notification_config.return_value = (
436+
None
437+
)
438+
439+
@pytest.fixture
440+
def extended_card_modifier(self) -> MagicMock:
441+
modifier = MagicMock()
442+
modifier.return_value = AgentCard()
443+
return modifier
444+
445+
@pytest.mark.parametrize(
446+
'path_template, method, handler_method_name, json_body',
447+
[
448+
('/message:send', 'POST', 'on_message_send', {'message': {}}),
449+
('/tasks/1:cancel', 'POST', 'on_cancel_task', None),
450+
('/tasks/1', 'GET', 'on_get_task', None),
451+
('/tasks', 'GET', 'on_list_tasks', None),
452+
(
453+
'/tasks/1/pushNotificationConfigs/p1',
454+
'GET',
455+
'on_get_task_push_notification_config',
456+
None,
457+
),
458+
(
459+
'/tasks/1/pushNotificationConfigs/p1',
460+
'DELETE',
461+
'on_delete_task_push_notification_config',
462+
None,
463+
),
464+
(
465+
'/tasks/1/pushNotificationConfigs',
466+
'POST',
467+
'on_create_task_push_notification_config',
468+
{'config': {'url': 'http://foo'}},
469+
),
470+
(
471+
'/tasks/1/pushNotificationConfigs',
472+
'GET',
473+
'on_list_task_push_notification_configs',
474+
None,
475+
),
476+
],
477+
)
478+
async def test_tenant_extraction_parametrized(
479+
self,
480+
client: AsyncClient,
481+
request_handler: MagicMock,
482+
path_template: str,
483+
method: str,
484+
handler_method_name: str,
485+
json_body: dict | None,
486+
) -> None:
487+
"""Test tenant extraction for standard REST endpoints."""
488+
# Test with tenant
489+
tenant = 'my-tenant'
490+
tenant_path = f'/{tenant}{path_template}'
491+
492+
response = await client.request(method, tenant_path, json=json_body)
493+
response.raise_for_status()
494+
495+
# Verify handler call
496+
handler_mock = getattr(request_handler, handler_method_name)
497+
498+
assert handler_mock.called
499+
args, _ = handler_mock.call_args
500+
context = args[1]
501+
assert context.tenant == tenant
502+
503+
# Reset mock for non-tenant test
504+
handler_mock.reset_mock()
505+
506+
# Test without tenant
507+
response = await client.request(method, path_template, json=json_body)
508+
response.raise_for_status()
509+
510+
# Verify context.tenant == ""
511+
assert handler_mock.called
512+
args, _ = handler_mock.call_args
513+
context = args[1]
514+
assert context.tenant == ''
515+
516+
async def test_tenant_extraction_extended_agent_card(
517+
self,
518+
client: AsyncClient,
519+
extended_card_modifier: MagicMock,
520+
) -> None:
521+
"""Test tenant extraction specifically for extendedAgentCard endpoint."""
522+
# Test with tenant
523+
tenant = 'my-tenant'
524+
tenant_path = f'/{tenant}/extendedAgentCard'
525+
526+
response = await client.get(tenant_path)
527+
response.raise_for_status()
528+
529+
# Verify extended_card_modifier called with tenant context
530+
assert extended_card_modifier.called
531+
args, _ = extended_card_modifier.call_args
532+
context = args[1]
533+
assert context.tenant == tenant
534+
535+
# Reset mock for non-tenant test
536+
extended_card_modifier.reset_mock()
537+
538+
# Test without tenant
539+
response = await client.get('/extendedAgentCard')
540+
response.raise_for_status()
541+
542+
# Verify extended_card_modifier called with empty tenant context
543+
assert extended_card_modifier.called
544+
args, _ = extended_card_modifier.call_args
545+
context = args[1]
546+
assert context.tenant == ''
547+
548+
399549
if __name__ == '__main__':
400550
pytest.main([__file__])

0 commit comments

Comments
 (0)