Skip to content

Commit 2f12712

Browse files
committed
Merge remote-tracking branch 'origin/1.0-dev' into ishymko/716-fix
2 parents d09a6a2 + a149a09 commit 2f12712

13 files changed

Lines changed: 312 additions & 108 deletions

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from a2a.server.request_handlers.request_handler import RequestHandler
2222
from a2a.server.tasks import (
2323
PushNotificationConfigStore,
24+
PushNotificationEvent,
2425
PushNotificationSender,
2526
ResultAggregator,
2627
TaskManager,
@@ -52,7 +53,11 @@
5253
TaskNotFoundError,
5354
UnsupportedOperationError,
5455
)
55-
from a2a.utils.task import apply_history_length
56+
from a2a.utils.task import (
57+
apply_history_length,
58+
validate_history_length,
59+
validate_page_size,
60+
)
5661
from a2a.utils.telemetry import SpanKind, trace_class
5762

5863

@@ -122,6 +127,8 @@ async def on_get_task(
122127
context: ServerCallContext | None = None,
123128
) -> Task | None:
124129
"""Default handler for 'tasks/get'."""
130+
validate_history_length(params)
131+
125132
task_id = params.id
126133
task: Task | None = await self.task_store.get(task_id, context)
127134
if not task:
@@ -135,6 +142,10 @@ async def on_list_tasks(
135142
context: ServerCallContext | None = None,
136143
) -> ListTasksResponse:
137144
"""Default handler for 'tasks/list'."""
145+
validate_history_length(params)
146+
if params.HasField('page_size'):
147+
validate_page_size(params.page_size)
148+
138149
page = await self.task_store.list(params, context)
139150
for task in page.tasks:
140151
if not params.include_artifacts:
@@ -309,13 +320,15 @@ def _validate_task_id_match(self, task_id: str, event_task_id: str) -> None:
309320
)
310321

311322
async def _send_push_notification_if_needed(
312-
self, task_id: str, result_aggregator: ResultAggregator
323+
self, task_id: str, event: Event
313324
) -> None:
314-
"""Sends push notification if configured and task is available."""
315-
if self._push_sender and task_id:
316-
latest_task = await result_aggregator.current_result
317-
if isinstance(latest_task, Task):
318-
await self._push_sender.send_notification(latest_task)
325+
"""Sends push notification if configured."""
326+
if (
327+
self._push_sender
328+
and task_id
329+
and isinstance(event, PushNotificationEvent)
330+
):
331+
await self._push_sender.send_notification(task_id, event)
319332

