Skip to content

Commit 942f4ae

Browse files
refactor(client)!: introduce ServiceParameters for extensions and include it in ClientCallContext (#784)
# Description This PR refactors the client API definitions to streamline extension handling and unify transport logic: - A new class to store extensions, integrated into the ClientCallContext. This reduces the need for having separate extension fields across the API definition. - Extracted common HTTP argument parsing logic into shared helper functions used by both REST and JSON-RPC transports. - Interceptor logic has been temporarily removed, as it will be redesigned and reintroduced in an upcoming PR.
1 parent 0ebca93 commit 942f4ae

24 files changed

Lines changed: 503 additions & 963 deletions

src/a2a/client/base_client.py

Lines changed: 39 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
from collections.abc import AsyncGenerator, AsyncIterator, Callable
2-
from typing import Any
32

43
from a2a.client.client import (
54
Client,
6-
ClientCallContext,
75
ClientConfig,
86
ClientEvent,
97
Consumer,
108
)
119
from a2a.client.client_task_manager import ClientTaskManager
12-
from a2a.client.middleware import ClientCallInterceptor
10+
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
1311
from a2a.client.transports.base import ClientTransport
1412
from a2a.types.a2a_pb2 import (
1513
AgentCard,
@@ -23,8 +21,6 @@
2321
ListTaskPushNotificationConfigsResponse,
2422
ListTasksRequest,
2523
ListTasksResponse,
26-
Message,
27-
SendMessageConfiguration,
2824
SendMessageRequest,
2925
StreamResponse,
3026
SubscribeToTaskRequest,
@@ -51,12 +47,9 @@ def __init__(
5147

5248
async def send_message(
5349
self,
54-
request: Message,
50+
request: SendMessageRequest,
5551
*,
56-
configuration: SendMessageConfiguration | None = None,
5752
context: ClientCallContext | None = None,
58-
request_metadata: dict[str, Any] | None = None,
59-
extensions: list[str] | None = None,
6053
) -> AsyncIterator[ClientEvent]:
6154
"""Sends a message to the agent.
6255
@@ -66,35 +59,15 @@ async def send_message(
6659
6760
Args:
6861
request: The message to send to the agent.
69-
configuration: Optional per-call overrides for message sending behavior.
70-
context: The client call context.
71-
request_metadata: Extensions Metadata attached to the request.
72-
extensions: List of extensions to be activated.
62+
context: Optional client call context.
7363
7464
Yields:
7565
An async iterator of `ClientEvent`
7666
"""
77-
config = SendMessageConfiguration(
78-
accepted_output_modes=self._config.accepted_output_modes,
79-
blocking=not self._config.polling,
80-
push_notification_config=(
81-
self._config.push_notification_configs[0]
82-
if self._config.push_notification_configs
83-
else None
84-
),
85-
)
86-
87-
if configuration:
88-
config.MergeFrom(configuration)
89-
config.blocking = configuration.blocking
90-
91-
send_message_request = SendMessageRequest(
92-
message=request, configuration=config, metadata=request_metadata
93-
)
94-
67+
self._apply_client_config(request)
9568
if not self._config.streaming or not self._card.capabilities.streaming:
9669
response = await self._transport.send_message(
97-
send_message_request, context=context, extensions=extensions
70+
request, context=context
9871
)
9972

10073
# In non-streaming case we convert to a StreamResponse so that the
@@ -116,11 +89,29 @@ async def send_message(
11689
return
11790

11891
stream = self._transport.send_message_streaming(
119-
send_message_request, context=context, extensions=extensions
92+
request, context=context
12093
)
12194
async for client_event in self._process_stream(stream):
12295
yield client_event
12396

97+
def _apply_client_config(self, request: SendMessageRequest) -> None:
98+
if not request.configuration.blocking and self._config.polling:
99+
request.configuration.blocking = not self._config.polling
100+
if (
101+
not request.configuration.HasField('push_notification_config')
102+
and self._config.push_notification_configs
103+
):
104+
request.configuration.push_notification_config.CopyFrom(
105+
self._config.push_notification_configs[0]
106+
)
107+
if (
108+
not request.configuration.accepted_output_modes
109+
and self._config.accepted_output_modes
110+
):
111+
request.configuration.accepted_output_modes.extend(
112+
self._config.accepted_output_modes
113+
)
114+
124115
async def _process_stream(
125116
self, stream: AsyncIterator[StreamResponse]
126117
) -> AsyncGenerator[ClientEvent]:
@@ -147,21 +138,17 @@ async def get_task(
147138
request: GetTaskRequest,
148139
*,
149140
context: ClientCallContext | None = None,
150-
extensions: list[str] | None = None,
151141
) -> Task:
152142
"""Retrieves the current state and history of a specific task.
153143
154144
Args:
155145
request: The `GetTaskRequest` object specifying the task ID.
156-
context: The client call context.
157-
extensions: List of extensions to be activated.
146+
context: Optional client call context.
158147
159148
Returns:
160149
A `Task` object representing the current state of the task.
161150
"""
162-
return await self._transport.get_task(
163-
request, context=context, extensions=extensions
164-
)
151+
return await self._transport.get_task(request, context=context)
165152

166153
async def list_tasks(
167154
self,
@@ -177,118 +164,104 @@ async def cancel_task(
177164
request: CancelTaskRequest,
178165
*,
179166
context: ClientCallContext | None = None,
180-
extensions: list[str] | None = None,
181167
) -> Task:
182168
"""Requests the agent to cancel a specific task.
183169
184170
Args:
185171
request: The `CancelTaskRequest` object specifying the task ID.
186-
context: The client call context.
187-
extensions: List of extensions to be activated.
172+
context: Optional client call context.
188173
189174
Returns:
190175
A `Task` object containing the updated task status.
191176
"""
192-
return await self._transport.cancel_task(
193-
request, context=context, extensions=extensions
194-
)
177+
return await self._transport.cancel_task(request, context=context)
195178

196179
async def create_task_push_notification_config(
197180
self,
198181
request: CreateTaskPushNotificationConfigRequest,
199182
*,
200183
context: ClientCallContext | None = None,
201-
extensions: list[str] | None = None,
202184
) -> TaskPushNotificationConfig:
203185
"""Sets or updates the push notification configuration for a specific task.
204186
205187
Args:
206188
request: The `TaskPushNotificationConfig` object with the new configuration.
207-
context: The client call context.
208-
extensions: List of extensions to be activated.
189+
context: Optional client call context.
209190
210191
Returns:
211192
The created or updated `TaskPushNotificationConfig` object.
212193
"""
213194
return await self._transport.create_task_push_notification_config(
214-
request, context=context, extensions=extensions
195+
request, context=context
215196
)
216197

217198
async def get_task_push_notification_config(
218199
self,
219200
request: GetTaskPushNotificationConfigRequest,
220201
*,
221202
context: ClientCallContext | None = None,
222-
extensions: list[str] | None = None,
223203
) -> TaskPushNotificationConfig:
224204
"""Retrieves the push notification configuration for a specific task.
225205
226206
Args:
227207
request: The `GetTaskPushNotificationConfigParams` object specifying the task.
228-
context: The client call context.
229-
extensions: List of extensions to be activated.
208+
context: Optional client call context.
230209
231210
Returns:
232211
A `TaskPushNotificationConfig` object containing the configuration.
233212
"""
234213
return await self._transport.get_task_push_notification_config(
235-
request, context=context, extensions=extensions
214+
request, context=context
236215
)
237216

238217
async def list_task_push_notification_configs(
239218
self,
240219
request: ListTaskPushNotificationConfigsRequest,
241220
*,
242221
context: ClientCallContext | None = None,
243-
extensions: list[str] | None = None,
244222
) -> ListTaskPushNotificationConfigsResponse:
245223
"""Lists push notification configurations for a specific task.
246224
247225
Args:
248226
request: The `ListTaskPushNotificationConfigsRequest` object specifying the request.
249-
context: The client call context.
250-
extensions: List of extensions to be activated.
227+
context: Optional client call context.
251228
252229
Returns:
253230
A `ListTaskPushNotificationConfigsResponse` object.
254231
"""
255232
return await self._transport.list_task_push_notification_configs(
256-
request, context=context, extensions=extensions
233+
request, context=context
257234
)
258235

259236
async def delete_task_push_notification_config(
260237
self,
261238
request: DeleteTaskPushNotificationConfigRequest,
262239
*,
263240
context: ClientCallContext | None = None,
264-
extensions: list[str] | None = None,
265241
) -> None:
266242
"""Deletes the push notification configuration for a specific task.
267243
268244
Args:
269245
request: The `DeleteTaskPushNotificationConfigRequest` object specifying the request.
270-
context: The client call context.
271-
extensions: List of extensions to be activated.
246+
context: Optional client call context.
272247
"""
273248
await self._transport.delete_task_push_notification_config(
274-
request, context=context, extensions=extensions
249+
request, context=context
275250
)
276251

277252
async def subscribe(
278253
self,
279254
request: SubscribeToTaskRequest,
280255
*,
281256
context: ClientCallContext | None = None,
282-
extensions: list[str] | None = None,
283257
) -> AsyncIterator[ClientEvent]:
284258
"""Resubscribes to a task's event stream.
285259
286260
This is only available if both the client and server support streaming.
287261
288262
Args:
289263
request: Parameters to identify the task to resubscribe to.
290-
context: The client call context.
291-
extensions: List of extensions to be activated.
264+
context: Optional client call context.
292265
293266
Yields:
294267
An async iterator of `ClientEvent` objects.
@@ -304,9 +277,7 @@ async def subscribe(
304277
# Note: resubscribe can only be called on an existing task. As such,
305278
# we should never see Message updates, despite the typing of the service
306279
# definition indicating it may be possible.
307-
stream = self._transport.subscribe(
308-
request, context=context, extensions=extensions
309-
)
280+
stream = self._transport.subscribe(request, context=context)
310281
async for client_event in self._process_stream(stream):
311282
yield client_event
312283

@@ -315,7 +286,6 @@ async def get_extended_agent_card(
315286
request: GetExtendedAgentCardRequest,
316287
*,
317288
context: ClientCallContext | None = None,
318-
extensions: list[str] | None = None,
319289
signature_verifier: Callable[[AgentCard], None] | None = None,
320290
) -> AgentCard:
321291
"""Retrieves the agent's card.
@@ -325,8 +295,7 @@ async def get_extended_agent_card(
325295
326296
Args:
327297
request: The `GetExtendedAgentCardRequest` object specifying the request.
328-
context: The client call context.
329-
extensions: List of extensions to be activated.
298+
context: Optional client call context.
330299
signature_verifier: A callable used to verify the agent card's signatures.
331300
332301
Returns:
@@ -335,7 +304,6 @@ async def get_extended_agent_card(
335304
card = await self._transport.get_extended_agent_card(
336305
request,
337306
context=context,
338-
extensions=extensions,
339307
signature_verifier=signature_verifier,
340308
)
341309
self._card = card

src/a2a/client/client.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,8 @@
2424
ListTaskPushNotificationConfigsResponse,
2525
ListTasksRequest,
2626
ListTasksResponse,
27-
Message,
2827
PushNotificationConfig,
29-
SendMessageConfiguration,
28+
SendMessageRequest,
3029
StreamResponse,
3130
SubscribeToTaskRequest,
3231
Task,
@@ -77,9 +76,6 @@ class ClientConfig:
7776
)
7877
"""Push notification configurations to use for every request."""
7978

80-
extensions: list[str] = dataclasses.field(default_factory=list)
81-
"""A list of extension URIs the client supports."""
82-
8379

8480
ClientEvent = tuple[StreamResponse, Task | None]
8581

@@ -130,12 +126,9 @@ async def __aexit__(
130126
@abstractmethod
131127
async def send_message(
132128
self,
133-
request: Message,
129+
request: SendMessageRequest,
134130
*,
135-
configuration: SendMessageConfiguration | None = None,
136131
context: ClientCallContext | None = None,
137-
request_metadata: dict[str, Any] | None = None,
138-
extensions: list[str] | None = None,
139132
) -> AsyncIterator[ClientEvent]:
140133
"""Sends a message to the server.
141134
@@ -154,7 +147,6 @@ async def get_task(
154147
request: GetTaskRequest,
155148
*,
156149
context: ClientCallContext | None = None,
157-
extensions: list[str] | None = None,
158150
) -> Task:
159151
"""Retrieves the current state and history of a specific task."""
160152

@@ -173,7 +165,6 @@ async def cancel_task(
173165
request: CancelTaskRequest,
174166
*,
175167
context: ClientCallContext | None = None,
176-
extensions: list[str] | None = None,
177168
) -> Task:
178169
"""Requests the agent to cancel a specific task."""
179170

@@ -183,7 +174,6 @@ async def create_task_push_notification_config(
183174
request: CreateTaskPushNotificationConfigRequest,
184175
*,
185176
context: ClientCallContext | None = None,
186-
extensions: list[str] | None = None,
187177
) -> TaskPushNotificationConfig:
188178
"""Sets or updates the push notification configuration for a specific task."""
189179

@@ -193,7 +183,6 @@ async def get_task_push_notification_config(
193183
request: GetTaskPushNotificationConfigRequest,
194184
*,
195185
context: ClientCallContext | None = None,
196-
extensions: list[str] | None = None,
197186
) -> TaskPushNotificationConfig:
198187
"""Retrieves the push notification configuration for a specific task."""
199188

@@ -203,7 +192,6 @@ async def list_task_push_notification_configs(
203192
request: ListTaskPushNotificationConfigsRequest,
204193
*,
205194
context: ClientCallContext | None = None,
206-
extensions: list[str] | None = None,
207195
) -> ListTaskPushNotificationConfigsResponse:
208196
"""Lists push notification configurations for a specific task."""
209197

@@ -213,7 +201,6 @@ async def delete_task_push_notification_config(
213201
request: DeleteTaskPushNotificationConfigRequest,
214202
*,
215203
context: ClientCallContext | None = None,
216-
extensions: list[str] | None = None,
217204
) -> None:
218205
"""Deletes the push notification configuration for a specific task."""
219206

@@ -223,7 +210,6 @@ async def subscribe(
223210
request: SubscribeToTaskRequest,
224211
*,
225212
context: ClientCallContext | None = None,
226-
extensions: list[str] | None = None,
227213
) -> AsyncIterator[ClientEvent]:
228214
"""Resubscribes to a task's event stream."""
229215
return
@@ -235,7 +221,6 @@ async def get_extended_agent_card(
235221
request: GetExtendedAgentCardRequest,
236222
*,
237223
context: ClientCallContext | None = None,
238-
extensions: list[str] | None = None,
239224
signature_verifier: Callable[[AgentCard], None] | None = None,
240225
) -> AgentCard:
241226
"""Retrieves the agent's card."""

0 commit comments

Comments
 (0)