Skip to content

Commit 0174559

Browse files
authored
Merge pull request #216 from ezmsg-org/fix/209-zero-copy
Deprecate `zero-copy` keyword argument in `@ez.subscriber`
2 parents f7705b4 + bbea28c commit 0174559

8 files changed

Lines changed: 187 additions & 22 deletions

File tree

src/ezmsg/core/backendprocess.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typing import Any
2020

2121
from .stream import Stream, InputStream, OutputStream
22-
from .unit import Unit, TIMEIT_ATTR, SUBSCRIBES_ATTR, ZERO_COPY_ATTR
22+
from .unit import Unit, TIMEIT_ATTR, SUBSCRIBES_ATTR
2323
from .messagechannel import LeakyQueue
2424

2525
from .graphcontext import GraphContext
@@ -374,8 +374,6 @@ async def wrapped_task(msg: Any = None) -> None:
374374
result = call_fn(msg)
375375
if inspect.isasyncgen(result):
376376
async for stream, obj in result:
377-
if obj and getattr(task, ZERO_COPY_ATTR, False) and obj is msg:
378-
obj = deepcopy(obj)
379377
await pub_fn(stream, obj)
380378

381379
elif asyncio.iscoroutine(result):

src/ezmsg/core/subclient.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ class Subscriber:
3939

4040
_graph_address: AddressType | None
4141
_graph_task: asyncio.Task[None]
42-
_cur_pubs: set[UUID]
4342
_incoming: NotificationQueue
4443

