Skip to content

Commit a6dbc51

Browse files
committed
refactor(client): allow transport agnostic per invocation timeouts
Replace magic `http_kwargs` key with explicit `timeout` property on `ClientCallContext`.
1 parent ced3f99 commit a6dbc51

7 files changed

Lines changed: 174 additions & 41 deletions

File tree

src/a2a/client/middleware.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class ClientCallContext(BaseModel):
1919
"""
2020

2121
state: MutableMapping[str, Any] = Field(default_factory=dict)
22+
timeout: float | None = None
2223

2324

2425
class ClientCallInterceptor(ABC):

src/a2a/client/transports/grpc.py

Lines changed: 72 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,8 @@ async def send_message(
134134
extensions: list[str] | None = None,
135135
) -> SendMessageResponse:
136136
"""Sends a non-streaming message request to the agent."""
137-
return await self.stub.SendMessage(
138-
request,
139-
metadata=self._get_grpc_metadata(extensions),
137+
return await self._call_grpc(
138+
self.stub.SendMessage, request, context, extensions
140139
)
141140

142141
@_handle_grpc_stream_exception
@@ -148,14 +147,9 @@ async def send_message_streaming(
148147
extensions: list[str] | None = None,
149148
) -> AsyncGenerator[StreamResponse]:
150149
"""Sends a streaming message request to the agent and yields responses as they arrive."""
151-
stream = self.stub.SendStreamingMessage(
152-
request,
153-
metadata=self._get_grpc_metadata(extensions),
154-
)
155-
while True:
156-
response = await stream.read()
157-
if response == grpc.aio.EOF: # pyright: ignore[reportAttributeAccessIssue]
158-
break
150+
async for response in self._call_grpc_stream(
151+
self.stub.SendStreamingMessage, request, context, extensions
152+
):
159153
yield response
160154

161155
@_handle_grpc_stream_exception
@@ -167,14 +161,9 @@ async def subscribe(
167161
extensions: list[str] | None = None,
168162
) -> AsyncGenerator[StreamResponse]:
169163
"""Reconnects to get task updates."""
170-
stream = self.stub.SubscribeToTask(
171-
request,
172-
metadata=self._get_grpc_metadata(extensions),
173-
)
174-
while True:
175-
response = await stream.read()
176-
if response == grpc.aio.EOF: # pyright: ignore[reportAttributeAccessIssue]
177-
break
164+
async for response in self._call_grpc_stream(
165+
self.stub.SubscribeToTask, request, context, extensions
166+
):
178167
yield response
179168

180169
@_handle_grpc_exception
@@ -186,9 +175,8 @@ async def get_task(
186175
extensions: list[str] | None = None,
187176
) -> Task:
188177
"""Retrieves the current state and history of a specific task."""
189-
return await self.stub.GetTask(
190-
request,
191-
metadata=self._get_grpc_metadata(extensions),
178+
return await self._call_grpc(
179+
self.stub.GetTask, request, context, extensions
192180
)
193181

194182
@_handle_grpc_exception
@@ -200,9 +188,8 @@ async def list_tasks(
200188
extensions: list[str] | None = None,
201189
) -> ListTasksResponse:
202190
"""Retrieves tasks for an agent."""
203-
return await self.stub.ListTasks(
204-
request,
205-
metadata=self._get_grpc_metadata(extensions),
191+
return await self._call_grpc(
192+
self.stub.ListTasks, request, context, extensions
206193
)
207194

208195
@_handle_grpc_exception
@@ -214,9 +201,8 @@ async def cancel_task(
214201
extensions: list[str] | None = None,
215202
) -> Task:
216203
"""Requests the agent to cancel a specific task."""
217-
return await self.stub.CancelTask(
218-
request,
219-
metadata=self._get_grpc_metadata(extensions),
204+
return await self._call_grpc(
205+
self.stub.CancelTask, request, context, extensions
220206
)
221207

222208
@_handle_grpc_exception
@@ -228,9 +214,11 @@ async def create_task_push_notification_config(
228214
extensions: list[str] | None = None,
229215
) -> TaskPushNotificationConfig:
230216
"""Sets or updates the push notification configuration for a specific task."""
231-
return await self.stub.CreateTaskPushNotificationConfig(
217+
return await self._call_grpc(
218+
self.stub.CreateTaskPushNotificationConfig,
232219
request,
233-
metadata=self._get_grpc_metadata(extensions),
220+
context,
221+
extensions,
234222
)
235223

236224
@_handle_grpc_exception
@@ -242,9 +230,11 @@ async def get_task_push_notification_config(
242230
extensions: list[str] | None = None,
243231
) -> TaskPushNotificationConfig:
244232
"""Retrieves the push notification configuration for a specific task."""
245-
return await self.stub.GetTaskPushNotificationConfig(
233+
return await self._call_grpc(
234+
self.stub.GetTaskPushNotificationConfig,
246235
request,
247-
metadata=self._get_grpc_metadata(extensions),
236+
context,
237+
extensions,
248238
)
249239

250240
@_handle_grpc_exception
@@ -256,9 +246,11 @@ async def list_task_push_notification_configs(
256246
extensions: list[str] | None = None,
257247
) -> ListTaskPushNotificationConfigsResponse:
258248
"""Lists push notification configurations for a specific task."""
259-
return await self.stub.ListTaskPushNotificationConfigs(
249+
return await self._call_grpc(
250+
self.stub.ListTaskPushNotificationConfigs,
260251
request,
261-
metadata=self._get_grpc_metadata(extensions),
252+
context,
253+
extensions,
262254
)
263255

264256
@_handle_grpc_exception
@@ -270,9 +262,11 @@ async def delete_task_push_notification_config(
270262
extensions: list[str] | None = None,
271263
) -> None:
272264
"""Deletes the push notification configuration for a specific task."""
273-
await self.stub.DeleteTaskPushNotificationConfig(
265+
await self._call_grpc(
266+
self.stub.DeleteTaskPushNotificationConfig,
274267
request,
275-
metadata=self._get_grpc_metadata(extensions),
268+
context,
269+
extensions,
276270
)
277271

278272
@_handle_grpc_exception
@@ -285,9 +279,8 @@ async def get_extended_agent_card(
285279
signature_verifier: Callable[[AgentCard], None] | None = None,
286280
) -> AgentCard:
287281
"""Retrieves the agent's card."""
288-
card = await self.stub.GetExtendedAgentCard(
289-
request,
290-
metadata=self._get_grpc_metadata(extensions),
282+
card = await self._call_grpc(
283+
self.stub.GetExtendedAgentCard, request, context, extensions
291284
)
292285

293286
if signature_verifier:
@@ -315,3 +308,43 @@ def _get_grpc_metadata(
315308
)
316309

317310
return metadata
311+
312+
def _get_grpc_timeout(
313+
self, context: ClientCallContext | None
314+
) -> float | None:
315+
return context.timeout if context else None
316+
317+
async def _call_grpc(
318+
self,
319+
method: Callable[..., Any],
320+
request: Any,
321+
context: ClientCallContext | None,
322+
extensions: list[str] | None,
323+
**kwargs: Any,
324+
) -> Any:
325+
return await method(
326+
request,
327+
metadata=self._get_grpc_metadata(extensions),
328+
timeout=self._get_grpc_timeout(context),
329+
**kwargs,
330+
)
331+
332+
async def _call_grpc_stream(
333+
self,
334+
method: Callable[..., Any],
335+
request: Any,
336+
context: ClientCallContext | None,
337+
extensions: list[str] | None,
338+
**kwargs: Any,
339+
) -> AsyncGenerator[StreamResponse]:
340+
stream = method(
341+
request,
342+
metadata=self._get_grpc_metadata(extensions),
343+
timeout=self._get_grpc_timeout(context),
344+
**kwargs,
345+
)
346+
while True:
347+
response = await stream.read()
348+
if response == grpc.aio.EOF:
349+
break
350+
yield response

src/a2a/client/transports/jsonrpc.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,10 @@ async def _apply_interceptors(
459459
def _get_http_args(
460460
self, context: ClientCallContext | None
461461
) -> dict[str, Any] | None:
462-
return context.state.get('http_kwargs') if context else None
462+
http_kwargs: dict[str, Any] = {}
463+
if context and context.timeout is not None:
464+
http_kwargs['timeout'] = httpx.Timeout(context.timeout)
465+
return http_kwargs if http_kwargs else None
463466

464467
def _create_jsonrpc_error(self, error_dict: dict[str, Any]) -> Exception:
465468
"""Creates the appropriate A2AError from a JSON-RPC error dictionary."""

src/a2a/client/transports/rest.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,10 @@ async def _apply_interceptors(
377377
def _get_http_args(
378378
self, context: ClientCallContext | None
379379
) -> dict[str, Any] | None:
380-
return context.state.get('http_kwargs') if context else None
380+
http_kwargs: dict[str, Any] = {}
381+
if context and context.timeout is not None:
382+
http_kwargs['timeout'] = httpx.Timeout(context.timeout)
383+
return http_kwargs
381384

382385
async def _prepare_send_message(
383386
self,

tests/client/transports/test_grpc_client.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,32 @@ async def test_send_message_task_response(
228228
assert response.task.id == sample_task.id
229229

230230

231+
@pytest.mark.asyncio
232+
async def test_send_message_with_timeout_context(
233+
grpc_transport: GrpcTransport,
234+
mock_grpc_stub: AsyncMock,
235+
sample_message_send_params: SendMessageRequest,
236+
sample_task: Task,
237+
) -> None:
238+
"""Test send_message passes context timeout to grpc stub."""
239+
from a2a.client.middleware import ClientCallContext
240+
241+
mock_grpc_stub.SendMessage.return_value = a2a_pb2.SendMessageResponse(
242+
task=sample_task
243+
)
244+
context = ClientCallContext(timeout=12.5)
245+
246+
await grpc_transport.send_message(
247+
sample_message_send_params,
248+
context=context,
249+
)
250+
251+
mock_grpc_stub.SendMessage.assert_awaited_once()
252+
_, kwargs = mock_grpc_stub.SendMessage.call_args
253+
assert 'timeout' in kwargs
254+
assert kwargs['timeout'] == 12.5
255+
256+
231257
@pytest.mark.parametrize('error_cls', list(JSON_RPC_ERROR_CODE_MAP.keys()))
232258
@pytest.mark.asyncio
233259
async def test_grpc_mapped_errors(
@@ -360,6 +386,7 @@ async def test_get_task(
360386
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
361387
),
362388
],
389+
timeout=None,
363390
)
364391
assert response.id == sample_task.id
365392

@@ -389,6 +416,7 @@ async def test_list_tasks(
389416
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
390417
),
391418
],
419+
timeout=None,
392420
)
393421
assert result.total_size == 2
394422
assert not result.next_page_token
@@ -417,6 +445,7 @@ async def test_get_task_with_history(
417445
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
418446
),
419447
],
448+
timeout=None,
420449
)
421450

