|
| 1 | +import asyncio |
| 2 | +from contextlib import contextmanager |
| 3 | +from uuid import uuid4 |
| 4 | + |
| 5 | +import pytest |
| 6 | + |
| 7 | +from ezmsg.core.subclient import Subscriber |
| 8 | +from ezmsg.core.netprotocol import Command, encode_str |
| 9 | +from ezmsg.core import channelmanager as channelmanager_module |
| 10 | +from ezmsg.core import subclient as subclient_module |
| 11 | + |
| 12 | + |
| 13 | +class DummyChannel: |
| 14 | + """Minimal Channel stand-in for Subscriber tests.""" |
| 15 | + |
| 16 | + def __init__(self): |
| 17 | + self.clients = {} |
| 18 | + self.closed = False |
| 19 | + self.waited = False |
| 20 | + self.topic = "test" |
| 21 | + self.num_buffers = 8 |
| 22 | + |
| 23 | + def register_client(self, client_id, queue, local_backpressure=None): |
| 24 | + self.clients[client_id] = queue |
| 25 | + |
| 26 | + def unregister_client(self, client_id): |
| 27 | + del self.clients[client_id] |
| 28 | + |
| 29 | + def close(self): |
| 30 | + self.closed = True |
| 31 | + |
| 32 | + async def wait_closed(self): |
| 33 | + self.waited = True |
| 34 | + |
| 35 | + @contextmanager |
| 36 | + def get(self, msg_id, client_id): |
| 37 | + yield f"msg-{msg_id}" |
| 38 | + |
| 39 | + |
| 40 | +class DummyWriter: |
| 41 | + """Minimal asyncio.StreamWriter stand-in.""" |
| 42 | + |
| 43 | + def __init__(self): |
| 44 | + self.buffer = [] |
| 45 | + self._closed = False |
| 46 | + |
| 47 | + def write(self, data): |
| 48 | + self.buffer.append(data) |
| 49 | + |
| 50 | + async def drain(self): |
| 51 | + pass |
| 52 | + |
| 53 | + def close(self): |
| 54 | + self._closed = True |
| 55 | + |
| 56 | + async def wait_closed(self): |
| 57 | + pass |
| 58 | + |
| 59 | + |
| 60 | +@pytest.mark.asyncio |
| 61 | +async def test_subscriber_unregisters_removed_publisher(monkeypatch): |
| 62 | + """A graph UPDATE that drops a publisher must remove it from _channels. |
| 63 | +
|
| 64 | + Before the fix (PR #218), Subscriber._cur_pubs was initialised to an |
| 65 | + empty set and never updated, so ``cur_pubs - pub_ids`` was always |
| 66 | + empty and stale publishers were never unregistered. |
| 67 | + """ |
| 68 | + async def fake_create(pub_id, address): |
| 69 | + return DummyChannel() |
| 70 | + |
| 71 | + monkeypatch.setattr(channelmanager_module.Channel, "create", fake_create) |
| 72 | + monkeypatch.setattr( |
| 73 | + subclient_module, "CHANNELS", channelmanager_module.ChannelManager() |
| 74 | + ) |
| 75 | + |
| 76 | + pub_a = uuid4() |
| 77 | + |
| 78 | + reader = asyncio.StreamReader() |
| 79 | + writer = DummyWriter() |
| 80 | + |
| 81 | + sub = Subscriber( |
| 82 | + id=uuid4(), |
| 83 | + topic="test/topic", |
| 84 | + graph_address=None, |
| 85 | + _guard=Subscriber._SENTINEL, |
| 86 | + ) |
| 87 | + |
| 88 | + # Protocol sequence: |
| 89 | + # UPDATE [pub_a] -> subscriber registers pub_a |
| 90 | + # COMPLETE -> _initialized is set |
| 91 | + # UPDATE [] -> subscriber should unregister pub_a |
| 92 | + reader.feed_data( |
| 93 | + Command.UPDATE.value |
| 94 | + + encode_str(str(pub_a)) |
| 95 | + + Command.COMPLETE.value |
| 96 | + + Command.UPDATE.value |
| 97 | + + encode_str("") |
| 98 | + ) |
| 99 | + # No EOF yet — connection stays open after the second UPDATE. |
| 100 | + |
| 101 | + task = asyncio.create_task(sub._graph_connection(reader, writer)) |
| 102 | + |
| 103 | + # Wait until the subscriber has written two COMPLETE responses |
| 104 | + # (one per UPDATE), meaning both UPDATEs have been processed. |
| 105 | + for _ in range(200): |
| 106 | + if len(writer.buffer) >= 2: |
| 107 | + break |
| 108 | + await asyncio.sleep(0) |
| 109 | + else: |
| 110 | + pytest.fail("Subscriber did not process both UPDATEs") |
| 111 | + |
| 112 | + # Connection is still open — this is mid-session state, not post-cleanup. |
| 113 | + # Before the fix, pub_a would still be in _channels here. |
| 114 | + assert pub_a not in sub._channels, ( |
| 115 | + "Publisher should be removed from _channels " |
| 116 | + "when a graph UPDATE no longer includes it" |
| 117 | + ) |
| 118 | + |
| 119 | + reader.feed_eof() |
| 120 | + await task |
| 121 | + |
| 122 | + |
| 123 | +@pytest.mark.asyncio |
| 124 | +async def test_recv_zero_copy_skips_stale_notification(): |
| 125 | + """recv_zero_copy must skip notifications from unregistered publishers. |
| 126 | +
|
| 127 | + Before PR #218 there was no guard; a stale notification would cause |
| 128 | + a KeyError on ``self._channels[pub_id]``. |
| 129 | + """ |
| 130 | + sub = Subscriber( |
| 131 | + id=uuid4(), |
| 132 | + topic="test/topic", |
| 133 | + graph_address=None, |
| 134 | + _guard=Subscriber._SENTINEL, |
| 135 | + ) |
| 136 | + |
| 137 | + stale_pub = uuid4() |
| 138 | + valid_pub = uuid4() |
| 139 | + |
| 140 | + sub._channels[valid_pub] = DummyChannel() |
| 141 | + |
| 142 | + # Stale notification first, then a valid one. |
| 143 | + await sub._incoming.put((stale_pub, 0)) |
| 144 | + await sub._incoming.put((valid_pub, 1)) |
| 145 | + |
| 146 | + # Before the fix this would raise KeyError for stale_pub. |
| 147 | + async with sub.recv_zero_copy() as msg: |
| 148 | + assert msg == "msg-1" |
0 commit comments