Skip to content

Commit b283f79

Browse files
Merge pull request #228 from ezmsg-org/fix/collection-topics-relays
Fix: Input/OutputTopic/Relay for Collections
2 parents 198de90 + 70c0101 commit b283f79

9 files changed

Lines changed: 484 additions & 30 deletions

File tree

examples/ezmsg_configs.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ async def listen(self, msg: int) -> None:
8282

8383

8484
class PassthroughCollection(ez.Collection):
85-
INPUT = ez.InputStream(int)
86-
OUTPUT = ez.OutputStream(int)
85+
INPUT = ez.InputTopic(int)
86+
OUTPUT = ez.OutputTopic(int)
8787

8888
def network(self) -> ez.NetworkDefinition:
8989
return ((self.INPUT, self.OUTPUT),)
@@ -136,7 +136,7 @@ def configure(self) -> None:
136136

137137

138138
class PubNoSubCollection(ez.Collection):
139-
OUTPUT = ez.OutputStream(int)
139+
OUTPUT = ez.OutputTopic(int)
140140
GENERATE = Generator()
141141
LOG = DebugLog()
142142

@@ -148,7 +148,7 @@ def network(self) -> ez.NetworkDefinition:
148148

149149

150150
class SubNoPubCollection(ez.Collection):
151-
INPUT = ez.InputStream(int)
151+
INPUT = ez.InputTopic(int)
152152
LISTEN = Listener()
153153

154154
def network(self) -> ez.NetworkDefinition:
@@ -175,7 +175,7 @@ class PubNoSubPassthroughCollection(ez.Collection):
175175
COLLECTION = PubNoSubCollection()
176176
PASSTHROUGH = PassthroughCollection()
177177

178-
OUTPUT = ez.OutputStream(int)
178+
OUTPUT = ez.OutputTopic(int)
179179

