Skip to content

Commit fa5bc59

Browse files
committed
fixed bug where messages could be dropped
1 parent c8b4d6f commit fa5bc59

2 files changed

Lines changed: 81 additions & 74 deletions

File tree

src/ezmsg/core/sync.py

Lines changed: 45 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -257,10 +257,10 @@ def spin(self, poll_interval: float = 0.1) -> None:
257257

258258
def spin_once(self, timeout: float | None = 0.0) -> bool:
259259
"""
260-
Process at most one subscription callback.
260+
Process any subscription callbacks ready within the timeout window.
261261
262-
:param timeout: Seconds to wait for a callback. Use None to block forever.
263-
:return: True if a callback was processed, False otherwise.
262+
:param timeout: Seconds to wait for callbacks. Use None to block forever.
263+
:return: True if any callback was processed, False otherwise.
264264
"""
265265
self._ensure_started()
266266
if self._shutdown_requested.is_set():
@@ -289,31 +289,32 @@ def spin_once(self, timeout: float | None = 0.0) -> bool:
289289
finally:
290290
if not keep_future:
291291
self._spin_future = None
292-
if result is None:
292+
if not result:
293293
return False
294+
processed = False
295+
for entry, cm, msg in result:
296+
_, callback, zero_copy = entry
294297

295-
entry, cm, msg = result
296-
_, callback, zero_copy = entry
297-
298-
try:
299-
if not zero_copy:
300-
msg = deepcopy(msg)
301-
callback(msg)
302-
except Exception:
303-
logger.exception("Unhandled exception in subscription callback")
304-
finally:
305-
exit_fut = asyncio.run_coroutine_threadsafe(
306-
cm.__aexit__(None, None, None), self._loop
307-
)
308298
try:
309-
_future_result(exit_fut, None)
310-
except CacheMiss:
311-
logger.warning(
312-
"Cache miss while releasing message; publisher likely exited."
313-
)
299+
if not zero_copy:
300+
msg = deepcopy(msg)
301+
callback(msg)
314302
except Exception:
315-
logger.exception("Failed while releasing message backpressure")
316-
return True
303+
logger.exception("Unhandled exception in subscription callback")
304+
finally:
305+
exit_fut = asyncio.run_coroutine_threadsafe(
306+
cm.__aexit__(None, None, None), self._loop
307+
)
308+
try:
309+
_future_result(exit_fut, None)
310+
except CacheMiss:
311+
logger.warning(
312+
"Cache miss while releasing message; publisher likely exited."
313+
)
314+
except Exception:
315+
logger.exception("Failed while releasing message backpressure")
316+
processed = True
317+
return processed
317318

318319
def shutdown(self) -> None:
319320
self._shutdown_requested.set()
@@ -368,30 +369,14 @@ def spin(context: SyncContext, poll_interval: float = 0.1) -> None:
368369

369370

370371
def spin_once(context: SyncContext, timeout: float | None = 0.0) -> bool:
371-
"""Process at most one subscription callback."""
372+
"""Process any subscription callbacks ready within the timeout window."""
372373
return context.spin_once(timeout=timeout)
373374

374375

