Skip to content

Commit d06b7c2

Browse files
committed
Subscribe post.
1 parent cac6f58 commit d06b7c2

9 files changed

Lines changed: 311 additions & 18 deletions

File tree

src/a2a/client/transports/rest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,10 +262,11 @@ async def subscribe(
262262
) -> AsyncGenerator[StreamResponse]:
263263
"""Reconnects to get task updates."""
264264
async for event in self._send_stream_request(
265-
'GET',
265+
'POST',
266266
f'/tasks/{request.id}:subscribe',
267267
request.tenant,
268268
context=context,
269+
json=MessageToDict(request),
269270
):
270271
yield event
271272

src/a2a/compat/v0_3/rest_adapter.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,10 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]:
163163
self._handle_streaming_request,
164164
self.handler.on_subscribe_to_task,
165165
),
166+
('/v1/tasks/{id}:subscribe', 'POST'): functools.partial(
167+
self._handle_streaming_request,
168+
self.handler.on_subscribe_to_task,
169+
),
166170
('/v1/tasks/{id}', 'GET'): functools.partial(
167171
self._handle_request, self.handler.on_get_task
168172
),

src/a2a/compat/v0_3/rest_transport.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def __init__(
6868
self.url = url.removesuffix('/')
6969
self.httpx_client = httpx_client
7070
self.agent_card = agent_card
71+
self._subscribe_method = 'POST'
72+
self._subscribe_retry_attempted = False
7173

7274
async def send_message(
7375
self,
@@ -273,13 +275,39 @@ async def subscribe(
273275
*,
274276
context: ClientCallContext | None = None,
275277
) -> AsyncGenerator[StreamResponse]:
276-
"""Reconnects to get task updates."""
277-
async for event in self._send_stream_request(
278-
'GET',
279-
f'/v1/tasks/{request.id}:subscribe',
280-
context=context,
281-
):
282-
yield event
278+
"""Reconnects to get task updates.
279+
280+
This method implements backward compatibility logic for the subscribe
281+
endpoint. It first attempts to use POST, which is the official method
282+
for A2A subscribe endpoint. If the server returns 405 Method Not Allowed,
283+
it falls back to GET and remembers this preference for future calls
284+
on this transport instance. If both fail with 405, it will default back
285+
to POST for next calls but will not retry again.
286+
"""
287+
try:
288+
async for event in self._send_stream_request(
289+
self._subscribe_method,
290+
f'/v1/tasks/{request.id}:subscribe',
291+
context=context,
292+
):
293+
yield event
294+
except A2AClientError as e:
295+
# Check for 405 Method Not Allowed in the cause (httpx.HTTPStatusError)
296+
cause = e.__cause__
297+
if (
298+
isinstance(cause, httpx.HTTPStatusError)
299+
and cause.response.status_code == httpx.codes.METHOD_NOT_ALLOWED
300+
):
301+
if self._subscribe_retry_attempted:
302+
self._subscribe_method = 'POST'
303+
raise
304+
else:
305+
self._subscribe_method = 'GET'
306+
self._subscribe_retry_attempted = True
307+
async for event in self.subscribe(request, context=context):
308+
yield event
309+
else:
310+
raise
283311

284312
async def get_extended_agent_card(
285313
self,
@@ -311,7 +339,11 @@ async def close(self) -> None:
311339
def _handle_http_error(self, e: httpx.HTTPStatusError) -> NoReturn:
312340
"""Handles HTTP status errors and raises the appropriate A2AError."""
313341
try:
314-
error_data = e.response.json()
342+
try:
343+
error_data = e.response.json()
344+
except (json.JSONDecodeError, ValueError, httpx.ResponseNotRead):
345+
error_data = {}
346+
315347
error_type = error_data.get('type')
316348
message = error_data.get('message', str(e))
317349

src/a2a/server/apps/rest/rest_adapter.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,10 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]:
237237
self._handle_streaming_request,
238238
self.handler.on_subscribe_to_task,
239239
),
240+
('/tasks/{id}:subscribe', 'POST'): functools.partial(
241+
self._handle_streaming_request,
242+
self.handler.on_subscribe_to_task,
243+
),
240244
('/tasks/{id}', 'GET'): functools.partial(
241245
self._handle_request, self.handler.on_get_task
242246
),

src/a2a/server/request_handlers/rest_handler.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,15 @@ async def on_subscribe_to_task(
159159
Yields:
160160
JSON serialized objects containing streaming events
161161
"""
162-
task_id = request.path_params['id']
162+
params = SubscribeToTaskRequest()
163+
if request.method == 'POST':
164+
body = await request.body()
165+
if body:
166+
Parse(body, params)
167+
168+
params.id = request.path_params['id']
163169
async for event in self.request_handler.on_subscribe_to_task(
164-
SubscribeToTaskRequest(id=task_id), context
170+
params, context
165171
):
166172
yield MessageToDict(proto_utils.to_stream_response(event))
167173

tests/client/transports/test_rest_client.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -716,8 +716,15 @@ async def empty_aiter():
716716
async for _ in method(request=request_obj):
717717
pass
718718

719-
# 4. Verify the URL
719+
# 4. Verify the URL and method
720720
mock_aconnect_sse.assert_called_once()
721-
args, _ = mock_aconnect_sse.call_args
721+
args, kwargs = mock_aconnect_sse.call_args
722+
# method is 2nd positional argument
723+
if method_name == 'subscribe':
724+
assert args[1] == 'POST'
725+
assert kwargs.get('json') == json_format.MessageToDict(request_obj)
726+
else:
727+
assert args[1] == 'POST'
728+
722729
# url is 3rd positional argument in aconnect_sse(client, method, url, ...)
723730
assert args[2] == f'http://agent.example.com/api{expected_path}'

tests/compat/v0_3/test_rest_handler.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,44 @@ async def mock_stream(*args, **kwargs):
186186
]
187187

188188

189+
@pytest.mark.anyio
190+
async def test_on_subscribe_to_task_post(
191+
rest_handler, mock_request, mock_context
192+
):
193+
mock_request.path_params = {'id': 'task-1'}
194+
mock_request.method = 'POST'
195+
request_body = {'name': 'tasks/task-1'}
196+
mock_request.body = AsyncMock(
197+
return_value=json.dumps(request_body).encode('utf-8')
198+
)
199+
200+
async def mock_stream(*args, **kwargs):
201+
yield types_v03.SendStreamingMessageSuccessResponse(
202+
id='req-1',
203+
result=types_v03.Message(
204+
message_id='msg-2',
205+
role='agent',
206+
parts=[types_v03.TextPart(text='Update')],
207+
),
208+
)
209+
210+
rest_handler.handler03.on_subscribe_to_task = MagicMock(
211+
side_effect=mock_stream
212+
)
213+
214+
results = [
215+
chunk
216+
async for chunk in rest_handler.on_subscribe_to_task(
217+
mock_request, mock_context
218+
)
219+
]
220+
221+
assert len(results) == 1
222+
rest_handler.handler03.on_subscribe_to_task.assert_called_once()
223+
called_req = rest_handler.handler03.on_subscribe_to_task.call_args[0][0]
224+
assert called_req.params.id == 'task-1'
225+
226+
189227
@pytest.mark.anyio
190228
async def test_get_push_notification(rest_handler, mock_request, mock_context):
191229
mock_request.path_params = {'id': 'task-1', 'push_id': 'push-1'}

tests/compat/v0_3/test_rest_transport.py

Lines changed: 133 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
23
from unittest.mock import AsyncMock, MagicMock, patch
34

45
import httpx
@@ -232,14 +233,48 @@ async def mock_send_stream_request(*args, **kwargs):
232233
assert events[1] == StreamResponse(message=Message(message_id='msg-123'))
233234

234235

236+
def create_405_error():
237+
mock_response = MagicMock(spec=httpx.Response)
238+
mock_response.status_code = 405
239+
mock_response.json.return_value = {
240+
'type': 'MethodNotAllowed',
241+
'message': 'Method Not Allowed',
242+
}
243+
mock_request = MagicMock(spec=httpx.Request)
244+
mock_request.url = 'http://example.com/v1/tasks/task-123:subscribe'
245+
246+
status_error = httpx.HTTPStatusError(
247+
'405 Method Not Allowed', request=mock_request, response=mock_response
248+
)
249+
raise A2AClientError('HTTP Error 405') from status_error
250+
251+
252+
def create_500_error():
253+
mock_response = MagicMock(spec=httpx.Response)
254+
mock_response.status_code = 500
255+
mock_response.json.return_value = {
256+
'type': 'InternalError',
257+
'message': 'Internal Error',
258+
}
259+
mock_request = MagicMock(spec=httpx.Request)
260+
261+
status_error = httpx.HTTPStatusError(
262+
'500 Internal Error', request=mock_request, response=mock_response
263+
)
264+
raise A2AClientError('HTTP Error 500') from status_error
265+
266+
235267
@pytest.mark.asyncio
236-
async def test_compat_rest_transport_subscribe(transport):
237-
async def mock_send_stream_request(*args, **kwargs):
268+
async def test_compat_rest_transport_subscribe_post_works_no_retry(transport):
269+
"""Scenario: POST works, no retry."""
270+
271+
async def mock_stream(method, path, context=None):
272+
assert method == 'POST'
238273
task = Task(id='task-123')
239274
task.status.message.role = Role.ROLE_AGENT
240275
yield StreamResponse(task=task)
241276

242-
transport._send_stream_request = mock_send_stream_request
277+
transport._send_stream_request = mock_stream
243278

244279
req = SubscribeToTaskRequest(id='task-123')
245280
events = [event async for event in transport.subscribe(req)]
@@ -248,6 +283,101 @@ async def mock_send_stream_request(*args, **kwargs):
248283
expected_task = Task(id='task-123')
249284
expected_task.status.message.role = Role.ROLE_AGENT
250285
assert events[0] == StreamResponse(task=expected_task)
286+
assert transport._subscribe_method == 'POST'
287+
assert transport._subscribe_retry_attempted is False
288+
289+
290+
@pytest.mark.asyncio
291+
async def test_compat_rest_transport_subscribe_post_405_retry_get_success(
292+
transport,
293+
):
294+
"""Scenario: POST returns 405, automatic retry GET. Second call uses GET directly."""
295+
call_count = 0
296+
297+
async def mock_stream(method, path, context=None):
298+
nonlocal call_count
299+
call_count += 1
300+
if method == 'POST':
301+
create_405_error()
302+
if method == 'GET':
303+
task = Task(id='task-123')
304+
task.status.message.role = Role.ROLE_AGENT
305+
yield StreamResponse(task=task)
306+
307+
transport._send_stream_request = mock_stream
308+
309+
req = SubscribeToTaskRequest(id='task-123')
310+
events = [event async for event in transport.subscribe(req)]
311+
312+
assert len(events) == 1
313+
assert call_count == 2
314+
assert transport._subscribe_method == 'GET'
315+
assert transport._subscribe_retry_attempted is True
316+
317+
# Second call should use GET directly
318+
call_count = 0
319+
events = [event async for event in transport.subscribe(req)]
320+
assert len(events) == 1
321+
assert call_count == 1 # Only GET called
322+
assert transport._subscribe_method == 'GET'
323+
324+
325+
@pytest.mark.asyncio
326+
async def test_compat_rest_transport_subscribe_post_405_get_405_fails(
327+
transport,
328+
):
329+
"""Scenario: POST return 405, retry GET, return 405 - error. Second call is just POST."""
330+
call_count = 0
331+
332+
async def mock_stream(method, path, context=None):
333+
nonlocal call_count
334+
call_count += 1
335+
# To make it an async generator even when it raises
336+
if False:
337+
yield
338+
create_405_error()
339+
340+
transport._send_stream_request = mock_stream
341+
342+
req = SubscribeToTaskRequest(id='task-123')
343+
with pytest.raises(A2AClientError) as exc_info:
344+
[event async for event in transport.subscribe(req)]
345+
346+
assert '405' in str(exc_info.value)
347+
assert call_count == 2 # Tried POST then GET
348+
assert transport._subscribe_method == 'POST'
349+
assert transport._subscribe_retry_attempted is True
350+
351+
# Second call should try POST directly and fail without retry
352+
call_count = 0
353+
with pytest.raises(A2AClientError):
354+
[event async for event in transport.subscribe(req)]
355+
assert call_count == 1
356+
assert transport._subscribe_method == 'POST'
357+
358+
359+
@pytest.mark.asyncio
360+
async def test_compat_rest_transport_subscribe_post_500_no_retry(transport):
361+
"""Scenario: POST return 500, no automatic retry."""
362+
call_count = 0
363+
364+
async def mock_stream(method, path, context=None):
365+
nonlocal call_count
366+
call_count += 1
367+
if False:
368+
yield
369+
create_500_error()
370+
371+
transport._send_stream_request = mock_stream
372+
373+
req = SubscribeToTaskRequest(id='task-123')
374+
with pytest.raises(A2AClientError) as exc_info:
375+
[event async for event in transport.subscribe(req)]
376+
377+
assert '500' in str(exc_info.value)
378+
assert call_count == 1 # No retry on 500
379+
assert transport._subscribe_method == 'POST'
380+
assert transport._subscribe_retry_attempted is False
251381

252382

253383
def test_compat_rest_transport_handle_http_error(transport):

0 commit comments

Comments
 (0)