-
Notifications
You must be signed in to change notification settings - Fork 424
Expand file tree
/
Copy pathdistributed_event_queue.py
More file actions
237 lines (194 loc) · 7.58 KB
/
distributed_event_queue.py
File metadata and controls
237 lines (194 loc) · 7.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
"""DistributedEventQueue — EventQueue with SNS fan-out for multi-instance A2A."""
import asyncio
import json
import logging
from collections.abc import Awaitable, Callable
from typing import Any
from a2a.server.events.event_queue import (
DEFAULT_MAX_QUEUE_SIZE,
Event,
EventQueue,
)
from a2a.types import (
Message,
Task,
TaskArtifactUpdateEvent,
TaskStatusUpdateEvent,
)
logger = logging.getLogger(__name__)
# Wire-format type tag used for graceful queue close across instances.
_CLOSE_TYPE = 'close'
_EVENT_TYPE = 'event'
# Map of ``kind`` discriminator → concrete Pydantic model class.
_KIND_TO_TYPE: dict[str, type[Event]] = {
'message': Message,
'task': Task,
'artifact-update': TaskArtifactUpdateEvent,
'status-update': TaskStatusUpdateEvent,
}
def _serialise_event(
event: Event,
task_id: str,
instance_id: str,
) -> str:
"""Serialises an event into the SNS wire-format JSON string.
Args:
event: The event to serialise.
task_id: The task ID this event belongs to.
instance_id: The originating instance ID (for dedup).
Returns:
A JSON string suitable for use as an SNS ``Message`` payload.
"""
payload: dict[str, Any] = {
'instance_id': instance_id,
'task_id': task_id,
'type': _EVENT_TYPE,
'event_kind': event.kind,
'event_data': json.loads(event.model_dump_json()),
}
return json.dumps(payload)
def _serialise_close(task_id: str, instance_id: str) -> str:
"""Serialises a close signal into the SNS wire-format JSON string.
Args:
task_id: The task ID whose queue is being closed.
instance_id: The originating instance ID.
Returns:
A JSON string suitable for use as an SNS ``Message`` payload.
"""
payload: dict[str, Any] = {
'instance_id': instance_id,
'task_id': task_id,
'type': _CLOSE_TYPE,
}
return json.dumps(payload)
def deserialise_wire_message(
raw: str,
) -> dict[str, Any]:
"""Parses a raw SNS/SQS wire-format JSON string.
Args:
raw: The raw JSON string from an SQS message body.
Returns:
The parsed wire-format dictionary. The caller is responsible for
routing based on the ``type`` field (``'event'`` or ``'close'``).
Raises:
ValueError: If the JSON is malformed or the ``type`` field is absent.
"""
try:
msg: dict[str, Any] = json.loads(raw)
except json.JSONDecodeError as exc:
raise ValueError(f'Malformed wire message: {raw!r}') from exc
if 'type' not in msg:
raise ValueError(f"Wire message missing 'type' field: {msg!r}")
return msg
def decode_event(msg: dict[str, Any]) -> Event | None:
"""Decodes an event from a parsed wire-format dictionary.
Args:
msg: A parsed wire-format dictionary with ``event_kind`` and
``event_data`` fields.
Returns:
The decoded Event, or ``None`` if the ``kind`` is unrecognised.
"""
kind = msg.get('event_kind')
event_data = msg.get('event_data')
if kind is None or event_data is None:
logger.warning('Wire message missing event_kind or event_data: %s', msg)
return None
event_cls = _KIND_TO_TYPE.get(kind)
if event_cls is None:
logger.warning('Unknown event kind in wire message: %s', kind)
return None
return event_cls.model_validate(event_data)
class DistributedEventQueue(EventQueue):
"""EventQueue subclass that publishes events to SNS for multi-instance delivery.
When ``enqueue_event`` is called by an agent handler, the event is:
1. Enqueued locally (for the current instance's SSE stream), **and**
2. Published asynchronously to SNS (for fan-out to all other instances).
When the SQS poller on a remote instance receives the SNS notification, it
calls ``enqueue_local`` directly — bypassing SNS re-publication — to avoid
infinite broadcast loops.
Args:
publish_fn: Async callable ``(message: str) -> None`` that publishes
the serialised wire message to SNS. Provided by
:class:`SnsQueueManager` and injected at construction time.
task_id: The task ID this queue serves.
instance_id: The unique ID of the local instance (used for dedup).
max_queue_size: Maximum number of events to buffer locally.
Defaults to ``DEFAULT_MAX_QUEUE_SIZE``.
"""
def __init__(
self,
publish_fn: Callable[[str], Awaitable[None]],
task_id: str,
instance_id: str,
*,
max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE,
) -> None:
"""Initialises the DistributedEventQueue."""
super().__init__(max_queue_size=max_queue_size)
self._publish_fn = publish_fn
self._task_id = task_id
self._instance_id = instance_id
logger.debug(
'DistributedEventQueue initialised (task_id=%s, instance=%s).',
task_id,
instance_id,
)
async def enqueue_event(self, event: Event) -> None:
"""Enqueues the event locally and publishes it to SNS.
The SNS publish is fire-and-forget (``asyncio.create_task``) so that
local delivery is never delayed by network I/O.
Args:
event: The event to enqueue and broadcast.
"""
await super().enqueue_event(event)
asyncio.create_task(self._publish_event(event)) # noqa: RUF006
async def enqueue_local(self, event: Event) -> None:
"""Enqueues an event locally without re-publishing to SNS.
Called by the SQS poller when delivering a remote event to this
instance. Using this method prevents the event from being
re-broadcast back to SNS, which would create an infinite loop.
Args:
event: The event received from the SQS queue.
"""
await super().enqueue_event(event)
async def close(self, immediate: bool = False) -> None:
"""Closes the queue locally and publishes a close signal to SNS.
The close signal allows other instances to also close their local
queues for the same task, ensuring clean shutdown across the cluster.
Args:
immediate: If ``True``, discard buffered events immediately
rather than waiting for them to drain.
"""
if not self.is_closed():
asyncio.create_task(self._publish_close()) # noqa: RUF006
await super().close(immediate=immediate)
async def _publish_event(self, event: Event) -> None:
"""Fire-and-forget coroutine: serialises and publishes one event.
Args:
event: The event to publish.
"""
try:
message = _serialise_event(event, self._task_id, self._instance_id)
await self._publish_fn(message)
logger.debug(
'Event published to SNS (task_id=%s, kind=%s).',
self._task_id,
event.kind,
)
except Exception:
logger.exception(
'Failed to publish event to SNS (task_id=%s).', self._task_id
)
async def _publish_close(self) -> None:
"""Fire-and-forget coroutine: publishes the close signal to SNS."""
try:
message = _serialise_close(self._task_id, self._instance_id)
await self._publish_fn(message)
logger.debug(
'Close signal published to SNS (task_id=%s).', self._task_id
)
except Exception:
logger.exception(
'Failed to publish close signal to SNS (task_id=%s).',
self._task_id,
)