Skip to content

Commit 72a330d

Browse files
authored
feat(server, json-rpc): Implement tenant context propagation for JSON-RPC requests. (#778)
# Description - adds tenant propagation to ServerCallContext for JSON-RPC requests - adds unit and integration tests ## Contributing guide - [x] Follow the [`CONTRIBUTING` Guide](https://github.com/a2aproject/a2a-python/blob/main/CONTRIBUTING.md). - [x] Make your Pull Request title in the <https://www.conventionalcommits.org/> specification. - Important Prefixes for [release-please](https://github.com/googleapis/release-please): - `fix:` which represents bug fixes, and correlates to a [SemVer](https://semver.org/) patch. - `feat:` represents a new feature, and correlates to a SemVer minor. - `feat!:`, or `fix!:`, `refactor!:`, etc., which represent a breaking change (indicated by the `!`) and will result in a SemVer major. - [x] Ensure the tests and linter pass (Run `bash scripts/format.sh` from the repository root to format) - [x] Appropriate docs were updated (if necessary) Fixes #672 🦕
1 parent 80d827a commit 72a330d

3 files changed

Lines changed: 306 additions & 184 deletions

File tree

src/a2a/server/apps/jsonrpc/jsonrpc_app.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,7 @@ async def _handle_requests(self, request: Request) -> Response: # noqa: PLR0911
385385

386386
# 3) Build call context and wrap the request for downstream handling
387387
call_context = self._context_builder.build(request)
388+
call_context.tenant = getattr(specific_request, 'tenant', '')
388389
call_context.state['method'] = method
389390
call_context.state['request_id'] = request_id
390391

tests/integration/test_tenant.py

Lines changed: 220 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -1,160 +1,245 @@
11
import pytest
22
from unittest.mock import AsyncMock, patch, MagicMock
33
import httpx
4+
from httpx import ASGITransport, AsyncClient
5+
46
from a2a.types.a2a_pb2 import (
57
AgentCard,
68
AgentInterface,
79
SendMessageRequest,
810
Message,
911
GetTaskRequest,
1012
AgentCapabilities,
13+
ListTasksRequest,
14+
ListTasksResponse,
15+
Task,
1116
)
1217
from a2a.client.transports import RestTransport, JsonRpcTransport, GrpcTransport
1318
from a2a.client.transports.tenant_decorator import TenantTransportDecorator
1419
from a2a.client import ClientConfig, ClientFactory
1520
from a2a.utils.constants import TransportProtocol
1621

22+
from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication
23+
from a2a.server.request_handlers.request_handler import RequestHandler
24+
from a2a.server.context import ServerCallContext
25+
26+
27+
class TestTenantDecorator:
28+
@pytest.fixture
29+
def agent_card(self):
30+
return AgentCard(
31+
supported_interfaces=[
32+
AgentInterface(
33+
url='http://example.com/rest',
34+
protocol_binding=TransportProtocol.HTTP_JSON,
35+
tenant='tenant-1',
36+
),
37+
AgentInterface(
38+
url='http://example.com/jsonrpc',
39+
protocol_binding=TransportProtocol.JSONRPC,
40+
tenant='tenant-2',
41+
),
42+
AgentInterface(
43+
url='http://example.com/grpc',
44+
protocol_binding=TransportProtocol.GRPC,
45+
tenant='tenant-3',
46+
),
47+
],
48+
capabilities=AgentCapabilities(streaming=True),
49+
)
1750

18-
@pytest.fixture
19-
def agent_card():
20-
return AgentCard(
21-
supported_interfaces=[
22-
AgentInterface(
23-
url='http://example.com/rest',
24-
protocol_binding=TransportProtocol.HTTP_JSON,
25-
tenant='tenant-1',
26-
),
27-
AgentInterface(
28-
url='http://example.com/jsonrpc',
29-
protocol_binding=TransportProtocol.JSONRPC,
30-
tenant='tenant-2',
31-
),
32-
AgentInterface(
33-
url='http://example.com/grpc',
34-
protocol_binding=TransportProtocol.GRPC,
35-
tenant='tenant-3',
36-
),
37-
],
38-
capabilities=AgentCapabilities(streaming=True),
39-
)
40-
41-
42-
@pytest.mark.asyncio
43-
async def test_tenant_decorator_rest(agent_card):
44-
mock_httpx = AsyncMock(spec=httpx.AsyncClient)
45-
mock_httpx.build_request.return_value = MagicMock()
46-
mock_httpx.send.return_value = MagicMock(
47-
status_code=200, json=lambda: {'message': {}}
48-
)
49-
50-
config = ClientConfig(
51-
httpx_client=mock_httpx,
52-
supported_protocol_bindings=[TransportProtocol.HTTP_JSON],
53-
)
54-
factory = ClientFactory(config)
55-
client = factory.create(agent_card)
56-
57-
assert isinstance(client._transport, TenantTransportDecorator)
58-
assert client._transport._tenant == 'tenant-1'
59-
60-
# Test SendMessage (POST) - Use transport directly to avoid streaming complexity in mock
61-
request = SendMessageRequest(message=Message(parts=[{'text': 'hi'}]))
62-
await client._transport.send_message(request)
63-
64-
# Check that tenant was populated in request
65-
assert request.tenant == 'tenant-1'
66-
67-
# Check that path was prepended in the underlying transport
68-
mock_httpx.build_request.assert_called()
69-
send_call = next(
70-
c
71-
for c in mock_httpx.build_request.call_args_list
72-
if 'message:send' in c.args[1]
73-
)
74-
args, kwargs = send_call
75-
assert args[1] == 'http://example.com/rest/tenant-1/message:send'
76-
assert 'tenant' in kwargs['json']
77-
78-
79-
@pytest.mark.asyncio
80-
async def test_tenant_decorator_jsonrpc(agent_card):
81-
mock_httpx = AsyncMock(spec=httpx.AsyncClient)
82-
mock_httpx.build_request.return_value = MagicMock()
83-
mock_httpx.send.return_value = MagicMock(
84-
status_code=200,
85-
json=lambda: {'result': {'message': {}}, 'id': '1', 'jsonrpc': '2.0'},
86-
)
87-
88-
config = ClientConfig(
89-
httpx_client=mock_httpx,
90-
supported_protocol_bindings=[TransportProtocol.JSONRPC],
91-
)
92-
factory = ClientFactory(config)
93-
client = factory.create(agent_card)
94-
95-
assert isinstance(client._transport, TenantTransportDecorator)
96-
assert client._transport._tenant == 'tenant-2'
97-
98-
request = SendMessageRequest(message=Message(parts=[{'text': 'hi'}]))
99-
await client._transport.send_message(request)
100-
101-
mock_httpx.build_request.assert_called()
102-
_, kwargs = mock_httpx.build_request.call_args
103-
assert kwargs['json']['params']['tenant'] == 'tenant-2'
104-
105-
106-
@pytest.mark.asyncio
107-
async def test_tenant_decorator_grpc(agent_card):
108-
mock_channel = MagicMock()
109-
config = ClientConfig(
110-
grpc_channel_factory=lambda url: mock_channel,
111-
supported_protocol_bindings=[TransportProtocol.GRPC],
112-
)
113-
114-
with patch('a2a.types.a2a_pb2_grpc.A2AServiceStub') as mock_stub_class:
115-
mock_stub = mock_stub_class.return_value
116-
mock_stub.SendMessage = AsyncMock(return_value={'message': {}})
51+
@pytest.mark.asyncio
52+
async def test_tenant_decorator_rest(self, agent_card):
53+
mock_httpx = AsyncMock(spec=httpx.AsyncClient)
54+
mock_httpx.build_request.return_value = MagicMock()
55+
mock_httpx.send.return_value = MagicMock(
56+
status_code=200, json=lambda: {'message': {}}
57+
)
11758

59+
config = ClientConfig(
60+
httpx_client=mock_httpx,
61+
supported_protocol_bindings=[TransportProtocol.HTTP_JSON],
62+
)
11863
factory = ClientFactory(config)
11964
client = factory.create(agent_card)
12065

12166
assert isinstance(client._transport, TenantTransportDecorator)
122-
assert client._transport._tenant == 'tenant-3'
67+
assert client._transport._tenant == 'tenant-1'
68+
69+
# Test SendMessage (POST) - Use transport directly to avoid streaming complexity in mock
70+
request = SendMessageRequest(message=Message(parts=[{'text': 'hi'}]))
71+
await client._transport.send_message(request)
72+
73+
# Check that tenant was populated in request
74+
assert request.tenant == 'tenant-1'
12375

124-
await client._transport.send_message(
125-
SendMessageRequest(message=Message(parts=[{'text': 'hi'}]))
76+
# Check that path was prepended in the underlying transport
77+
mock_httpx.build_request.assert_called()
78+
send_call = next(
79+
c
80+
for c in mock_httpx.build_request.call_args_list
81+
if 'message:send' in c.args[1]
82+
)
83+
args, kwargs = send_call
84+
assert args[1] == 'http://example.com/rest/tenant-1/message:send'
85+
assert 'tenant' in kwargs['json']
86+
87+
@pytest.mark.asyncio
88+
async def test_tenant_decorator_jsonrpc(self, agent_card):
89+
mock_httpx = AsyncMock(spec=httpx.AsyncClient)
90+
mock_httpx.build_request.return_value = MagicMock()
91+
mock_httpx.send.return_value = MagicMock(
92+
status_code=200,
93+
json=lambda: {
94+
'result': {'message': {}},
95+
'id': '1',
96+
'jsonrpc': '2.0',
97+
},
98+
)
99+
100+
config = ClientConfig(
101+
httpx_client=mock_httpx,
102+
supported_protocol_bindings=[TransportProtocol.JSONRPC],
103+
)
104+
factory = ClientFactory(config)
105+
client = factory.create(agent_card)
106+
107+
assert isinstance(client._transport, TenantTransportDecorator)
108+
assert client._transport._tenant == 'tenant-2'
109+
110+
request = SendMessageRequest(message=Message(parts=[{'text': 'hi'}]))
111+
await client._transport.send_message(request)
112+
113+
mock_httpx.build_request.assert_called()
114+
_, kwargs = mock_httpx.build_request.call_args
115+
assert kwargs['json']['params']['tenant'] == 'tenant-2'
116+
117+
@pytest.mark.asyncio
118+
async def test_tenant_decorator_grpc(self, agent_card):
119+
mock_channel = MagicMock()
120+
config = ClientConfig(
121+
grpc_channel_factory=lambda url: mock_channel,
122+
supported_protocol_bindings=[TransportProtocol.GRPC],
123+
)
124+
125+
with patch('a2a.types.a2a_pb2_grpc.A2AServiceStub') as mock_stub_class:
126+
mock_stub = mock_stub_class.return_value
127+
mock_stub.SendMessage = AsyncMock(return_value={'message': {}})
128+
129+
factory = ClientFactory(config)
130+
client = factory.create(agent_card)
131+
132+
assert isinstance(client._transport, TenantTransportDecorator)
133+
assert client._transport._tenant == 'tenant-3'
134+
135+
await client._transport.send_message(
136+
SendMessageRequest(message=Message(parts=[{'text': 'hi'}]))
137+
)
138+
139+
call_args = mock_stub.SendMessage.call_args
140+
assert call_args[0][0].tenant == 'tenant-3'
141+
142+
@pytest.mark.asyncio
143+
async def test_tenant_decorator_explicit_override(self, agent_card):
144+
mock_httpx = AsyncMock(spec=httpx.AsyncClient)
145+
mock_httpx.build_request.return_value = MagicMock()
146+
mock_httpx.send.return_value = MagicMock(
147+
status_code=200, json=lambda: {'message': {}}
148+
)
149+
150+
config = ClientConfig(
151+
httpx_client=mock_httpx,
152+
supported_protocol_bindings=[TransportProtocol.HTTP_JSON],
153+
)
154+
factory = ClientFactory(config)
155+
client = factory.create(agent_card)
156+
157+
request = SendMessageRequest(
158+
message=Message(parts=[{'text': 'hi'}]), tenant='explicit-tenant'
159+
)
160+
await client._transport.send_message(request)
161+
162+
assert request.tenant == 'explicit-tenant'
163+
164+
send_call = next(
165+
c
166+
for c in mock_httpx.build_request.call_args_list
167+
if 'message:send' in c.args[1]
168+
)
169+
args, _ = send_call
170+
assert args[1] == 'http://example.com/rest/explicit-tenant/message:send'
171+
172+
173+
class TestJSONRPCTenantIntegration:
174+
@pytest.fixture
175+
def mock_handler(self):
176+
handler = AsyncMock(spec=RequestHandler)
177+
handler.on_list_tasks.return_value = ListTasksResponse(
178+
tasks=[Task(id='task-1')]
179+
)
180+
return handler
181+
182+
@pytest.fixture
183+
def jsonrpc_agent_card(self):
184+
return AgentCard(
185+
supported_interfaces=[
186+
AgentInterface(
187+
url='http://testserver/jsonrpc',
188+
protocol_binding=TransportProtocol.JSONRPC,
189+
tenant='my-test-tenant',
190+
),
191+
],
192+
capabilities=AgentCapabilities(
193+
streaming=False,
194+
push_notifications=False,
195+
),
126196
)
127197

128-
call_args = mock_stub.SendMessage.call_args
129-
assert call_args[0][0].tenant == 'tenant-3'
130-
131-
132-
@pytest.mark.asyncio
133-
async def test_tenant_decorator_explicit_override(agent_card):
134-
mock_httpx = AsyncMock(spec=httpx.AsyncClient)
135-
mock_httpx.build_request.return_value = MagicMock()
136-
mock_httpx.send.return_value = MagicMock(
137-
status_code=200, json=lambda: {'message': {}}
138-
)
139-
140-
config = ClientConfig(
141-
httpx_client=mock_httpx,
142-
supported_protocol_bindings=[TransportProtocol.HTTP_JSON],
143-
)
144-
factory = ClientFactory(config)
145-
client = factory.create(agent_card)
146-
147-
request = SendMessageRequest(
148-
message=Message(parts=[{'text': 'hi'}]), tenant='explicit-tenant'
149-
)
150-
await client._transport.send_message(request)
151-
152-
assert request.tenant == 'explicit-tenant'
153-
154-
send_call = next(
155-
c
156-
for c in mock_httpx.build_request.call_args_list
157-
if 'message:send' in c.args[1]
158-
)
159-
args, _ = send_call
160-
assert args[1] == 'http://example.com/rest/explicit-tenant/message:send'
198+
@pytest.fixture
199+
def server_app(self, jsonrpc_agent_card, mock_handler):
200+
app = A2AStarletteApplication(
201+
agent_card=jsonrpc_agent_card,
202+
http_handler=mock_handler,
203+
).build(rpc_url='/jsonrpc')
204+
return app
205+
206+
@pytest.mark.asyncio
207+
async def test_jsonrpc_tenant_context_population(
208+
self, server_app, mock_handler, jsonrpc_agent_card
209+
):
210+
"""
211+
Integration test to verify that a tenant configured in the client
212+
is correctly propagated to the ServerCallContext in the server
213+
via the JSON-RPC transport.
214+
"""
215+
# 1. Setup the client using the server app as the transport
216+
# We use ASGITransport so httpx calls go directly to the Starlette app
217+
transport = ASGITransport(app=server_app)
218+
async with AsyncClient(
219+
transport=transport, base_url='http://testserver'
220+
) as httpx_client:
221+
# Create the A2A client properly configured
222+
config = ClientConfig(
223+
httpx_client=httpx_client,
224+
supported_protocol_bindings=[TransportProtocol.JSONRPC],
225+
)
226+
factory = ClientFactory(config)
227+
client = factory.create(jsonrpc_agent_card)
228+
229+
# 2. Make the call (list_tasks)
230+
response = await client.list_tasks(ListTasksRequest())
231+
232+
# 3. Verify response
233+
assert len(response.tasks) == 1
234+
assert response.tasks[0].id == 'task-1'
235+
236+
# 4. Verify ServerCallContext on the server side
237+
mock_handler.on_list_tasks.assert_called_once()
238+
call_args = mock_handler.on_list_tasks.call_args
239+
240+
# call_args[0] are positional args: (request, context)
241+
# Check call_args signature in jsonrpc_handler.py: await self.handler.list_tasks(request_obj, context)
242+
243+
server_context = call_args[0][1]
244+
assert isinstance(server_context, ServerCallContext)
245+
assert server_context.tenant == 'my-test-tenant'

0 commit comments

Comments
 (0)