375376
async def _recv_any(
376377
entries: Iterable[tuple[SyncSubscriber, Callable[[Any], None], bool]],
377378
timeout: float | None,
378-
) -> tuple[tuple[SyncSubscriber, Callable[[Any], None], bool], Any, Any] | None:
379-
async def _cleanup_result(result: Any) -> None:
380-
if isinstance(result, BaseException):
381-
return
382-
try:
383-
_, cm, _ = result
384-
except Exception:
385-
return
386-
try:
387-
await cm.__aexit__(None, None, None)
388-
except CacheMiss:
389-
logger.warning(
390-
"Cache miss while releasing message; publisher likely exited."
391-
)
392-
except Exception:
393-
logger.exception("Failed while releasing message backpressure")
394-
379+
) -> list[tuple[tuple[SyncSubscriber, Callable[[Any], None], bool], Any, Any]]:
395380
async def _recv_entry(
396381
entry: tuple[SyncSubscriber, Callable[[Any], None], bool]
397382
) -> tuple[tuple[SyncSubscriber, Callable[[Any], None], bool], Any, Any]:
@@ -410,20 +395,9 @@ async def _recv_entry(
410395
done, pending = await asyncio.wait(
411396
tasks, timeout=remaining, return_when=asyncio.FIRST_COMPLETED
412397
)
413-
if not done:
414-
for task in pending:
415-
try:
416-
task.cancel()
417-
except RuntimeError:
418-
pass
419-
pending_results = await asyncio.gather(
420-
*pending, return_exceptions=True
421-
)
422-
for result in pending_results:
423-
await _cleanup_result(result)
424-
return None
425-
426-
winner_result = None
398+
results: list[
399+
tuple[tuple[SyncSubscriber, Callable[[Any], None], bool], Any, Any]
400+
] = []
427401
for task in done:
428402
try:
429403
result = task.result()
@@ -435,28 +409,32 @@ async def _recv_entry(
435409
except Exception:
436410
logger.exception("Sync subscription receive failed")
437411
continue
438-
if winner_result is None:
439-
winner_result = result
440-
else:
441-
await _cleanup_result(result)
412+
results.append(result)
442413

443414
for task in pending:
444415
try:
445416
task.cancel()
446417
except RuntimeError:
447418
pass
448-
pending_results = await asyncio.gather(
449-
*pending, return_exceptions=True
450-
)
419+
pending_results = await asyncio.gather(*pending, return_exceptions=True)
451420
for result in pending_results:
452-
await _cleanup_result(result)
421+
if isinstance(result, CacheMiss):
422+
continue
423+
if isinstance(result, asyncio.CancelledError):
424+
continue
425+
if isinstance(result, BaseException):
426+
logger.exception(
427+
"Sync subscription receive failed", exc_info=result
428+
)
429+
continue
430+
results.append(result)
453431

454-
if winner_result is not None:
455-
return winner_result
432+
if results:
433+
return results
456434

457435
# Only CacheMiss/cancelled/error occurred; continue within timeout window.
458436
if deadline is not None and loop.time() >= deadline:
459-
return None
437+
return []
460438
finally:
461439
for task in tasks:
462440
if not task.done():

tests/test_sync_api.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,34 @@ def on_msg(msg: str) -> None:
6464
spin_thread.join(timeout=1.0)
6565

6666

67+
def test_spin_once_processes_all_ready_callbacks():
68+
host = "127.0.0.1"
69+
port = _free_port()
70+
71+
with ez.sync.init((host, port), auto_start=True) as ctx:
72+
received: list[tuple[str, str]] = []
73+
74+
def on_a(msg: str) -> None:
75+
received.append(("A", msg))
76+
77+
def on_b(msg: str) -> None:
78+
received.append(("B", msg))
79+
80+
ctx.create_subscription("/A", callback=on_a, zero_copy=True)
81+
ctx.create_subscription("/B", callback=on_b, zero_copy=True)
82+
83+
pub_a = ctx.create_publisher("/A", num_buffers=1, force_tcp=True)
84+
pub_b = ctx.create_publisher("/B", num_buffers=1, force_tcp=True)
85+
86+
time.sleep(0.05)
87+
pub_a.publish("one")
88+
pub_b.publish("two")
89+
time.sleep(0.05)
90+
91+
assert ctx.spin_once(timeout=0.2) is True
92+
assert set(received) == {("A", "one"), ("B", "two")}
93+
94+
6795
@pytest.mark.asyncio
6896
async def test_recv_any_cachemiss_does_not_raise():
6997
class DummySub:
@@ -78,11 +106,11 @@ def __init__(self) -> None:
78106

79107
entry = (DummySyncSub(), lambda _: None, True)
80108
result = await sync_mod._recv_any([entry], timeout=0.01)
81-
assert result is None
109+
assert result == []
82110

83111

84112
@pytest.mark.asyncio
85-
async def test_recv_any_cleans_up_non_winner_contexts():
113+
async def test_recv_any_returns_all_ready_contexts():
86114
class DummyCM:
87115
def __init__(self, name: str, gate: asyncio.Event | None) -> None:
88116
self._name = name
@@ -131,12 +159,13 @@ async def _release_pending() -> None:
131159
asyncio.create_task(_release_pending())
132160

133161
result = await sync_mod._recv_any(entries, timeout=0.2)
134-
assert result is not None
162+
assert len(result) == 3
163+
164+
returned_cms = {cm for _, cm, _ in result}
165+
assert returned_cms == set(cms)
135166

136-
_, winner_cm, _ = result
137-
await winner_cm.__aexit__(None, None, None)
167+
for _, cm, _ in result:
168+
await cm.__aexit__(None, None, None)
138169

139170
for cm in cms:
140-
if cm is winner_cm:
141-
continue
142171
assert cm.exited.is_set()

0 commit comments

Comments
 (0)