-
Notifications
You must be signed in to change notification settings - Fork 423
Expand file tree
/
Copy pathbase_client.py
More file actions
305 lines (264 loc) · 10.3 KB
/
base_client.py
File metadata and controls
305 lines (264 loc) · 10.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
from collections.abc import AsyncIterator, Callable
from types import TracebackType
from typing import Any
from typing_extensions import Self
from a2a.client.client import (
Client,
ClientCallContext,
ClientConfig,
ClientEvent,
Consumer,
)
from a2a.client.client_task_manager import ClientTaskManager
from a2a.client.errors import A2AClientInvalidStateError
from a2a.client.middleware import ClientCallInterceptor
from a2a.client.transports.base import ClientTransport
from a2a.types import (
AgentCard,
GetTaskPushNotificationConfigParams,
Message,
MessageSendConfiguration,
MessageSendParams,
Task,
TaskArtifactUpdateEvent,
TaskIdParams,
TaskPushNotificationConfig,
TaskQueryParams,
TaskStatusUpdateEvent,
)
class BaseClient(Client):
"""Base implementation of the A2A client, containing transport-independent logic."""
def __init__(
self,
card: AgentCard,
config: ClientConfig,
transport: ClientTransport,
consumers: list[Consumer],
middleware: list[ClientCallInterceptor],
):
super().__init__(consumers, middleware)
self._card = card
self._config = config
self._transport = transport
async def __aenter__(self) -> Self:
"""Enters the async context manager, returning the client itself."""
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Exits the async context manager, ensuring close() is called."""
await self.close()
async def send_message(
self,
request: Message,
*,
configuration: MessageSendConfiguration | None = None,
context: ClientCallContext | None = None,
request_metadata: dict[str, Any] | None = None,
extensions: list[str] | None = None,
) -> AsyncIterator[ClientEvent | Message]:
"""Sends a message to the agent.
This method handles both streaming and non-streaming (polling) interactions
based on the client configuration and agent capabilities. It will yield
events as they are received from the agent.
Args:
request: The message to send to the agent.
configuration: Optional per-call overrides for message sending behavior.
context: The client call context.
request_metadata: Extensions Metadata attached to the request.
extensions: List of extensions to be activated.
Yields:
An async iterator of `ClientEvent` or a final `Message` response.
"""
base_config = MessageSendConfiguration(
accepted_output_modes=self._config.accepted_output_modes,
blocking=not self._config.polling,
push_notification_config=(
self._config.push_notification_configs[0]
if self._config.push_notification_configs
else None
),
)
if configuration is not None:
update_data = configuration.model_dump(
exclude_unset=True,
by_alias=False,
)
config = base_config.model_copy(update=update_data)
else:
config = base_config
params = MessageSendParams(
message=request, configuration=config, metadata=request_metadata
)
if not self._config.streaming or not self._card.capabilities.streaming:
response = await self._transport.send_message(
params, context=context, extensions=extensions
)
result = (
(response, None) if isinstance(response, Task) else response
)
await self.consume(result, self._card)
yield result
return
tracker = ClientTaskManager()
stream = self._transport.send_message_streaming(
params, context=context, extensions=extensions
)
first_event = await anext(stream)
# The response from a server may be either exactly one Message or a
# series of Task updates. Separate out the first message for special
# case handling, which allows us to simplify further stream processing.
if isinstance(first_event, Message):
await self.consume(first_event, self._card)
yield first_event
return
yield await self._process_response(tracker, first_event)
async for event in stream:
yield await self._process_response(tracker, event)
async def _process_response(
self,
tracker: ClientTaskManager,
event: Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent,
) -> ClientEvent:
if isinstance(event, Message):
raise A2AClientInvalidStateError(
'received a streamed Message from server after first response; this is not supported'
)
await tracker.process(event)
task = tracker.get_task_or_raise()
update = None if isinstance(event, Task) else event
client_event = (task, update)
await self.consume(client_event, self._card)
return client_event
async def get_task(
self,
request: TaskQueryParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> Task:
"""Retrieves the current state and history of a specific task.
Args:
request: The `TaskQueryParams` object specifying the task ID.
context: The client call context.
extensions: List of extensions to be activated.
Returns:
A `Task` object representing the current state of the task.
"""
return await self._transport.get_task(
request, context=context, extensions=extensions
)
async def cancel_task(
self,
request: TaskIdParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> Task:
"""Requests the agent to cancel a specific task.
Args:
request: The `TaskIdParams` object specifying the task ID.
context: The client call context.
extensions: List of extensions to be activated.
Returns:
A `Task` object containing the updated task status.
"""
return await self._transport.cancel_task(
request, context=context, extensions=extensions
)
async def set_task_callback(
self,
request: TaskPushNotificationConfig,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> TaskPushNotificationConfig:
"""Sets or updates the push notification configuration for a specific task.
Args:
request: The `TaskPushNotificationConfig` object with the new configuration.
context: The client call context.
extensions: List of extensions to be activated.
Returns:
The created or updated `TaskPushNotificationConfig` object.
"""
return await self._transport.set_task_callback(
request, context=context, extensions=extensions
)
async def get_task_callback(
self,
request: GetTaskPushNotificationConfigParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> TaskPushNotificationConfig:
"""Retrieves the push notification configuration for a specific task.
Args:
request: The `GetTaskPushNotificationConfigParams` object specifying the task.
context: The client call context.
extensions: List of extensions to be activated.
Returns:
A `TaskPushNotificationConfig` object containing the configuration.
"""
return await self._transport.get_task_callback(
request, context=context, extensions=extensions
)
async def resubscribe(
self,
request: TaskIdParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> AsyncIterator[ClientEvent]:
"""Resubscribes to a task's event stream.
This is only available if both the client and server support streaming.
Args:
request: Parameters to identify the task to resubscribe to.
context: The client call context.
extensions: List of extensions to be activated.
Yields:
An async iterator of `ClientEvent` objects.
Raises:
NotImplementedError: If streaming is not supported by the client or server.
"""
if not self._config.streaming or not self._card.capabilities.streaming:
raise NotImplementedError(
'client and/or server do not support resubscription.'
)
tracker = ClientTaskManager()
# Note: resubscribe can only be called on an existing task. As such,
# we should never see Message updates, despite the typing of the service
# definition indicating it may be possible.
async for event in self._transport.resubscribe(
request, context=context, extensions=extensions
):
yield await self._process_response(tracker, event)
async def get_card(
self,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Retrieves the agent's card.
This will fetch the authenticated card if necessary and update the
client's internal state with the new card.
Args:
context: The client call context.
extensions: List of extensions to be activated.
signature_verifier: A callable used to verify the agent card's signatures.
Returns:
The `AgentCard` for the agent.
"""
card = await self._transport.get_card(
context=context,
extensions=extensions,
signature_verifier=signature_verifier,
)
self._card = card
return card
async def close(self) -> None:
"""Closes the underlying transport."""
await self._transport.close()