Skip to content

Commit c8b4d6f

Browse files
committed
addressed copilot review
1 parent 4d1b9c2 commit c8b4d6f

5 files changed

Lines changed: 127 additions & 19 deletions

File tree

examples/simple_async_subscriber.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import ezmsg.core as ez
44

5-
PORT = 12345
65
TOPIC = "/TEST"
76

87

examples/simple_subscriber.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import time
21
import ezmsg.core as ez
32

43
TOPIC = "/TEST"
@@ -10,6 +9,7 @@ def main(host: str = "127.0.0.1", port: int = 12345) -> None:
109

1110
def on_message(msg: str) -> None:
1211
# Uncomment if you want to witness backpressure!
12+
# import time
1313
# time.sleep(1.0)
1414
print(msg)
1515

src/ezmsg/core/graphserver.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class GraphServer(threading.Thread):
5656
_shutdown: threading.Event
5757

5858
_sock: socket.socket
59+
_address: Address | None
5960
_loop: asyncio.AbstractEventLoop
6061

6162
graph: DAG
@@ -78,11 +79,13 @@ def __init__(self, **kwargs) -> None:
7879
self.clients = {}
7980
self._client_tasks = {}
8081
self.shms = {}
82+
self._address = None
8183

8284
@property
8385
def address(self) -> Address:
84-
return Address(*self._sock.getsockname())
85-
86+
assert self._address is not None, "GraphServer not up yet"
87+
return self._address
88+
8689
def start(self, address: AddressType | None = None) -> None: # type: ignore[override]
8790
if address is not None:
8891
self._sock = create_socket(*address)
@@ -92,10 +95,13 @@ def start(self, address: AddressType | None = None) -> None: # type: ignore[ove
9295
)
9396
self._sock = create_socket(start_port=start_port)
9497

98+
# Cache address immediately to avoid touching a possibly-closed socket later.
99+
self._address = Address(*self._sock.getsockname())
100+
95101
self._loop = asyncio.new_event_loop()
96102
super().start()
97103
self._server_up.wait()
98-
logger.info(f'Started GraphServer at {address}')
104+
logger.info(f'Started GraphServer at {self.address}')
99105

100106
def stop(self) -> None:
101107
self._shutdown.set()

src/ezmsg/core/sync.py

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,13 @@ def sync(self, timeout: float | None = None) -> None:
4444
_future_result(fut, timeout)
4545

4646
def pause(self) -> None:
47-
self._pub.pause()
47+
self._loop.call_soon_threadsafe(self._pub.pause)
4848

4949
def resume(self) -> None:
50-
self._pub.resume()
50+
self._loop.call_soon_threadsafe(self._pub.resume)
5151

5252
def close(self) -> None:
53-
self._pub.close()
53+
self._loop.call_soon_threadsafe(self._pub.close)
5454

5555
def wait_closed(self, timeout: float | None = None) -> None:
5656
fut = asyncio.run_coroutine_threadsafe(self._pub.wait_closed(), self._loop)
@@ -101,7 +101,7 @@ def recv_zero_copy(self, timeout: float | None = None) -> _SyncZeroCopy:
101101
return _SyncZeroCopy(self._sub, self._loop, timeout)
102102

103103
def close(self) -> None:
104-
self._sub.close()
104+
self._loop.call_soon_threadsafe(self._sub.close)
105105

106106
def wait_closed(self, timeout: float | None = None) -> None:
107107
fut = asyncio.run_coroutine_threadsafe(self._sub.wait_closed(), self._loop)
@@ -134,6 +134,14 @@ def graph_address(self) -> AddressType | None:
134134
return self._graph_context.graph_address
135135

136136
def __enter__(self) -> "SyncContext":
137+
138+
# SyncContext instances are single-use: they cannot be re-entered after shutdown.
139+
if self._closed:
140+
raise RuntimeError(
141+
"SyncContext instances cannot be reused after shutdown; "
142+
"create a new SyncContext instead."
143+
)
144+
137145
if self._loop_cm is not None:
138146
return self
139147

