Skip to content

Commit c00b7b8

Browse files
committed
fix(server): wrap task in StreamResponse for push notifications
Signed-off-by: Luca Muscariello <muscariello@ieee.org>
1 parent 6260ea2 commit c00b7b8

4 files changed

Lines changed: 19 additions & 12 deletions

File tree

src/a2a/server/tasks/base_push_notification_sender.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
PushNotificationConfigStore,
1010
)
1111
from a2a.server.tasks.push_notification_sender import PushNotificationSender
12-
from a2a.types.a2a_pb2 import PushNotificationConfig, Task
12+
from a2a.types.a2a_pb2 import PushNotificationConfig, StreamResponse, Task
1313

1414

1515
logger = logging.getLogger(__name__)
@@ -59,7 +59,7 @@ async def _dispatch_notification(
5959
headers = {'X-A2A-Notification-Token': push_info.token}
6060
response = await self._client.post(
6161
url,
62-
json=MessageToDict(task),
62+
json=MessageToDict(StreamResponse(task=task)),
6363
headers=headers,
6464
)
6565
response.raise_for_status()

tests/e2e/push_notifications/notifications_app.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from fastapi import FastAPI, HTTPException, Path, Request
66
from pydantic import BaseModel, ConfigDict, ValidationError
77

8-
from a2a.types.a2a_pb2 import Task
8+
from a2a.types.a2a_pb2 import StreamResponse, Task
99
from google.protobuf.json_format import ParseDict, MessageToDict
1010

1111

@@ -35,7 +35,12 @@ async def add_notification(request: Request):
3535
)
3636
try:
3737
json_data = await request.json()
38-
task = ParseDict(json_data, Task())
38+
stream_response = ParseDict(json_data, StreamResponse())
39+
if not stream_response.HasField('task'):
40+
raise HTTPException(
41+
status_code=400, detail='Missing task in StreamResponse'
42+
)
43+
task = stream_response.task
3944
except Exception as e:
4045
raise HTTPException(status_code=400, detail=str(e))
4146

tests/server/tasks/test_inmemory_push_notifications.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)
1414
from a2a.types.a2a_pb2 import (
1515
PushNotificationConfig,
16+
StreamResponse,
1617
Task,
1718
TaskState,
1819
TaskStatus,
@@ -162,7 +163,7 @@ async def test_send_notification_success(self) -> None:
162163
self.assertEqual(called_args[0], config.url)
163164
self.assertEqual(
164165
called_kwargs['json'],
165-
MessageToDict(task_data),
166+
MessageToDict(StreamResponse(task=task_data)),
166167
)
167168
self.assertNotIn(
168169
'auth', called_kwargs
@@ -189,7 +190,7 @@ async def test_send_notification_with_token_success(self) -> None:
189190
self.assertEqual(called_args[0], config.url)
190191
self.assertEqual(
191192
called_kwargs['json'],
192-
MessageToDict(task_data),
193+
MessageToDict(StreamResponse(task=task_data)),
193194
)
194195
self.assertEqual(
195196
called_kwargs['headers'],
@@ -287,7 +288,7 @@ async def test_send_notification_with_auth(
287288
self.assertEqual(called_args[0], config.url)
288289
self.assertEqual(
289290
called_kwargs['json'],
290-
MessageToDict(task_data),
291+
MessageToDict(StreamResponse(task=task_data)),
291292
)
292293
self.assertNotIn(
293294
'auth', called_kwargs

tests/server/tasks/test_push_notification_sender.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
)
1111
from a2a.types.a2a_pb2 import (
1212
PushNotificationConfig,
13+
StreamResponse,
1314
Task,
1415
TaskState,
1516
TaskStatus,
@@ -65,7 +66,7 @@ async def test_send_notification_success(self) -> None:
6566
# assert httpx_client post method got invoked with right parameters
6667
self.mock_httpx_client.post.assert_awaited_once_with(
6768
config.url,
68-
json=MessageToDict(task_data),
69+
json=MessageToDict(StreamResponse(task=task_data)),
6970
headers=None,
7071
)
7172
mock_response.raise_for_status.assert_called_once()
@@ -89,7 +90,7 @@ async def test_send_notification_with_token_success(self) -> None:
8990
# assert httpx_client post method got invoked with right parameters
9091
self.mock_httpx_client.post.assert_awaited_once_with(
9192
config.url,
92-
json=MessageToDict(task_data),
93+
json=MessageToDict(StreamResponse(task=task_data)),
9394
headers={'X-A2A-Notification-Token': 'unique_token'},
9495
)
9596
mock_response.raise_for_status.assert_called_once()
@@ -126,7 +127,7 @@ async def test_send_notification_http_status_error(
126127
self.mock_config_store.get_info.assert_awaited_once_with(task_id)
127128
self.mock_httpx_client.post.assert_awaited_once_with(
128129
config.url,
129-
json=MessageToDict(task_data),
130+
json=MessageToDict(StreamResponse(task=task_data)),
130131
headers=None,
131132
)
132133
mock_logger.exception.assert_called_once()
@@ -154,13 +155,13 @@ async def test_send_notification_multiple_configs(self) -> None:
154155
# Check calls for config1
155156
self.mock_httpx_client.post.assert_any_call(
156157
config1.url,
157-
json=MessageToDict(task_data),
158+
json=MessageToDict(StreamResponse(task=task_data)),
158159
headers=None,
159160
)
160161
# Check calls for config2
161162
self.mock_httpx_client.post.assert_any_call(
162163
config2.url,
163-
json=MessageToDict(task_data),
164+
json=MessageToDict(StreamResponse(task=task_data)),
164165
headers=None,
165166
)
166167
mock_response.raise_for_status.call_count = 2

0 commit comments

Comments
 (0)