180180
def network(self) -> ez.NetworkDefinition:
181181
return (
@@ -188,7 +188,7 @@ class SubNoPubPassthroughCollection(ez.Collection):
188188
COLLECTION = SubNoPubCollection()
189189
PASSTHROUGH = PassthroughCollection()
190190

191-
INPUT = ez.InputStream(int)
191+
INPUT = ez.InputTopic(int)
192192

193193
def network(self) -> ez.NetworkDefinition:
194194
return (

examples/ezmsg_toy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@ class ModifierCollection(ez.Collection):
123123
"""This collection will subscribe to messages
124124
and append the most recent LFO output"""
125125

126-
INPUT = ez.InputStream(str)
127-
OUTPUT = ez.OutputStream(str)
126+
INPUT = ez.InputTopic(str)
127+
OUTPUT = ez.OutputTopic(str)
128128

129129
SIN = LFO()
130130
# SIN2 = LFO()

src/ezmsg/core/__init__.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414
"Settings",
1515
"Collection",
1616
"NetworkDefinition",
17+
"Topic",
18+
"InputTopic",
19+
"OutputTopic",
20+
"InputRelay",
21+
"OutputRelay",
1722
"InputStream",
1823
"OutputStream",
1924
"Unit",
@@ -44,7 +49,15 @@
4449
from .settings import Settings
4550
from .collection import Collection, NetworkDefinition
4651
from .unit import Unit, task, publisher, subscriber, main, timeit, process, thread
47-
from .stream import InputStream, OutputStream
52+
from .stream import (
53+
Topic,
54+
InputTopic,
55+
OutputTopic,
56+
InputRelay,
57+
OutputRelay,
58+
InputStream,
59+
OutputStream,
60+
)
4861
from .backend import run, GraphRunner, GraphRunnerStartError
4962
from .backendprocess import Complete, NormalTermination
5063
from .graphserver import GraphServer

src/ezmsg/core/backend.py

Lines changed: 141 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import logging
66
import os
77
import signal
8+
from dataclasses import dataclass
89
from threading import BrokenBarrierError
910
from multiprocessing import Event, Barrier
1011
from multiprocessing.synchronize import Event as EventType
@@ -16,8 +17,9 @@
1617

1718
from .collection import Collection, NetworkDefinition
1819
from .component import Component
19-
from .stream import Stream
20+
from .stream import Stream, InputRelay, OutputRelay
2021
from .unit import Unit, PROCESS_ATTR
22+
from .relay import _CollectionRelayUnit, _RelaySettings
2123

2224
from .graphserver import GraphService
2325
from .graphcontext import GraphContext
@@ -33,6 +35,16 @@
3335
logger = logging.getLogger("ezmsg")
3436

3537

38+
@dataclass
39+
class _RelayBinding:
40+
kind: str # "input" or "output"
41+
endpoint_topic: str
42+
relay_in_topic: str
43+
relay_out_topic: str
44+
endpoint: InputRelay | OutputRelay
45+
relay_unit: _CollectionRelayUnit
46+
47+
3648
class ExecutionContext:
3749
_process_units: list[list[Unit]]
3850
_processes: list[BackendProcess] | None
@@ -95,22 +107,32 @@ def setup(
95107
start_participant: bool = False,
96108
) -> "ExecutionContext | None":
97109
graph_connections: list[tuple[str, str]] = []
110+
relay_bindings: dict[str, _RelayBinding] = {}
98111

99112
for name, component in components.items():
100113
component._set_name(name)
101114
component._set_location([root_name] if root_name is not None else [])
102115

116+
def normalize_topic(endpoint: Stream | str | enum.Enum, where: str) -> str:
117+
if isinstance(endpoint, Stream):
118+
return endpoint.address
119+
if isinstance(endpoint, enum.Enum):
120+
return endpoint.name
121+
if isinstance(endpoint, str):
122+
return endpoint
123+
raise TypeError(
124+
f"Invalid endpoint type in {where}: {type(endpoint)}. "
125+
"Expected Stream, str, or Enum."
126+
)
127+
103128
if connections is not None:
104129
for from_topic, to_topic in connections:
105-
if isinstance(from_topic, Stream):
106-
from_topic = from_topic.address
107-
if isinstance(to_topic, Stream):
108-
to_topic = to_topic.address
109-
if isinstance(to_topic, enum.Enum):
110-
to_topic = to_topic.name
111-
if isinstance(from_topic, enum.Enum):
112-
from_topic = from_topic.name
113-
graph_connections.append((from_topic, to_topic))
130+
graph_connections.append(
131+
(
132+
normalize_topic(from_topic, "connections"),
133+
normalize_topic(to_topic, "connections"),
134+
)
135+
)
114136

115137
def crawl_components(
116138
component: Component, callback: Callable[[Component], None]
@@ -121,23 +143,115 @@ def crawl_components(
121143
search += list(comp.components.values())
122144
callback(comp)
123145

146+
def input_relay_settings(relay: InputRelay) -> _RelaySettings:
147+
return _RelaySettings(
148+
leaky=relay.leaky,
149+
max_queue=relay.max_queue,
150+
copy_on_forward=relay.copy_on_forward,
151+
)
152+
153+
def output_relay_settings(relay: OutputRelay) -> _RelaySettings:
154+
return _RelaySettings(
155+
host=relay.host,
156+
port=relay.port,
157+
num_buffers=relay.num_buffers,
158+
buf_size=relay.buf_size,
159+
force_tcp=relay.force_tcp,
160+
copy_on_forward=relay.copy_on_forward,
161+
)
162+
163+
def add_collection_relay_units(comp: Component) -> None:
164+
if not isinstance(comp, Collection):
165+
return
166+
167+
for endpoint_name, endpoint in comp.streams.items():
168+
if isinstance(endpoint, InputRelay):
169+
relay_name = f"__relay_in_{endpoint_name}"
170+
if relay_name in comp.components:
171+
raise ValueError(
172+
f"{comp.address} already defines component '{relay_name}'."
173+
)
174+
175+
relay_unit = _CollectionRelayUnit(input_relay_settings(endpoint))
176+
relay_unit._set_name(relay_name)
177+
relay_unit._set_location(comp.location + [comp.name])
178+
comp.components[relay_name] = relay_unit
179+
setattr(comp, relay_name, relay_unit)
180+
181+
relay_bindings[endpoint.address] = _RelayBinding(
182+
kind="input",
183+
endpoint_topic=endpoint.address,
184+
relay_in_topic=relay_unit.INPUT.address,
185+
relay_out_topic=relay_unit.OUTPUT.address,
186+
endpoint=endpoint,
187+
relay_unit=relay_unit,
188+
)
189+
190+
elif isinstance(endpoint, OutputRelay):
191+
relay_name = f"__relay_out_{endpoint_name}"
192+
if relay_name in comp.components:
193+
raise ValueError(
194+
f"{comp.address} already defines component '{relay_name}'."
195+
)
196+
197+
relay_unit = _CollectionRelayUnit(output_relay_settings(endpoint))
198+
relay_unit._set_name(relay_name)
199+
relay_unit._set_location(comp.location + [comp.name])
200+
comp.components[relay_name] = relay_unit
201+
setattr(comp, relay_name, relay_unit)
202+
203+
relay_bindings[endpoint.address] = _RelayBinding(
204+
kind="output",
205+
endpoint_topic=endpoint.address,
206+
relay_in_topic=relay_unit.INPUT.address,
207+
relay_out_topic=relay_unit.OUTPUT.address,
208+
endpoint=endpoint,
209+
relay_unit=relay_unit,
210+
)
211+
212+
for component in components.values():
213+
if isinstance(component, Collection):
214+
crawl_components(component, add_collection_relay_units)
215+
124216
def gather_edges(comp: Component):
125217
if isinstance(comp, Collection):
126218
for from_stream, to_stream in comp.network():
127-
if isinstance(from_stream, Stream):
128-
from_stream = from_stream.address
129-
if isinstance(to_stream, Stream):
130-
to_stream = to_stream.address
131-
if isinstance(to_stream, enum.Enum):
132-
to_stream = to_stream.name
133-
if isinstance(from_stream, enum.Enum):
134-
from_stream = from_stream.name
135-
graph_connections.append((from_stream, to_stream))
219+
graph_connections.append(
220+
(
221+
normalize_topic(from_stream, f"{comp.address}.network"),
222+
normalize_topic(to_stream, f"{comp.address}.network"),
223+
)
224+
)
136225

137226
for component in components.values():
138227
if isinstance(component, Collection):
139228
crawl_components(component, gather_edges)
140229

230+
if relay_bindings:
231+
rewritten_connections: list[tuple[str, str]] = []
232+
for from_topic, to_topic in graph_connections:
233+
to_binding = relay_bindings.get(to_topic, None)
234+
if to_binding is not None and to_binding.kind == "output":
235+
to_topic = to_binding.relay_in_topic
236+
237+
from_binding = relay_bindings.get(from_topic, None)
238+
if from_binding is not None and from_binding.kind == "input":
239+
from_topic = from_binding.relay_out_topic
240+
241+
rewritten_connections.append((from_topic, to_topic))
242+
243+
for binding in relay_bindings.values():
244+
if binding.kind == "input":
245+
rewritten_connections.append(
246+
(binding.endpoint_topic, binding.relay_in_topic)
247+
)
248+
else:
249+
rewritten_connections.append(
250+
(binding.relay_out_topic, binding.endpoint_topic)
251+
)
252+
253+
graph_connections = rewritten_connections
254+
141255
processes = collect_processes(components.values(), process_components)
142256

143257
for component in components.values():
@@ -149,6 +263,14 @@ def configure_collections(comp: Component):
149263

150264
crawl_components(component, configure_collections)
151265

266+
for binding in relay_bindings.values():
267+
if isinstance(binding.endpoint, InputRelay):
268+
binding.relay_unit.apply_settings(input_relay_settings(binding.endpoint))
269+
elif isinstance(binding.endpoint, OutputRelay):
270+
binding.relay_unit.apply_settings(
271+
output_relay_settings(binding.endpoint)
272+
)
273+
152274
if force_single_process:
153275
processes = [[u for pu in processes for u in pu]]
154276

src/ezmsg/core/collection.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
from collections.abc import Collection as AbstractCollection
33
import typing
44
from copy import deepcopy
5+
import warnings
56

6-
from .stream import Stream
7+
from .stream import Stream, InputStream, OutputStream
78
from .component import ComponentMeta, Component
89
from .settings import Settings
910

@@ -34,6 +35,16 @@ def __init__(
3435
if isinstance(field_value, Component):
3536
field_value._set_name(field_name)
3637
cls.__components__[field_name] = field_value
38+
elif isinstance(field_value, (InputStream, OutputStream)):
39+
warnings.warn(
40+
f"{name}.{field_name} uses {type(field_value).__name__} as a "
41+
"Collection boundary endpoint. This behavior is deprecated and "
42+
"will change in a future release. Use InputTopic / OutputTopic "
43+
"for zero-overhead topic shortcuts, or InputRelay / OutputRelay "
44+
"for explicit boundary republishers.",
45+
FutureWarning,
46+
stacklevel=2,
47+
)
3748

3849

3950
class Collection(Component, metaclass=CollectionMeta):

src/ezmsg/core/relay.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from copy import deepcopy
2+
from collections.abc import AsyncGenerator
3+
from typing import Any
4+
5+
from .settings import Settings
6+
from .netprotocol import DEFAULT_SHM_SIZE
7+
from .stream import InputStream, OutputStream
8+
from .unit import Unit, publisher, subscriber
9+
10+
11+
class _RelaySettings(Settings):
12+
leaky: bool = False
13+
max_queue: int | None = None
14+
host: str | None = None
15+
port: int | None = None
16+
num_buffers: int = 32
17+
buf_size: int = DEFAULT_SHM_SIZE
18+
force_tcp: bool = False
19+
copy_on_forward: bool = True
20+
21+
22+
class _CollectionRelayUnit(Unit):
23+
SETTINGS = _RelaySettings
24+
25+
INPUT = InputStream(Any)
26+
OUTPUT = OutputStream(Any)
27+
28+
async def initialize(self) -> None:
29+
self.INPUT.leaky = self.SETTINGS.leaky
30+
self.INPUT.max_queue = self.SETTINGS.max_queue
31+
self.OUTPUT.host = self.SETTINGS.host
32+
self.OUTPUT.port = self.SETTINGS.port
33+
self.OUTPUT.num_buffers = self.SETTINGS.num_buffers
34+
self.OUTPUT.buf_size = self.SETTINGS.buf_size
35+
self.OUTPUT.force_tcp = self.SETTINGS.force_tcp
36+
37+
@subscriber(INPUT)
38+
@publisher(OUTPUT)
39+
async def relay(self, msg: Any) -> AsyncGenerator:
40+
if self.SETTINGS.copy_on_forward:
41+
msg = deepcopy(msg)
42+
yield self.OUTPUT, msg

0 commit comments

Comments
 (0)