Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ test_venv/
coverage.xml
.nox
spec.json
.idea
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ postgresql = ["sqlalchemy[asyncio,postgresql-asyncpg]>=2.0.0"]
mysql = ["sqlalchemy[asyncio,aiomysql]>=2.0.0"]
signing = ["PyJWT>=2.0.0"]
sqlite = ["sqlalchemy[asyncio,aiosqlite]>=2.0.0"]
aws = ["aioboto3>=13.0.0"]

sql = ["a2a-sdk[postgresql,mysql,sqlite]"]

Expand All @@ -47,6 +48,7 @@ all = [
"a2a-sdk[grpc]",
"a2a-sdk[telemetry]",
"a2a-sdk[signing]",
"a2a-sdk[aws]",
]

[project.urls]
Expand Down
64 changes: 64 additions & 0 deletions src/a2a/server/events/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Event handling components for the A2A server."""

import logging

from a2a.server.events.event_consumer import EventConsumer
from a2a.server.events.event_queue import Event, EventQueue
from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager
Expand All @@ -10,12 +12,74 @@
)


logger = logging.getLogger(__name__)

try:
from a2a.server.events.distributed_event_queue import (
DistributedEventQueue, # type: ignore
)
from a2a.server.events.queue_lifecycle_manager import (
QueueLifecycleManager, # type: ignore
QueueProvisionResult, # type: ignore
)
from a2a.server.events.sns_queue_manager import (
SnsQueueManager, # type: ignore
)
except ImportError as e:
_original_aws_error = e
logger.debug(
'AWS distributed event components not loaded. '
'Install the aws extra to enable them. Error: %s',
e,
)

class DistributedEventQueue: # type: ignore
"""Placeholder when aws extra is not installed."""

def __init__(self, *args, **kwargs):
raise ImportError(
'To use DistributedEventQueue, install the aws extra: '
'\'pip install "a2a-sdk[aws]"\''
) from _original_aws_error

class SnsQueueManager: # type: ignore
"""Placeholder when aws extra is not installed."""

def __init__(self, *args, **kwargs):
raise ImportError(
'To use SnsQueueManager, install the aws extra: '
'\'pip install "a2a-sdk[aws]"\''
) from _original_aws_error

class QueueLifecycleManager: # type: ignore
"""Placeholder when aws extra is not installed."""

def __init__(self, *args, **kwargs):
raise ImportError(
'To use QueueLifecycleManager, install the aws extra: '
'\'pip install "a2a-sdk[aws]"\''
) from _original_aws_error

class QueueProvisionResult: # type: ignore
"""Placeholder when aws extra is not installed."""

def __init__(self, *args, **kwargs):
raise ImportError(
'To use QueueProvisionResult, install the aws extra: '
'\'pip install "a2a-sdk[aws]"\''
) from _original_aws_error


__all__ = [
'DistributedEventQueue',
'Event',
'EventConsumer',
'EventQueue',
'InMemoryQueueManager',
'NoTaskQueue',
'QueueLifecycleManager',
'QueueManager',
'QueueProvisionResult',
'SnsQueueManager',
'TaskQueueExists',
]
237 changes: 237 additions & 0 deletions src/a2a/server/events/distributed_event_queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
"""DistributedEventQueue — EventQueue with SNS fan-out for multi-instance A2A."""

import asyncio
import json
import logging

from collections.abc import Awaitable, Callable
from typing import Any

from a2a.server.events.event_queue import (
DEFAULT_MAX_QUEUE_SIZE,
Event,
EventQueue,
)
from a2a.types import (
Message,
Task,
TaskArtifactUpdateEvent,
TaskStatusUpdateEvent,
)


logger = logging.getLogger(__name__)

# Wire-format type tag used for graceful queue close across instances.
_CLOSE_TYPE = 'close'
_EVENT_TYPE = 'event'

# Map of ``kind`` discriminator → concrete Pydantic model class.
_KIND_TO_TYPE: dict[str, type[Event]] = {
'message': Message,
'task': Task,
'artifact-update': TaskArtifactUpdateEvent,
'status-update': TaskStatusUpdateEvent,
}


def _serialise_event(
event: Event,
task_id: str,
instance_id: str,
) -> str:
"""Serialises an event into the SNS wire-format JSON string.

Check failure on line 43 in src/a2a/server/events/distributed_event_queue.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`Serialises` is not a recognized word. (unrecognized-spelling)

Args:
event: The event to serialise.
task_id: The task ID this event belongs to.
instance_id: The originating instance ID (for dedup).

Check failure on line 48 in src/a2a/server/events/distributed_event_queue.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`dedup` is not a recognized word. (unrecognized-spelling)

Returns:
A JSON string suitable for use as an SNS ``Message`` payload.
"""
payload: dict[str, Any] = {
'instance_id': instance_id,
'task_id': task_id,
'type': _EVENT_TYPE,
'event_kind': event.kind,
'event_data': json.loads(event.model_dump_json()),
Comment thread
hashtekconsulting marked this conversation as resolved.
Outdated
}
return json.dumps(payload)


def _serialise_close(task_id: str, instance_id: str) -> str:
"""Serialises a close signal into the SNS wire-format JSON string.

Check failure on line 64 in src/a2a/server/events/distributed_event_queue.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`Serialises` is not a recognized word. (unrecognized-spelling)

Check warning on line 64 in src/a2a/server/events/distributed_event_queue.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`Serialises` is not a recognized word. (unrecognized-spelling)

Args:
task_id: The task ID whose queue is being closed.
instance_id: The originating instance ID.

