Skip to content

Commit 81358a3

Browse files
committed
Catch QueueShutDown in event_consumer.py.
1 parent 1ac0a1c commit 81358a3

2 files changed

Lines changed: 41 additions & 28 deletions

File tree

src/a2a/server/events/event_consumer.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import asyncio
22
import logging
3-
import sys
43

54
from collections.abc import AsyncGenerator
65

76
from pydantic import ValidationError
87

9-
from a2a.server.events.event_queue import Event, EventQueue
8+
from a2a.server.events.event_queue import Event, EventQueue, QueueShutDown
109
from a2a.types.a2a_pb2 import (
1110
Message,
1211
Task,
@@ -17,13 +16,6 @@
1716
from a2a.utils.telemetry import SpanKind, trace_class
1817

1918

20-
# This is an alias to the exception for closed queue
21-
QueueClosed: type[Exception] = asyncio.QueueEmpty
22-
23-
# When using python 3.13 or higher, the closed queue signal is QueueShutdown
24-
if sys.version_info >= (3, 13):
25-
QueueClosed = asyncio.QueueShutDown
26-
2719
logger = logging.getLogger(__name__)
2820

2921

@@ -143,7 +135,7 @@ async def consume_all(self) -> AsyncGenerator[Event]:
143135
except asyncio.TimeoutError: # pyright: ignore [reportUnusedExcept]
144136
# This class was made an alias of built-in TimeoutError after 3.11
145137
continue
146-
except (QueueClosed, asyncio.QueueEmpty):
138+
except (QueueShutDown, asyncio.QueueEmpty):
147139
# Confirm that the queue is closed, e.g. we aren't on
148140
# python 3.12 and get a queue empty error on an open queue
149141
if self.queue.is_closed():

tests/server/events/test_event_consumer.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77

88
from pydantic import ValidationError
99

10-
from a2a.server.events.event_consumer import EventConsumer, QueueClosed
10+
from a2a.server.events.event_consumer import EventConsumer
11+
from a2a.server.events.event_queue import QueueShutDown
1112
from a2a.server.events.event_queue import EventQueue
1213
from a2a.server.jsonrpc_models import JSONRPCError
1314
from a2a.types import (
@@ -256,9 +257,9 @@ async def test_consume_all_raises_stored_exception(
256257
async def test_consume_all_stops_on_queue_closed_and_confirmed_closed(
257258
event_consumer: EventConsumer, mock_event_queue: AsyncMock
258259
):
259-
"""Test consume_all stops if QueueClosed is raised and queue.is_closed() is True."""
260-
# Simulate the queue raising QueueClosed (which is asyncio.QueueEmpty or QueueShutdown)
261-
mock_event_queue.dequeue_event.side_effect = QueueClosed(
260+
"""Test consume_all stops if QueueShutDown is raised and queue.is_closed() is True."""
261+
# Simulate the queue raising QueueShutDown (which is asyncio.QueueEmpty or QueueShutdown)
262+
mock_event_queue.dequeue_event.side_effect = QueueShutDown(
262263
'Queue is empty/closed'
263264
)
264265
# Simulate the queue confirming it's closed
@@ -270,7 +271,7 @@ async def test_consume_all_stops_on_queue_closed_and_confirmed_closed(
270271

271272
assert (
272273
len(consumed_events) == 0
273-
) # No events should be consumed as it breaks on QueueClosed
274+
) # No events should be consumed as it breaks on QueueShutDown
274275
mock_event_queue.dequeue_event.assert_called_once() # Should attempt to dequeue once
275276
mock_event_queue.is_closed.assert_called_once() # Should check if closed
276277

@@ -279,28 +280,28 @@ async def test_consume_all_stops_on_queue_closed_and_confirmed_closed(
279280
async def test_consume_all_continues_on_queue_empty_if_not_really_closed(
280281
event_consumer: EventConsumer, mock_event_queue: AsyncMock
281282
):
282-
"""Test that QueueClosed with is_closed=False allows loop to continue via timeout."""
283+
"""Test that QueueShutDown with is_closed=False allows loop to continue via timeout."""
283284
final_event = create_sample_message(message_id='final_event_id')
284285

285286
# Setup dequeue_event behavior:
286-
# 1. Raise QueueClosed (e.g., asyncio.QueueEmpty)
287+
# 1. Raise QueueShutDown (e.g., asyncio.QueueEmpty)
287288
# 2. Return the final_event
288-
# 3. Raise QueueClosed again (to terminate after final_event)
289+
# 3. Raise QueueShutDown again (to terminate after final_event)
289290
dequeue_effects = [
290-
QueueClosed('Simulated temporary empty'),
291+
QueueShutDown('Simulated temporary empty'),
291292
final_event,
292-
QueueClosed('Queue closed after final event'),
293+
QueueShutDown('Queue closed after final event'),
293294
]
294295
mock_event_queue.dequeue_event.side_effect = dequeue_effects
295296

296297
# Setup is_closed behavior:
297-
# 1. False when QueueClosed is first raised (so loop doesn't break)
298-
# 2. True after final_event is processed and QueueClosed is raised again
298+
# 1. False when QueueShutDown is first raised (so loop doesn't break)
299+
# 2. True after final_event is processed and QueueShutDown is raised again
299300
is_closed_effects = [False, True]
300301
mock_event_queue.is_closed.side_effect = is_closed_effects
301302

302303
# Patch asyncio.wait_for used inside consume_all
303-
# The goal is that the first QueueClosed leads to a TimeoutError inside consume_all,
304+
# The goal is that the first QueueShutDown leads to a TimeoutError inside consume_all,
304305
# the loop continues, and then the final_event is fetched.
305306

306307
# To reliably test the timeout behavior within consume_all, we adjust the consumer's
@@ -315,15 +316,15 @@ async def test_consume_all_continues_on_queue_empty_if_not_really_closed(
315316
assert consumed_events[0] == final_event
316317

317318
# Dequeue attempts:
318-
# 1. Raises QueueClosed (is_closed=False, leads to TimeoutError, loop continues)
319+
# 1. Raises QueueShutDown (is_closed=False, leads to TimeoutError, loop continues)
319320
# 2. Returns final_event (which is a Message, causing consume_all to break)
320321
assert (
321322
mock_event_queue.dequeue_event.call_count == 2
322323
) # Only two calls needed
323324

324325
# is_closed calls:
325-
# 1. After first QueueClosed (returns False)
326-
# The second QueueClosed is not reached because Message breaks the loop.
326+
# 1. After first QueueShutDown (returns False)
327+
# The second QueueShutDown is not reached because Message breaks the loop.
327328
assert mock_event_queue.is_closed.call_count == 1
328329

329330

@@ -332,13 +333,13 @@ async def test_consume_all_handles_queue_empty_when_closed_python_version_agnost
332333
event_consumer: EventConsumer, mock_event_queue: AsyncMock, monkeypatch
333334
):
334335
"""Ensure consume_all stops with no events when queue is closed and dequeue_event raises asyncio.QueueEmpty (Python version-agnostic)."""
335-
# Make QueueClosed a distinct exception (not QueueEmpty) to emulate py3.13 semantics
336+
# Make QueueShutDown a distinct exception (not QueueEmpty) to emulate py3.13 semantics
336337
from a2a.server.events import event_consumer as ec
337338

338339
class QueueShutDown(Exception):
339340
pass
340341

341-
monkeypatch.setattr(ec, 'QueueClosed', QueueShutDown, raising=True)
342+
monkeypatch.setattr(ec, 'QueueShutDown', QueueShutDown, raising=True)
342343

343344
# Simulate queue reporting closed while dequeue raises QueueEmpty
344345
mock_event_queue.dequeue_event.side_effect = asyncio.QueueEmpty(
@@ -538,3 +539,23 @@ async def test_background_close_deadlocks_on_trailing_events() -> None:
538539
await asyncio.wait_for(consumer._close_task, timeout=0.1)
539540
except asyncio.TimeoutError:
540541
pytest.fail('Background close task deadlocked on trailing events!')
542+
543+
544+
@pytest.mark.asyncio
545+
async def test_consume_all_handles_actual_queue_shutdown(
546+
event_consumer: EventConsumer, mock_event_queue: AsyncMock
547+
):
548+
"""Ensure consume_all stops when queue is closed and dequeue_event raises the actual QueueShutDown from event_queue."""
549+
from a2a.server.events.event_queue import QueueShutDown
550+
551+
mock_event_queue.dequeue_event.side_effect = QueueShutDown(
552+
'Queue is closed'
553+
)
554+
mock_event_queue.is_closed.return_value = True
555+
556+
consumed_events = []
557+
# This should exit cleanly because consume_all correctly catches the QueueShutDown exception.
558+
async for event in event_consumer.consume_all():
559+
consumed_events.append(event)
560+
561+
assert len(consumed_events) == 0

0 commit comments

Comments
 (0)