Skip to content

Commit bbea28c

Browse files
authored
Merge pull request #218 from ezmsg-org/217-subscriber-never-updates-its-set-of-current-publishers-after-processing-update-preventing-disconnection-of-old-publishers
hot fix 217 - recalculate Subscriber's pub keys on demand
2 parents 5c17e7c + 911d7bb commit bbea28c

2 files changed

Lines changed: 158 additions & 5 deletions

File tree

src/ezmsg/core/subclient.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ class Subscriber:
3939

4040
_graph_address: AddressType | None
4141
_graph_task: asyncio.Task[None]
42-
_cur_pubs: set[UUID]
4342
_incoming: NotificationQueue
4443

4544
# FIXME: This event allows Subscriber.create to block until
@@ -128,7 +127,6 @@ def __init__(
128127
self.leaky = leaky
129128
self._graph_address = graph_address
130129

131-
self._cur_pubs = set()
132130
self._channels = dict()
133131
if self.leaky:
134132
self._incoming = LeakyQueue(
@@ -223,8 +221,10 @@ async def _graph_connection(
223221
pub_ids = (
224222
set([UUID(id) for id in update.split(",")]) if update else set()
225223
)
224+
cur_pubs = set(self._channels.keys())
226225

227-
for pub_id in set(pub_ids - self._cur_pubs):
226+
# Register new channels
227+
for pub_id in set(pub_ids - cur_pubs):
228228
channel = await CHANNELS.register(
229229
pub_id, self.id, self._incoming, self._graph_address
230230
)
@@ -239,7 +239,8 @@ async def _graph_connection(
239239

240240
self._channels[pub_id] = channel
241241

242-
for pub_id in set(self._cur_pubs - pub_ids):
242+
# Unregister expired channels
243+
for pub_id in set(cur_pubs - pub_ids):
243244
await CHANNELS.unregister(pub_id, self.id, self._graph_address)
244245
del self._channels[pub_id]
245246

@@ -288,7 +289,11 @@ async def recv_zero_copy(self) -> typing.AsyncGenerator[typing.Any, None]:
288289
:return: Context manager yielding the received message.
289290
:rtype: collections.abc.AsyncGenerator[typing.Any, None]
290291
"""
291-
pub_id, msg_id = await self._incoming.get()
292+
while True:
293+
pub_id, msg_id = await self._incoming.get()
294+
if pub_id in self._channels:
295+
break
296+
# Stale notification from an unregistered publisher — skip.
292297

293298
with self._channels[pub_id].get(msg_id, self.id) as msg:
294299
yield msg

tests/test_subclient.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import asyncio
2+
from contextlib import contextmanager
3+
from uuid import uuid4
4+
5+
import pytest
6+
7+
from ezmsg.core.subclient import Subscriber
8+
from ezmsg.core.netprotocol import Command, encode_str
9+
from ezmsg.core import channelmanager as channelmanager_module
10+
from ezmsg.core import subclient as subclient_module
11+
12+
13+
class DummyChannel:
14+
"""Minimal Channel stand-in for Subscriber tests."""
15+
16+
def __init__(self):
17+
self.clients = {}
18+
self.closed = False
19+
self.waited = False
20+
self.topic = "test"
21+
self.num_buffers = 8
22+
23+
def register_client(self, client_id, queue, local_backpressure=None):
24+
self.clients[client_id] = queue
25+
26+
def unregister_client(self, client_id):
27+
del self.clients[client_id]
28+
29+
def close(self):
30+
self.closed = True
31+
32+
async def wait_closed(self):
33+
self.waited = True
34+
35+
@contextmanager
36+
def get(self, msg_id, client_id):
37+
yield f"msg-{msg_id}"
38+
39+
40+
class DummyWriter:
41+
"""Minimal asyncio.StreamWriter stand-in."""
42+
43+
def __init__(self):
44+
self.buffer = []
45+
self._closed = False
46+
47+
def write(self, data):
48+
self.buffer.append(data)
49+
50+
async def drain(self):
51+
pass
52+
53+
def close(self):
54+
self._closed = True
55+
56+
async def wait_closed(self):
57+
pass
58+
59+
60+
@pytest.mark.asyncio
61+
async def test_subscriber_unregisters_removed_publisher(monkeypatch):
62+
"""A graph UPDATE that drops a publisher must remove it from _channels.
63+
64+
Before the fix (PR #218), Subscriber._cur_pubs was initialised to an
65+
empty set and never updated, so ``cur_pubs - pub_ids`` was always
66+
empty and stale publishers were never unregistered.
67+
"""
68+
async def fake_create(pub_id, address):
69+
return DummyChannel()
70+
71+
monkeypatch.setattr(channelmanager_module.Channel, "create", fake_create)
72+
monkeypatch.setattr(
73+
subclient_module, "CHANNELS", channelmanager_module.ChannelManager()
74+
)
75+
76+
pub_a = uuid4()
77+
78+
reader = asyncio.StreamReader()
79+
writer = DummyWriter()
80+
81+
sub = Subscriber(
82+
id=uuid4(),
83+
topic="test/topic",
84+
graph_address=None,
85+
_guard=Subscriber._SENTINEL,
86+
)
87+
88+
# Protocol sequence:
89+
# UPDATE [pub_a] -> subscriber registers pub_a
90+
# COMPLETE -> _initialized is set
91+
# UPDATE [] -> subscriber should unregister pub_a
92+
reader.feed_data(
93+
Command.UPDATE.value
94+
+ encode_str(str(pub_a))
95+
+ Command.COMPLETE.value
96+
+ Command.UPDATE.value
97+
+ encode_str("")
98+
)
99+
# No EOF yet — connection stays open after the second UPDATE.
100+
101+
task = asyncio.create_task(sub._graph_connection(reader, writer))
102+
103+
# Wait until the subscriber has written two COMPLETE responses
104+
# (one per UPDATE), meaning both UPDATEs have been processed.
105+
for _ in range(200):
106+
if len(writer.buffer) >= 2:
107+
break
108+
await asyncio.sleep(0)
109+
else:
110+
pytest.fail("Subscriber did not process both UPDATEs")
111+
112+
# Connection is still open — this is mid-session state, not post-cleanup.
113+
# Before the fix, pub_a would still be in _channels here.
114+
assert pub_a not in sub._channels, (
115+
"Publisher should be removed from _channels "
116+
"when a graph UPDATE no longer includes it"
117+
)
118+
119+
reader.feed_eof()
120+
await task
121+
122+
123+
@pytest.mark.asyncio
124+
async def test_recv_zero_copy_skips_stale_notification():
125+
"""recv_zero_copy must skip notifications from unregistered publishers.
126+
127+
Before PR #218 there was no guard; a stale notification would cause
128+
a KeyError on ``self._channels[pub_id]``.
129+
"""
130+
sub = Subscriber(
131+
id=uuid4(),
132+
topic="test/topic",
133+
graph_address=None,
134+
_guard=Subscriber._SENTINEL,
135+
)
136+
137+
stale_pub = uuid4()
138+
valid_pub = uuid4()
139+
140+
sub._channels[valid_pub] = DummyChannel()
141+
142+
# Stale notification first, then a valid one.
143+
await sub._incoming.put((stale_pub, 0))
144+
await sub._incoming.put((valid_pub, 1))
145+
146+
# Before the fix this would raise KeyError for stale_pub.
147+
async with sub.recv_zero_copy() as msg:
148+
assert msg == "msg-1"

0 commit comments

Comments
 (0)