Returns:
A JSON string suitable for use as an SNS ``Message`` payload.
"""
payload: dict[str, Any] = {
'instance_id': instance_id,
'task_id': task_id,
'type': _CLOSE_TYPE,
}
return json.dumps(payload)


def deserialise_wire_message(
raw: str,
) -> dict[str, Any]:
"""Parses a raw SNS/SQS wire-format JSON string.

Args:
raw: The raw JSON string from an SQS message body.

Returns:
The parsed wire-format dictionary. The caller is responsible for
routing based on the ``type`` field (``'event'`` or ``'close'``).

Raises:
ValueError: If the JSON is malformed or the ``type`` field is absent.
"""
try:
msg: dict[str, Any] = json.loads(raw)
except json.JSONDecodeError as exc:
raise ValueError(f'Malformed wire message: {raw!r}') from exc
if 'type' not in msg:
raise ValueError(f"Wire message missing 'type' field: {msg!r}")
return msg


def decode_event(msg: dict[str, Any]) -> Event | None:
"""Decodes an event from a parsed wire-format dictionary.

Args:
msg: A parsed wire-format dictionary with ``event_kind`` and
``event_data`` fields.

Returns:
The decoded Event, or ``None`` if the ``kind`` is unrecognised.
"""
kind = msg.get('event_kind')
event_data = msg.get('event_data')
if kind is None or event_data is None:
logger.warning('Wire message missing event_kind or event_data: %s', msg)
return None
event_cls = _KIND_TO_TYPE.get(kind)
if event_cls is None:
logger.warning('Unknown event kind in wire message: %s', kind)
return None
return event_cls.model_validate(event_data)


class DistributedEventQueue(EventQueue):
"""EventQueue subclass that publishes events to SNS for multi-instance delivery.

When ``enqueue_event`` is called by an agent handler, the event is:

1. Enqueued locally (for the current instance's SSE stream), **and**
2. Published asynchronously to SNS (for fan-out to all other instances).

When the SQS poller on a remote instance receives the SNS notification, it
calls ``enqueue_local`` directly — bypassing SNS re-publication — to avoid
infinite broadcast loops.

Args:
publish_fn: Async callable ``(message: str) -> None`` that publishes
the serialised wire message to SNS. Provided by
:class:`SnsQueueManager` and injected at construction time.
task_id: The task ID this queue serves.
instance_id: The unique ID of the local instance (used for dedup).

Check failure on line 144 in src/a2a/server/events/distributed_event_queue.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`dedup` is not a recognized word. (unrecognized-spelling)

Check warning on line 144 in src/a2a/server/events/distributed_event_queue.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`dedup` is not a recognized word. (unrecognized-spelling)
max_queue_size: Maximum number of events to buffer locally.
Defaults to ``DEFAULT_MAX_QUEUE_SIZE``.
"""

def __init__(
self,
publish_fn: Callable[[str], Awaitable[None]],
task_id: str,
instance_id: str,
*,
max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE,
) -> None:
"""Initialises the DistributedEventQueue."""
super().__init__(max_queue_size=max_queue_size)
self._publish_fn = publish_fn
self._task_id = task_id
self._instance_id = instance_id
logger.debug(
'DistributedEventQueue initialised (task_id=%s, instance=%s).',
task_id,
instance_id,
)

async def enqueue_event(self, event: Event) -> None:
"""Enqueues the event locally and publishes it to SNS.

The SNS publish is fire-and-forget (``asyncio.create_task``) so that
local delivery is never delayed by network I/O.

Args:
event: The event to enqueue and broadcast.
"""
await super().enqueue_event(event)
asyncio.create_task(self._publish_event(event)) # noqa: RUF006

async def enqueue_local(self, event: Event) -> None:
"""Enqueues an event locally without re-publishing to SNS.

Called by the SQS poller when delivering a remote event to this
instance. Using this method prevents the event from being
re-broadcast back to SNS, which would create an infinite loop.

Args:
event: The event received from the SQS queue.
"""
await super().enqueue_event(event)

async def close(self, immediate: bool = False) -> None:
"""Closes the queue locally and publishes a close signal to SNS.

The close signal allows other instances to also close their local
queues for the same task, ensuring clean shutdown across the cluster.

Args:
immediate: If ``True``, discard buffered events immediately
rather than waiting for them to drain.
"""
if not self.is_closed():
asyncio.create_task(self._publish_close()) # noqa: RUF006
await super().close(immediate=immediate)

async def _publish_event(self, event: Event) -> None:
"""Fire-and-forget coroutine: serialises and publishes one event.

Check failure on line 207 in src/a2a/server/events/distributed_event_queue.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`serialises` is not a recognized word. (unrecognized-spelling)

Args:
event: The event to publish.
"""
try:
message = _serialise_event(event, self._task_id, self._instance_id)
await self._publish_fn(message)
logger.debug(
'Event published to SNS (task_id=%s, kind=%s).',
self._task_id,
event.kind,
)
except Exception:
logger.exception(
'Failed to publish event to SNS (task_id=%s).', self._task_id
)

async def _publish_close(self) -> None:
"""Fire-and-forget coroutine: publishes the close signal to SNS."""
try:
message = _serialise_close(self._task_id, self._instance_id)
await self._publish_fn(message)
logger.debug(
'Close signal published to SNS (task_id=%s).', self._task_id
)
except Exception:
logger.exception(
'Failed to publish close signal to SNS (task_id=%s).',
self._task_id,
)
Loading
Loading