4544
# FIXME: This event allows Subscriber.create to block until
@@ -128,7 +127,6 @@ def __init__(
128127
self.leaky = leaky
129128
self._graph_address = graph_address
130129

131-
self._cur_pubs = set()
132130
self._channels = dict()
133131
if self.leaky:
134132
self._incoming = LeakyQueue(
@@ -223,8 +221,10 @@ async def _graph_connection(
223221
pub_ids = (
224222
set([UUID(id) for id in update.split(",")]) if update else set()
225223
)
224+
cur_pubs = set(self._channels.keys())
226225

227-
for pub_id in set(pub_ids - self._cur_pubs):
226+
# Register new channels
227+
for pub_id in set(pub_ids - cur_pubs):
228228
channel = await CHANNELS.register(
229229
pub_id, self.id, self._incoming, self._graph_address
230230
)
@@ -239,7 +239,8 @@ async def _graph_connection(
239239

240240
self._channels[pub_id] = channel
241241

242-
for pub_id in set(self._cur_pubs - pub_ids):
242+
# Unregister expired channels
243+
for pub_id in set(cur_pubs - pub_ids):
243244
await CHANNELS.unregister(pub_id, self.id, self._graph_address)
244245
del self._channels[pub_id]
245246

@@ -288,7 +289,11 @@ async def recv_zero_copy(self) -> typing.AsyncGenerator[typing.Any, None]:
288289
:return: Context manager yielding the received message.
289290
:rtype: collections.abc.AsyncGenerator[typing.Any, None]
290291
"""
291-
pub_id, msg_id = await self._incoming.get()
292+
while True:
293+
pub_id, msg_id = await self._incoming.get()
294+
if pub_id in self._channels:
295+
break
296+
# Stale notification from an unregistered publisher — skip.
292297

293298
with self._channels[pub_id].get(msg_id, self.id) as msg:
294299
yield msg

src/ezmsg/core/unit.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import time
22
import inspect
33
import functools
4+
import warnings
45
from .stream import InputStream, OutputStream
56
from .component import ComponentMeta, Component
67
from .settings import Settings
@@ -23,6 +24,8 @@
2324
LEAKY_ATTR = "__ez_leaky__"
2425
MAX_QUEUE_ATTR = "__ez_max_queue__"
2526

27+
_ZERO_COPY_SENTINEL = object()
28+
2629

2730
class UnitMeta(ComponentMeta):
2831
def __init__(
@@ -162,17 +165,18 @@ def pub_factory(func):
162165
return pub_factory
163166

164167

165-
def subscriber(stream: InputStream, zero_copy: bool = False):
168+
def subscriber(stream: InputStream, zero_copy: Any = _ZERO_COPY_SENTINEL):
166169
"""
167170
A decorator for a method that subscribes to a stream in the task/messaging thread.
168171
169172
An async function will run once per message received from the :obj:`InputStream`
170173
it subscribes to. A function can have both ``@subscriber`` and ``@publisher`` decorators.
171174
175+
The ``zero_copy`` argument is deprecated and ignored. Subscribers always receive
176+
zero-copy messages, so callers can omit it.
177+
172178
:param stream: The input stream to subscribe to
173179
:type stream: InputStream
174-
:param zero_copy: Whether to use zero-copy message passing (default: False)
175-
:type zero_copy: bool
176180
:return: Decorated function that can subscribe to the stream
177181
:rtype: collections.abc.Callable
178182
:raises ValueError: If stream is not an InputStream
@@ -183,20 +187,30 @@ def subscriber(stream: InputStream, zero_copy: bool = False):
183187
184188
INPUT = ez.InputStream(Message)
185189
186-
@subscriber(INPUT)
187-
async def print_message(self, message: Message) -> None:
188-
print(message)
190+
@subscriber(INPUT)
191+
async def print_message(self, message: Message) -> None:
192+
print(message)
189193
"""
190194

191195
if not isinstance(stream, InputStream):
192196
raise ValueError(f"Cannot subscribe to object of type {type(stream)}")
193197

198+
if zero_copy is not _ZERO_COPY_SENTINEL:
199+
warnings.warn(
200+
"The `zero_copy` argument to @subscriber is deprecated and ignored. "
201+
"Zero-copy behavior is now determined by the InputStream's `leaky` property "
202+
"(non-leaky subscribers use zero-copy; leaky subscribers receive deep-copied "
203+
"messages). Remove any explicit `zero_copy=...` usage.",
204+
DeprecationWarning,
205+
stacklevel=2,
206+
)
207+
194208
def sub_factory(func):
195209
subscribed_streams: InputStream | None = getattr(func, SUBSCRIBES_ATTR, None)
196210
if subscribed_streams is not None:
197211
raise Exception(f"{func} cannot subscribe to more than one stream")
198212
setattr(func, SUBSCRIBES_ATTR, stream)
199-
setattr(func, ZERO_COPY_ATTR, zero_copy)
213+
setattr(func, ZERO_COPY_ATTR, True)
200214
setattr(func, LEAKY_ATTR, stream.leaky)
201215
setattr(func, MAX_QUEUE_ATTR, stream.max_queue)
202216
return task(func)

src/ezmsg/util/debuglog.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class DebugLog(ez.Unit):
3232
OUTPUT = ez.OutputStream(Any)
3333
"""Send messages back out to continue through the graph."""
3434

35-
@ez.subscriber(INPUT, zero_copy=True)
35+
@ez.subscriber(INPUT)
3636
@ez.publisher(OUTPUT)
3737
async def log(self, msg: Any) -> AsyncGenerator:
3838
logstr = f"{self.SETTINGS.name} - {msg=}"

src/ezmsg/util/messages/key.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ async def on_settings(self, msg: ez.Settings) -> None:
8383
self.apply_settings(msg)
8484
self.construct_generator()
8585

86-
@ez.subscriber(INPUT_SIGNAL, zero_copy=True)
86+
@ez.subscriber(INPUT_SIGNAL)
8787
@ez.publisher(OUTPUT_SIGNAL)
8888
async def on_message(self, message: AxisArray) -> AsyncGenerator:
8989
"""
@@ -125,7 +125,7 @@ class FilterOnKey(ez.Unit):
125125
INPUT_SIGNAL = ez.InputStream(AxisArray)
126126
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
127127

128-
@ez.subscriber(INPUT_SIGNAL, zero_copy=True)
128+
@ez.subscriber(INPUT_SIGNAL)
129129
@ez.publisher(OUTPUT_SIGNAL)
130130
async def on_message(self, message: AxisArray) -> AsyncGenerator:
131131
"""

src/ezmsg/util/messages/modify.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ async def on_settings(self, msg: ez.Settings) -> None:
114114
self.apply_settings(msg)
115115
self._transformer = ModifyAxisTransformer(self.SETTINGS)
116116

117-
@ez.subscriber(INPUT_SIGNAL, zero_copy=True)
117+
@ez.subscriber(INPUT_SIGNAL)
118118
@ez.publisher(OUTPUT_SIGNAL)
119119
async def on_message(self, message: AxisArray) -> AsyncGenerator:
120120
ret = self._transformer(message)

src/ezmsg/util/perf/impl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ class LoadTestRelay(ez.Unit):
130130
INPUT = ez.InputStream(LoadTestSample)
131131
OUTPUT = ez.OutputStream(LoadTestSample)
132132

133-
@ez.subscriber(INPUT, zero_copy=True)
133+
@ez.subscriber(INPUT)
134134
@ez.publisher(OUTPUT)
135135
async def on_msg(self, msg: LoadTestSample) -> typing.AsyncGenerator:
136136
yield self.OUTPUT, msg
@@ -152,7 +152,7 @@ class LoadTestReceiver(ez.Unit):
152152
async def initialize(self) -> None:
153153
ez.logger.info(f"Load test subscriber started. (PID: {os.getpid()})")
154154

155-
@ez.subscriber(INPUT, zero_copy=True)
155+
@ez.subscriber(INPUT)
156156
async def receive(self, sample: LoadTestSample) -> None:
157157
counter = self.STATE.counters.get(sample.key, -1)
158158
if sample.counter != counter + 1:
@@ -166,7 +166,7 @@ async def receive(self, sample: LoadTestSample) -> None:
166166
class LoadTestSink(LoadTestReceiver):
167167
INPUT = ez.InputStream(LoadTestSample)
168168

169-
@ez.subscriber(INPUT, zero_copy=True)
169+
@ez.subscriber(INPUT)
170170
async def receive(self, sample: LoadTestSample) -> None:
171171
await super().receive(sample)
172172
if len(self.STATE.received_data) == self.SETTINGS.num_msgs:

tests/test_subclient.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import asyncio
2+
from contextlib import contextmanager
3+
from uuid import uuid4
4+
5+
import pytest
6+
7+
from ezmsg.core.subclient import Subscriber
8+
from ezmsg.core.netprotocol import Command, encode_str
9+
from ezmsg.core import channelmanager as channelmanager_module
10+
from ezmsg.core import subclient as subclient_module
11+
12+
13+
class DummyChannel:
14+
"""Minimal Channel stand-in for Subscriber tests."""
15+
16+
def __init__(self):
17+
self.clients = {}
18+
self.closed = False
19+
self.waited = False
20+
self.topic = "test"
21+
self.num_buffers = 8
22+
23+
def register_client(self, client_id, queue, local_backpressure=None):
24+
self.clients[client_id] = queue
25+
26+
def unregister_client(self, client_id):
27+
del self.clients[client_id]
28+
29+
def close(self):
30+
self.closed = True
31+
32+
async def wait_closed(self):
33+
self.waited = True
34+
35+
@contextmanager
36+
def get(self, msg_id, client_id):
37+
yield f"msg-{msg_id}"
38+
39+
40+
class DummyWriter:
41+
"""Minimal asyncio.StreamWriter stand-in."""
42+
43+
def __init__(self):
44+
self.buffer = []
45+
self._closed = False
46+
47+
def write(self, data):
48+
self.buffer.append(data)
49+
50+
async def drain(self):
51+
pass
52+
53+
def close(self):
54+
self._closed = True
55+
56+
async def wait_closed(self):
57+
pass
58+
59+
60+
@pytest.mark.asyncio
61+
async def test_subscriber_unregisters_removed_publisher(monkeypatch):
62+
"""A graph UPDATE that drops a publisher must remove it from _channels.
63+
64+
Before the fix (PR #218), Subscriber._cur_pubs was initialised to an
65+
empty set and never updated, so ``cur_pubs - pub_ids`` was always
66+
empty and stale publishers were never unregistered.
67+
"""
68+
async def fake_create(pub_id, address):
69+
return DummyChannel()
70+
71+
monkeypatch.setattr(channelmanager_module.Channel, "create", fake_create)
72+
monkeypatch.setattr(
73+
subclient_module, "CHANNELS", channelmanager_module.ChannelManager()
74+
)
75+
76+
pub_a = uuid4()
77+
78+
reader = asyncio.StreamReader()
79+
writer = DummyWriter()
80+
81+
sub = Subscriber(
82+
id=uuid4(),
83+
topic="test/topic",
84+
graph_address=None,
85+
_guard=Subscriber._SENTINEL,
86+
)
87+
88+
# Protocol sequence:
89+
# UPDATE [pub_a] -> subscriber registers pub_a
90+
# COMPLETE -> _initialized is set
91+
# UPDATE [] -> subscriber should unregister pub_a
92+
reader.feed_data(
93+
Command.UPDATE.value
94+
+ encode_str(str(pub_a))
95+
+ Command.COMPLETE.value
96+
+ Command.UPDATE.value
97+
+ encode_str("")
98+
)
99+
# No EOF yet — connection stays open after the second UPDATE.
100+
101+
task = asyncio.create_task(sub._graph_connection(reader, writer))
102+
103+
# Wait until the subscriber has written two COMPLETE responses
104+
# (one per UPDATE), meaning both UPDATEs have been processed.
105+
for _ in range(200):
106+
if len(writer.buffer) >= 2:
107+
break
108+
await asyncio.sleep(0)
109+
else:
110+
pytest.fail("Subscriber did not process both UPDATEs")
111+
112+
# Connection is still open — this is mid-session state, not post-cleanup.
113+
# Before the fix, pub_a would still be in _channels here.
114+
assert pub_a not in sub._channels, (
115+
"Publisher should be removed from _channels "
116+
"when a graph UPDATE no longer includes it"
117+
)
118+
119+
reader.feed_eof()
120+
await task
121+
122+
123+
@pytest.mark.asyncio
124+
async def test_recv_zero_copy_skips_stale_notification():
125+
"""recv_zero_copy must skip notifications from unregistered publishers.
126+
127+
Before PR #218 there was no guard; a stale notification would cause
128+
a KeyError on ``self._channels[pub_id]``.
129+
"""
130+
sub = Subscriber(
131+
id=uuid4(),
132+
topic="test/topic",
133+
graph_address=None,
134+
_guard=Subscriber._SENTINEL,
135+
)
136+
137+
stale_pub = uuid4()
138+
valid_pub = uuid4()
139+
140+
sub._channels[valid_pub] = DummyChannel()
141+
142+
# Stale notification first, then a valid one.
143+
await sub._incoming.put((stale_pub, 0))
144+
await sub._incoming.put((valid_pub, 1))
145+
146+
# Before the fix this would raise KeyError for stale_pub.
147+
async with sub.recv_zero_copy() as msg:
148+
assert msg == "msg-1"

0 commit comments

Comments
 (0)