422451

@@ -443,6 +472,7 @@ async def test_cancel_task(
443472
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
444473
(HTTP_EXTENSION_HEADER.lower(), 'https://example.com/test-ext/v3'),
445474
],
475+
timeout=None,
446476
)
447477
assert response.status.state == TaskState.TASK_STATE_CANCELED
448478

@@ -476,6 +506,7 @@ async def test_create_task_push_notification_config_with_valid_task(
476506
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
477507
),
478508
],
509+
timeout=None,
479510
)
480511
assert response.task_id == sample_task_push_notification_config.task_id
481512

@@ -539,6 +570,7 @@ async def test_get_task_push_notification_config_with_valid_task(
539570
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
540571
),
541572
],
573+
timeout=None,
542574
)
543575
assert response.task_id == sample_task_push_notification_config.task_id
544576

@@ -593,6 +625,7 @@ async def test_list_task_push_notification_configs(
593625
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
594626
),
595627
],
628+
timeout=None,
596629
)
597630
assert len(response.configs) == 1
598631
assert response.configs[0].task_id == 'task-1'
@@ -626,6 +659,7 @@ async def test_delete_task_push_notification_config(
626659
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
627660
),
628661
],
662+
timeout=None,
629663
)
630664

631665

tests/client/transports/test_jsonrpc_client.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,32 @@ async def test_send_message_json_decode_error(
235235
with pytest.raises(A2AClientError):
236236
await transport.send_message(request)
237237

238+
@pytest.mark.asyncio
239+
async def test_send_message_with_timeout_context(
240+
self, transport, mock_httpx_client
241+
):
242+
"""Test that send_message passes context timeout to build_request."""
243+
from a2a.client.middleware import ClientCallContext
244+
245+
mock_response = MagicMock()
246+
mock_response.json.return_value = {
247+
'jsonrpc': '2.0',
248+
'id': '1',
249+
'result': {},
250+
}
251+
mock_response.raise_for_status = MagicMock()
252+
mock_httpx_client.send.return_value = mock_response
253+
254+
request = create_send_message_request()
255+
context = ClientCallContext(timeout=15.0)
256+
257+
await transport.send_message(request, context=context)
258+
259+
mock_httpx_client.build_request.assert_called_once()
260+
_, kwargs = mock_httpx_client.build_request.call_args
261+
assert 'timeout' in kwargs
262+
assert kwargs['timeout'] == httpx.Timeout(15.0)
263+
238264

239265
class TestGetTask:
240266
"""Tests for the get_task method."""

tests/client/transports/test_rest_client.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,39 @@ async def test_rest_mapped_errors(
135135
with pytest.raises(error_cls):
136136
await client.send_message(request=params)
137137

138+
@pytest.mark.asyncio
139+
async def test_send_message_with_timeout_context(
140+
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock
141+
):
142+
"""Test that send_message passes context timeout to build_request."""
143+
from a2a.client.middleware import ClientCallContext
144+
145+
client = RestTransport(
146+
httpx_client=mock_httpx_client,
147+
agent_card=mock_agent_card,
148+
url='http://agent.example.com/api',
149+
)
150+
params = SendMessageRequest(
151+
message=create_text_message_object(content='Hello')
152+
)
153+
context = ClientCallContext(timeout=10.0)
154+
155+
mock_build_request = MagicMock(
156+
return_value=AsyncMock(spec=httpx.Request)
157+
)
158+
mock_httpx_client.build_request = mock_build_request
159+
160+
mock_response = AsyncMock(spec=httpx.Response)
161+
mock_response.status_code = 200
162+
mock_httpx_client.send.return_value = mock_response
163+
164+
await client.send_message(request=params, context=context)
165+
166+
mock_build_request.assert_called_once()
167+
_, kwargs = mock_build_request.call_args
168+
assert 'timeout' in kwargs
169+
assert kwargs['timeout'] == httpx.Timeout(10.0)
170+
138171

139172
class TestRestTransportExtensions:
140173
@pytest.mark.asyncio

0 commit comments

Comments
 (0)