11from collections .abc import AsyncGenerator , AsyncIterator , Callable
2- from typing import Any
32
43from a2a .client .client import (
54 Client ,
6- ClientCallContext ,
75 ClientConfig ,
86 ClientEvent ,
97 Consumer ,
108)
119from a2a .client .client_task_manager import ClientTaskManager
12- from a2a .client .middleware import ClientCallInterceptor
10+ from a2a .client .middleware import ClientCallContext , ClientCallInterceptor
1311from a2a .client .transports .base import ClientTransport
1412from a2a .types .a2a_pb2 import (
1513 AgentCard ,
2321 ListTaskPushNotificationConfigsResponse ,
2422 ListTasksRequest ,
2523 ListTasksResponse ,
26- Message ,
27- SendMessageConfiguration ,
2824 SendMessageRequest ,
2925 StreamResponse ,
3026 SubscribeToTaskRequest ,
@@ -51,12 +47,9 @@ def __init__(
5147
5248 async def send_message (
5349 self ,
54- request : Message ,
50+ request : SendMessageRequest ,
5551 * ,
56- configuration : SendMessageConfiguration | None = None ,
5752 context : ClientCallContext | None = None ,
58- request_metadata : dict [str , Any ] | None = None ,
59- extensions : list [str ] | None = None ,
6053 ) -> AsyncIterator [ClientEvent ]:
6154 """Sends a message to the agent.
6255
@@ -66,35 +59,15 @@ async def send_message(
6659
6760 Args:
6861 request: The message to send to the agent.
69- configuration: Optional per-call overrides for message sending behavior.
70- context: The client call context.
71- request_metadata: Extensions Metadata attached to the request.
72- extensions: List of extensions to be activated.
62+ context: Optional client call context.
7363
7464 Yields:
7565 An async iterator of `ClientEvent`
7666 """
77- config = SendMessageConfiguration (
78- accepted_output_modes = self ._config .accepted_output_modes ,
79- blocking = not self ._config .polling ,
80- push_notification_config = (
81- self ._config .push_notification_configs [0 ]
82- if self ._config .push_notification_configs
83- else None
84- ),
85- )
86-
87- if configuration :
88- config .MergeFrom (configuration )
89- config .blocking = configuration .blocking
90-
91- send_message_request = SendMessageRequest (
92- message = request , configuration = config , metadata = request_metadata
93- )
94-
67+ self ._apply_client_config (request )
9568 if not self ._config .streaming or not self ._card .capabilities .streaming :
9669 response = await self ._transport .send_message (
97- send_message_request , context = context , extensions = extensions
70+ request , context = context
9871 )
9972
10073 # In non-streaming case we convert to a StreamResponse so that the
@@ -116,11 +89,29 @@ async def send_message(
11689 return
11790
11891 stream = self ._transport .send_message_streaming (
119- send_message_request , context = context , extensions = extensions
92+ request , context = context
12093 )
12194 async for client_event in self ._process_stream (stream ):
12295 yield client_event
12396
97+ def _apply_client_config (self , request : SendMessageRequest ) -> None :
98+ if not request .configuration .blocking and self ._config .polling :
99+ request .configuration .blocking = not self ._config .polling
100+ if (
101+ not request .configuration .HasField ('push_notification_config' )
102+ and self ._config .push_notification_configs
103+ ):
104+ request .configuration .push_notification_config .CopyFrom (
105+ self ._config .push_notification_configs [0 ]
106+ )
107+ if (
108+ not request .configuration .accepted_output_modes
109+ and self ._config .accepted_output_modes
110+ ):
111+ request .configuration .accepted_output_modes .extend (
112+ self ._config .accepted_output_modes
113+ )
114+
124115 async def _process_stream (
125116 self , stream : AsyncIterator [StreamResponse ]
126117 ) -> AsyncGenerator [ClientEvent ]:
@@ -147,21 +138,17 @@ async def get_task(
147138 request : GetTaskRequest ,
148139 * ,
149140 context : ClientCallContext | None = None ,
150- extensions : list [str ] | None = None ,
151141 ) -> Task :
152142 """Retrieves the current state and history of a specific task.
153143
154144 Args:
155145 request: The `GetTaskRequest` object specifying the task ID.
156- context: The client call context.
157- extensions: List of extensions to be activated.
146+ context: Optional client call context.
158147
159148 Returns:
160149 A `Task` object representing the current state of the task.
161150 """
162- return await self ._transport .get_task (
163- request , context = context , extensions = extensions
164- )
151+ return await self ._transport .get_task (request , context = context )
165152
166153 async def list_tasks (
167154 self ,
@@ -177,118 +164,104 @@ async def cancel_task(
177164 request : CancelTaskRequest ,
178165 * ,
179166 context : ClientCallContext | None = None ,
180- extensions : list [str ] | None = None ,
181167 ) -> Task :
182168 """Requests the agent to cancel a specific task.
183169
184170 Args:
185171 request: The `CancelTaskRequest` object specifying the task ID.
186- context: The client call context.
187- extensions: List of extensions to be activated.
172+ context: Optional client call context.
188173
189174 Returns:
190175 A `Task` object containing the updated task status.
191176 """
192- return await self ._transport .cancel_task (
193- request , context = context , extensions = extensions
194- )
177+ return await self ._transport .cancel_task (request , context = context )
195178
196179 async def create_task_push_notification_config (
197180 self ,
198181 request : CreateTaskPushNotificationConfigRequest ,
199182 * ,
200183 context : ClientCallContext | None = None ,
201- extensions : list [str ] | None = None ,
202184 ) -> TaskPushNotificationConfig :
203185 """Sets or updates the push notification configuration for a specific task.
204186
205187 Args:
206188 request: The `TaskPushNotificationConfig` object with the new configuration.
207- context: The client call context.
208- extensions: List of extensions to be activated.
189+ context: Optional client call context.
209190
210191 Returns:
211192 The created or updated `TaskPushNotificationConfig` object.
212193 """
213194 return await self ._transport .create_task_push_notification_config (
214- request , context = context , extensions = extensions
195+ request , context = context
215196 )
216197
217198 async def get_task_push_notification_config (
218199 self ,
219200 request : GetTaskPushNotificationConfigRequest ,
220201 * ,
221202 context : ClientCallContext | None = None ,
222- extensions : list [str ] | None = None ,
223203 ) -> TaskPushNotificationConfig :
224204 """Retrieves the push notification configuration for a specific task.
225205
226206 Args:
227207 request: The `GetTaskPushNotificationConfigParams` object specifying the task.
228- context: The client call context.
229- extensions: List of extensions to be activated.
208+ context: Optional client call context.
230209
231210 Returns:
232211 A `TaskPushNotificationConfig` object containing the configuration.
233212 """
234213 return await self ._transport .get_task_push_notification_config (
235- request , context = context , extensions = extensions
214+ request , context = context
236215 )
237216
238217 async def list_task_push_notification_configs (
239218 self ,
240219 request : ListTaskPushNotificationConfigsRequest ,
241220 * ,
242221 context : ClientCallContext | None = None ,
243- extensions : list [str ] | None = None ,
244222 ) -> ListTaskPushNotificationConfigsResponse :
245223 """Lists push notification configurations for a specific task.
246224
247225 Args:
248226 request: The `ListTaskPushNotificationConfigsRequest` object specifying the request.
249- context: The client call context.
250- extensions: List of extensions to be activated.
227+ context: Optional client call context.
251228
252229 Returns:
253230 A `ListTaskPushNotificationConfigsResponse` object.
254231 """
255232 return await self ._transport .list_task_push_notification_configs (
256- request , context = context , extensions = extensions
233+ request , context = context
257234 )
258235
259236 async def delete_task_push_notification_config (
260237 self ,
261238 request : DeleteTaskPushNotificationConfigRequest ,
262239 * ,
263240 context : ClientCallContext | None = None ,
264- extensions : list [str ] | None = None ,
265241 ) -> None :
266242 """Deletes the push notification configuration for a specific task.
267243
268244 Args:
269245 request: The `DeleteTaskPushNotificationConfigRequest` object specifying the request.
270- context: The client call context.
271- extensions: List of extensions to be activated.
246+ context: Optional client call context.
272247 """
273248 await self ._transport .delete_task_push_notification_config (
274- request , context = context , extensions = extensions
249+ request , context = context
275250 )
276251
277252 async def subscribe (
278253 self ,
279254 request : SubscribeToTaskRequest ,
280255 * ,
281256 context : ClientCallContext | None = None ,
282- extensions : list [str ] | None = None ,
283257 ) -> AsyncIterator [ClientEvent ]:
284258 """Resubscribes to a task's event stream.
285259
286260 This is only available if both the client and server support streaming.
287261
288262 Args:
289263 request: Parameters to identify the task to resubscribe to.
290- context: The client call context.
291- extensions: List of extensions to be activated.
264+ context: Optional client call context.
292265
293266 Yields:
294267 An async iterator of `ClientEvent` objects.
@@ -304,9 +277,7 @@ async def subscribe(
304277 # Note: resubscribe can only be called on an existing task. As such,
305278 # we should never see Message updates, despite the typing of the service
306279 # definition indicating it may be possible.
307- stream = self ._transport .subscribe (
308- request , context = context , extensions = extensions
309- )
280+ stream = self ._transport .subscribe (request , context = context )
310281 async for client_event in self ._process_stream (stream ):
311282 yield client_event
312283
@@ -315,7 +286,6 @@ async def get_extended_agent_card(
315286 request : GetExtendedAgentCardRequest ,
316287 * ,
317288 context : ClientCallContext | None = None ,
318- extensions : list [str ] | None = None ,
319289 signature_verifier : Callable [[AgentCard ], None ] | None = None ,
320290 ) -> AgentCard :
321291 """Retrieves the agent's card.
@@ -325,8 +295,7 @@ async def get_extended_agent_card(
325295
326296 Args:
327297 request: The `GetExtendedAgentCardRequest` object specifying the request.
328- context: The client call context.
329- extensions: List of extensions to be activated.
298+ context: Optional client call context.
330299 signature_verifier: A callable used to verify the agent card's signatures.
331300
332301 Returns:
@@ -335,7 +304,6 @@ async def get_extended_agent_card(
335304 card = await self ._transport .get_extended_agent_card (
336305 request ,
337306 context = context ,
338- extensions = extensions ,
339307 signature_verifier = signature_verifier ,
340308 )
341309 self ._card = card
0 commit comments