55import logging
66import os
77import signal
8+ from dataclasses import dataclass
89from threading import BrokenBarrierError
910from multiprocessing import Event , Barrier
1011from multiprocessing .synchronize import Event as EventType
1617
1718from .collection import Collection , NetworkDefinition
1819from .component import Component
19- from .stream import Stream
20+ from .stream import Stream , InputRelay , OutputRelay
2021from .unit import Unit , PROCESS_ATTR
22+ from .relay import _CollectionRelayUnit , _RelaySettings
2123
2224from .graphserver import GraphService
2325from .graphcontext import GraphContext
3335logger = logging .getLogger ("ezmsg" )
3436
3537
38+ @dataclass
39+ class _RelayBinding :
40+ kind : str # "input" or "output"
41+ endpoint_topic : str
42+ relay_in_topic : str
43+ relay_out_topic : str
44+ endpoint : InputRelay | OutputRelay
45+ relay_unit : _CollectionRelayUnit
46+
47+
3648class ExecutionContext :
3749 _process_units : list [list [Unit ]]
3850 _processes : list [BackendProcess ] | None
@@ -95,22 +107,32 @@ def setup(
95107 start_participant : bool = False ,
96108 ) -> "ExecutionContext | None" :
97109 graph_connections : list [tuple [str , str ]] = []
110+ relay_bindings : dict [str , _RelayBinding ] = {}
98111
99112 for name , component in components .items ():
100113 component ._set_name (name )
101114 component ._set_location ([root_name ] if root_name is not None else [])
102115
116+ def normalize_topic (endpoint : Stream | str | enum .Enum , where : str ) -> str :
117+ if isinstance (endpoint , Stream ):
118+ return endpoint .address
119+ if isinstance (endpoint , enum .Enum ):
120+ return endpoint .name
121+ if isinstance (endpoint , str ):
122+ return endpoint
123+ raise TypeError (
124+ f"Invalid endpoint type in { where } : { type (endpoint )} . "
125+ "Expected Stream, str, or Enum."
126+ )
127+
103128 if connections is not None :
104129 for from_topic , to_topic in connections :
105- if isinstance (from_topic , Stream ):
106- from_topic = from_topic .address
107- if isinstance (to_topic , Stream ):
108- to_topic = to_topic .address
109- if isinstance (to_topic , enum .Enum ):
110- to_topic = to_topic .name
111- if isinstance (from_topic , enum .Enum ):
112- from_topic = from_topic .name
113- graph_connections .append ((from_topic , to_topic ))
130+ graph_connections .append (
131+ (
132+ normalize_topic (from_topic , "connections" ),
133+ normalize_topic (to_topic , "connections" ),
134+ )
135+ )
114136
115137 def crawl_components (
116138 component : Component , callback : Callable [[Component ], None ]
@@ -121,23 +143,115 @@ def crawl_components(
121143 search += list (comp .components .values ())
122144 callback (comp )
123145
146+ def input_relay_settings (relay : InputRelay ) -> _RelaySettings :
147+ return _RelaySettings (
148+ leaky = relay .leaky ,
149+ max_queue = relay .max_queue ,
150+ copy_on_forward = relay .copy_on_forward ,
151+ )
152+
153+ def output_relay_settings (relay : OutputRelay ) -> _RelaySettings :
154+ return _RelaySettings (
155+ host = relay .host ,
156+ port = relay .port ,
157+ num_buffers = relay .num_buffers ,
158+ buf_size = relay .buf_size ,
159+ force_tcp = relay .force_tcp ,
160+ copy_on_forward = relay .copy_on_forward ,
161+ )
162+
163+ def add_collection_relay_units (comp : Component ) -> None :
164+ if not isinstance (comp , Collection ):
165+ return
166+
167+ for endpoint_name , endpoint in comp .streams .items ():
168+ if isinstance (endpoint , InputRelay ):
169+ relay_name = f"__relay_in_{ endpoint_name } "
170+ if relay_name in comp .components :
171+ raise ValueError (
172+ f"{ comp .address } already defines component '{ relay_name } '."
173+ )
174+
175+ relay_unit = _CollectionRelayUnit (input_relay_settings (endpoint ))
176+ relay_unit ._set_name (relay_name )
177+ relay_unit ._set_location (comp .location + [comp .name ])
178+ comp .components [relay_name ] = relay_unit
179+ setattr (comp , relay_name , relay_unit )
180+
181+ relay_bindings [endpoint .address ] = _RelayBinding (
182+ kind = "input" ,
183+ endpoint_topic = endpoint .address ,
184+ relay_in_topic = relay_unit .INPUT .address ,
185+ relay_out_topic = relay_unit .OUTPUT .address ,
186+ endpoint = endpoint ,
187+ relay_unit = relay_unit ,
188+ )
189+
190+ elif isinstance (endpoint , OutputRelay ):
191+ relay_name = f"__relay_out_{ endpoint_name } "
192+ if relay_name in comp .components :
193+ raise ValueError (
194+ f"{ comp .address } already defines component '{ relay_name } '."
195+ )
196+
197+ relay_unit = _CollectionRelayUnit (output_relay_settings (endpoint ))
198+ relay_unit ._set_name (relay_name )
199+ relay_unit ._set_location (comp .location + [comp .name ])
200+ comp .components [relay_name ] = relay_unit
201+ setattr (comp , relay_name , relay_unit )
202+
203+ relay_bindings [endpoint .address ] = _RelayBinding (
204+ kind = "output" ,
205+ endpoint_topic = endpoint .address ,
206+ relay_in_topic = relay_unit .INPUT .address ,
207+ relay_out_topic = relay_unit .OUTPUT .address ,
208+ endpoint = endpoint ,
209+ relay_unit = relay_unit ,
210+ )
211+
212+ for component in components .values ():
213+ if isinstance (component , Collection ):
214+ crawl_components (component , add_collection_relay_units )
215+
124216 def gather_edges (comp : Component ):
125217 if isinstance (comp , Collection ):
126218 for from_stream , to_stream in comp .network ():
127- if isinstance (from_stream , Stream ):
128- from_stream = from_stream .address
129- if isinstance (to_stream , Stream ):
130- to_stream = to_stream .address
131- if isinstance (to_stream , enum .Enum ):
132- to_stream = to_stream .name
133- if isinstance (from_stream , enum .Enum ):
134- from_stream = from_stream .name
135- graph_connections .append ((from_stream , to_stream ))
219+ graph_connections .append (
220+ (
221+ normalize_topic (from_stream , f"{ comp .address } .network" ),
222+ normalize_topic (to_stream , f"{ comp .address } .network" ),
223+ )
224+ )
136225
137226 for component in components .values ():
138227 if isinstance (component , Collection ):
139228 crawl_components (component , gather_edges )
140229
230+ if relay_bindings :
231+ rewritten_connections : list [tuple [str , str ]] = []
232+ for from_topic , to_topic in graph_connections :
233+ to_binding = relay_bindings .get (to_topic , None )
234+ if to_binding is not None and to_binding .kind == "output" :
235+ to_topic = to_binding .relay_in_topic
236+
237+ from_binding = relay_bindings .get (from_topic , None )
238+ if from_binding is not None and from_binding .kind == "input" :
239+ from_topic = from_binding .relay_out_topic
240+
241+ rewritten_connections .append ((from_topic , to_topic ))
242+
243+ for binding in relay_bindings .values ():
244+ if binding .kind == "input" :
245+ rewritten_connections .append (
246+ (binding .endpoint_topic , binding .relay_in_topic )
247+ )
248+ else :
249+ rewritten_connections .append (
250+ (binding .relay_out_topic , binding .endpoint_topic )
251+ )
252+
253+ graph_connections = rewritten_connections
254+
141255 processes = collect_processes (components .values (), process_components )
142256
143257 for component in components .values ():
@@ -149,6 +263,14 @@ def configure_collections(comp: Component):
149263
150264 crawl_components (component , configure_collections )
151265
266+ for binding in relay_bindings .values ():
267+ if isinstance (binding .endpoint , InputRelay ):
268+ binding .relay_unit .apply_settings (input_relay_settings (binding .endpoint ))
269+ elif isinstance (binding .endpoint , OutputRelay ):
270+ binding .relay_unit .apply_settings (
271+ output_relay_settings (binding .endpoint )
272+ )
273+
152274 if force_single_process :
153275 processes = [[u for pu in processes for u in pu ]]
154276
0 commit comments