320333
async def on_message_send(
321334
self,
@@ -327,6 +340,8 @@ async def on_message_send(
327340
Starts the agent execution for the message and waits for the final
328341
result (Task or Message).
329342
"""
343+
validate_history_length(params.configuration)
344+
330345
(
331346
_task_manager,
332347
task_id,
@@ -345,10 +360,8 @@ async def on_message_send(
345360
interrupted_or_non_blocking = False
346361
try:
347362
# Create async callback for push notifications
348-
async def push_notification_callback() -> None:
349-
await self._send_push_notification_if_needed(
350-
task_id, result_aggregator
351-
)
363+
async def push_notification_callback(event: Event) -> None:
364+
await self._send_push_notification_if_needed(task_id, event)
352365

353366
(
354367
result,
@@ -381,8 +394,6 @@ async def push_notification_callback() -> None:
381394
if params.configuration:
382395
result = apply_history_length(result, params.configuration)
383396

384-
await self._send_push_notification_if_needed(task_id, result_aggregator)
385-
386397
return result
387398

388399
async def on_message_send_stream(
@@ -410,9 +421,7 @@ async def on_message_send_stream(
410421
if isinstance(event, Task):
411422
self._validate_task_id_match(task_id, event.id)
412423

413-
await self._send_push_notification_if_needed(
414-
task_id, result_aggregator
415-
)
424+
await self._send_push_notification_if_needed(task_id, event)
416425
yield event
417426
except (asyncio.CancelledError, GeneratorExit):
418427
# Client disconnected: continue consuming and persisting events in the background

src/a2a/server/tasks/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
from a2a.server.tasks.push_notification_config_store import (
1313
PushNotificationConfigStore,
1414
)
15-
from a2a.server.tasks.push_notification_sender import PushNotificationSender
15+
from a2a.server.tasks.push_notification_sender import (
16+
PushNotificationEvent,
17+
PushNotificationSender,
18+
)
1619
from a2a.server.tasks.result_aggregator import ResultAggregator
1720
from a2a.server.tasks.task_manager import TaskManager
1821
from a2a.server.tasks.task_store import TaskStore
@@ -72,6 +75,7 @@ def __init__(self, *args, **kwargs):
7275
'InMemoryPushNotificationConfigStore',
7376
'InMemoryTaskStore',
7477
'PushNotificationConfigStore',
78+
'PushNotificationEvent',
7579
'PushNotificationSender',
7680
'ResultAggregator',
7781
'TaskManager',

src/a2a/server/tasks/base_push_notification_sender.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@
88
from a2a.server.tasks.push_notification_config_store import (
99
PushNotificationConfigStore,
1010
)
11-
from a2a.server.tasks.push_notification_sender import PushNotificationSender
12-
from a2a.types.a2a_pb2 import PushNotificationConfig, StreamResponse, Task
11+
from a2a.server.tasks.push_notification_sender import (
12+
PushNotificationEvent,
13+
PushNotificationSender,
14+
)
15+
from a2a.types.a2a_pb2 import PushNotificationConfig
16+
from a2a.utils.proto_utils import to_stream_response
1317

1418

1519
logger = logging.getLogger(__name__)
@@ -32,44 +36,50 @@ def __init__(
3236
self._client = httpx_client
3337
self._config_store = config_store
3438

35-
async def send_notification(self, task: Task) -> None:
36-
"""Sends a push notification for a task if configuration exists."""
37-
push_configs = await self._config_store.get_info(task.id)
39+
async def send_notification(
40+
self, task_id: str, event: PushNotificationEvent
41+
) -> None:
42+
"""Sends a push notification for an event if configuration exists."""
43+
push_configs = await self._config_store.get_info(task_id)
3844
if not push_configs:
3945
return
4046

4147
awaitables = [
42-
self._dispatch_notification(task, push_info)
48+
self._dispatch_notification(event, push_info, task_id)
4349
for push_info in push_configs
4450
]
4551
results = await asyncio.gather(*awaitables)
4652

4753
if not all(results):
4854
logger.warning(
49-
'Some push notifications failed to send for task_id=%s', task.id
55+
'Some push notifications failed to send for task_id=%s', task_id
5056
)
5157

5258
async def _dispatch_notification(
53-
self, task: Task, push_info: PushNotificationConfig
59+
self,
60+
event: PushNotificationEvent,
61+
push_info: PushNotificationConfig,
62+
task_id: str,
5463
) -> bool:
5564
url = push_info.url
5665
try:
5766
headers = None
5867
if push_info.token:
5968
headers = {'X-A2A-Notification-Token': push_info.token}
69+
6070
response = await self._client.post(
6171
url,
62-
json=MessageToDict(StreamResponse(task=task)),
72+
json=MessageToDict(to_stream_response(event)),
6373
headers=headers,
6474
)
6575
response.raise_for_status()
6676
logger.info(
67-
'Push-notification sent for task_id=%s to URL: %s', task.id, url
77+
'Push-notification sent for task_id=%s to URL: %s', task_id, url
6878
)
6979
except Exception:
7080
logger.exception(
7181
'Error sending push-notification for task_id=%s to URL: %s.',
72-
task.id,
82+
task_id,
7383
url,
7484
)
7585
return False
Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
11
from abc import ABC, abstractmethod
22

3-
from a2a.types.a2a_pb2 import Task
3+
from a2a.types.a2a_pb2 import (
4+
Task,
5+
TaskArtifactUpdateEvent,
6+
TaskStatusUpdateEvent,
7+
)
8+
9+
10+
PushNotificationEvent = Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
411

512

613
class PushNotificationSender(ABC):
714
"""Interface for sending push notifications for tasks."""
815

916
@abstractmethod
10-
async def send_notification(self, task: Task) -> None:
17+
async def send_notification(
18+
self, task_id: str, event: PushNotificationEvent
19+
) -> None:
1120
"""Sends a push notification containing the latest task state."""

src/a2a/server/tasks/result_aggregator.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ async def consume_and_break_on_interrupt(
9898
self,
9999
consumer: EventConsumer,
100100
blocking: bool = True,
101-
event_callback: Callable[[], Awaitable[None]] | None = None,
101+
event_callback: Callable[[Event], Awaitable[None]] | None = None,
102102
) -> tuple[Task | Message | None, bool]:
103103
"""Processes the event stream until completion or an interruptible state is encountered.
104104
@@ -131,6 +131,9 @@ async def consume_and_break_on_interrupt(
131131
return event, False
132132
await self.task_manager.process(event)
133133

134+
if event_callback:
135+
await event_callback(event)
136+
134137
should_interrupt = False
135138
is_auth_required = (
136139
isinstance(event, Task | TaskStatusUpdateEvent)
@@ -169,7 +172,7 @@ async def consume_and_break_on_interrupt(
169172
async def _continue_consuming(
170173
self,
171174
event_stream: AsyncIterator[Event],
172-
event_callback: Callable[[], Awaitable[None]] | None = None,
175+
event_callback: Callable[[Event], Awaitable[None]] | None = None,
173176
) -> None:
174177
"""Continues processing an event stream in a background task.
175178
@@ -183,4 +186,4 @@ async def _continue_consuming(
183186
async for event in event_stream:
184187
await self.task_manager.process(event)
185188
if event_callback:
186-
await event_callback()
189+
await event_callback(event)

src/a2a/utils/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
DEFAULT_LIST_TASKS_PAGE_SIZE = 50
1111
"""Default page size for the `tasks/list` method."""
1212

13+
MAX_LIST_TASKS_PAGE_SIZE = 100
14+
"""Maximum page size for the `tasks/list` method."""
15+
1316

1417
class TransportProtocol(str, Enum):
1518
"""Transport protocol string constants."""

src/a2a/utils/task.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
TaskState,
1414
TaskStatus,
1515
)
16+
from a2a.utils.constants import MAX_LIST_TASKS_PAGE_SIZE
17+
from a2a.utils.errors import InvalidParamsError, ServerError
1618

1719

1820
def new_task(request: Message) -> Task:
@@ -96,6 +98,16 @@ def HasField(self, field_name: Literal['history_length']) -> bool: # noqa: N802
9698
...
9799

98100

101+
def validate_history_length(config: HistoryLengthConfig | None) -> None:
102+
"""Validates that history_length is non-negative."""
103+
if config and config.history_length < 0:
104+
raise ServerError(
105+
error=InvalidParamsError(
106+
message='history length must be non-negative'
107+
)
108+
)
109+
110+
99111
def apply_history_length(
100112
task: Task, config: HistoryLengthConfig | None
101113
) -> Task:
@@ -136,6 +148,24 @@ def apply_history_length(
136148
return task
137149

138150

151+
def validate_page_size(page_size: int) -> None:
152+
"""Validates that page_size is in range [1, 100].
153+
154+
See Also:
155+
https://a2a-protocol.org/latest/specification/#314-list-tasks
156+
"""
157+
if page_size < 1:
158+
raise ServerError(
159+
error=InvalidParamsError(message='minimum page size is 1')
160+
)
161+
if page_size > MAX_LIST_TASKS_PAGE_SIZE:
162+
raise ServerError(
163+
error=InvalidParamsError(
164+
message=f'maximum page size is {MAX_LIST_TASKS_PAGE_SIZE}'
165+
)
166+
)
167+
168+
139169
_ENCODING = 'utf-8'
140170

141171

tck/sut_agent.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
AgentSkill,
2424
Message,
2525
Part,
26+
Role,
2627
TaskState,
2728
TaskStatus,
2829
TaskStatusUpdateEvent,
@@ -87,7 +88,7 @@ async def execute(
8788
status=TaskStatus(
8889
state=TaskState.TASK_STATE_WORKING,
8990
message=Message(
90-
role='agent',
91+
role=Role.ROLE_AGENT,
9192
message_id=str(uuid.uuid4()),
9293
parts=[Part(text='Processing your question')],
9394
task_id=task_id,
@@ -108,7 +109,7 @@ async def execute(
108109
logger.info('[SUTAgentExecutor] Response: %s', agent_reply_text)
109110

110111
agent_message = Message(
111-
role='agent',
112+
role=Role.ROLE_AGENT,
112113
message_id=str(uuid.uuid4()),
113114
parts=[Part(text=agent_reply_text)],
114115
task_id=task_id,

tests/e2e/push_notifications/notifications_app.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
class Notification(BaseModel):
1313
"""Encapsulates default push notification data."""
1414

15-
task: dict[str, Any]
15+
event: dict[str, Any]
1616
token: str
1717

1818

@@ -36,20 +36,33 @@ async def add_notification(request: Request):
3636
try:
3737
json_data = await request.json()
3838
stream_response = ParseDict(json_data, StreamResponse())
39-
if not stream_response.HasField('task'):
39+
40+
payload_name = stream_response.WhichOneof('payload')
41+
task_id = None
42+
if payload_name:
43+
event_payload = getattr(stream_response, payload_name)
44+
# The 'Task' message uses 'id', while event messages use 'task_id'.
45+
task_id = getattr(
46+
event_payload, 'task_id', getattr(event_payload, 'id', None)
47+
)
48+
49+
if not task_id:
4050
raise HTTPException(
41-
status_code=400, detail='Missing task in StreamResponse'
51+
status_code=400,
52+
detail='Missing "task_id" in push notification.',
4253
)
43-
task = stream_response.task
54+
4455
except Exception as e:
4556
raise HTTPException(status_code=400, detail=str(e))
4657

4758
async with store_lock:
48-
if task.id not in store:
49-
store[task.id] = []
50-
store[task.id].append(
59+
if task_id not in store:
60+
store[task_id] = []
61+
store[task_id].append(
5162
Notification(
52-
task=MessageToDict(task, preserving_proto_field_name=True),
63+
event=MessageToDict(
64+
stream_response, preserving_proto_field_name=True
65+
),
5366
token=token,
5467
)
5568
)

0 commit comments

Comments
 (0)