@@ -368,6 +376,22 @@ async def _recv_any(
368376
entries: Iterable[tuple[SyncSubscriber, Callable[[Any], None], bool]],
369377
timeout: float | None,
370378
) -> tuple[tuple[SyncSubscriber, Callable[[Any], None], bool], Any, Any] | None:
379+
async def _cleanup_result(result: Any) -> None:
380+
if isinstance(result, BaseException):
381+
return
382+
try:
383+
_, cm, _ = result
384+
except Exception:
385+
return
386+
try:
387+
await cm.__aexit__(None, None, None)
388+
except CacheMiss:
389+
logger.warning(
390+
"Cache miss while releasing message; publisher likely exited."
391+
)
392+
except Exception:
393+
logger.exception("Failed while releasing message backpressure")
394+
371395
async def _recv_entry(
372396
entry: tuple[SyncSubscriber, Callable[[Any], None], bool]
373397
) -> tuple[tuple[SyncSubscriber, Callable[[Any], None], bool], Any, Any]:
@@ -392,19 +416,17 @@ async def _recv_entry(
392416
task.cancel()
393417
except RuntimeError:
394418
pass
395-
await asyncio.gather(*pending, return_exceptions=True)
419+
pending_results = await asyncio.gather(
420+
*pending, return_exceptions=True
421+
)
422+
for result in pending_results:
423+
await _cleanup_result(result)
396424
return None
397425

398-
for task in pending:
399-
try:
400-
task.cancel()
401-
except RuntimeError:
402-
pass
403-
await asyncio.gather(*pending, return_exceptions=True)
404-
426+
winner_result = None
405427
for task in done:
406428
try:
407-
return task.result()
429+
result = task.result()
408430
except CacheMiss:
409431
# Likely stale notification after publisher exit; keep waiting.
410432
continue
@@ -413,6 +435,24 @@ async def _recv_entry(
413435
except Exception:
414436
logger.exception("Sync subscription receive failed")
415437
continue
438+
if winner_result is None:
439+
winner_result = result
440+
else:
441+
await _cleanup_result(result)
442+
443+
for task in pending:
444+
try:
445+
task.cancel()
446+
except RuntimeError:
447+
pass
448+
pending_results = await asyncio.gather(
449+
*pending, return_exceptions=True
450+
)
451+
for result in pending_results:
452+
await _cleanup_result(result)
453+
454+
if winner_result is not None:
455+
return winner_result
416456

417457
# Only CacheMiss/cancelled/error occurred; continue within timeout window.
418458
if deadline is not None and loop.time() >= deadline:

tests/test_sync_api.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import socket
23
import threading
34
import time
@@ -56,7 +57,8 @@ def on_msg(msg: str) -> None:
5657
t2 = time.monotonic()
5758

5859
assert t2 - t1 >= 0.15
59-
done.wait(2.0)
60+
assert done.wait(2.0), "Timed out waiting for messages to be received"
61+
assert received == ["one", "two"]
6062

6163
ctx.shutdown()
6264
spin_thread.join(timeout=1.0)
@@ -77,3 +79,64 @@ def __init__(self) -> None:
7779
entry = (DummySyncSub(), lambda _: None, True)
7880
result = await sync_mod._recv_any([entry], timeout=0.01)
7981
assert result is None
82+
83+
84+
@pytest.mark.asyncio
85+
async def test_recv_any_cleans_up_non_winner_contexts():
86+
class DummyCM:
87+
def __init__(self, name: str, gate: asyncio.Event | None) -> None:
88+
self._name = name
89+
self._gate = gate
90+
self.exited = asyncio.Event()
91+
92+
async def __aenter__(self):
93+
if self._gate is not None:
94+
try:
95+
await self._gate.wait()
96+
except asyncio.CancelledError:
97+
await self._gate.wait()
98+
return f"msg-{self._name}"
99+
100+
async def __aexit__(self, exc_type, exc, tb):
101+
self.exited.set()
102+
103+
class DummySub:
104+
def __init__(self, cm: DummyCM) -> None:
105+
self._cm = cm
106+
107+
def recv_zero_copy(self):
108+
return self._cm
109+
110+
class DummySyncSub:
111+
def __init__(self, cm: DummyCM) -> None:
112+
self._sub = DummySub(cm)
113+
114+
release = asyncio.Event()
115+
cms: list[DummyCM] = []
116+
entries = []
117+
118+
def add_entry(name: str, gate: asyncio.Event | None) -> None:
119+
cm = DummyCM(name, gate)
120+
cms.append(cm)
121+
entries.append((DummySyncSub(cm), lambda _: None, True))
122+
123+
add_entry("fast", None)
124+
add_entry("slow1", release)
125+
add_entry("slow2", release)
126+
127+
async def _release_pending() -> None:
128+
await asyncio.sleep(0.01)
129+
release.set()
130+
131+
asyncio.create_task(_release_pending())
132+
133+
result = await sync_mod._recv_any(entries, timeout=0.2)
134+
assert result is not None
135+
136+
_, winner_cm, _ = result
137+
await winner_cm.__aexit__(None, None, None)
138+
139+
for cm in cms:
140+
if cm is winner_cm:
141+
continue
142+
assert cm.exited.is_set()

0 commit comments

Comments
 (0)