Skip to content

Commit 7367d37

Browse files
authored
Deemphasize ezmsg.util.generator (#210)
* Convert modify_axis to a factory function on a ModifyAxisTransformer class. * Deemphasize ezmsg.util.generator; not deprecated but use docstrings to discourage generator usage.
1 parent 819770e commit 7367d37

5 files changed

Lines changed: 146 additions & 93 deletions

File tree

examples/ezmsg_generator.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
"""
2+
.. note::
3+
This example demonstrates legacy generator-based patterns from
4+
``ezmsg.util.generator``. New code should use ``ezmsg.baseproc``
5+
(BaseTransformer, BaseProducer, etc.) instead.
6+
27
This ezmsg example showcases a design pattern where core computational
38
logic is encapsulated within a Python generator. This approach is
49
beneficial for several reasons:
@@ -22,14 +27,61 @@
2227
"""
2328

2429
import asyncio
30+
import traceback
2531
import ezmsg.core as ez
2632
import numpy as np
2733
from typing import Any
28-
from collections.abc import Generator
34+
from collections.abc import AsyncGenerator, Callable, Generator
35+
from functools import wraps, reduce
2936
from ezmsg.util.messages.axisarray import AxisArray, replace
3037
from ezmsg.util.debuglog import DebugLog
3138
from ezmsg.util.gen_to_unit import gen_to_unit
32-
from ezmsg.util.generator import consumer, compose, Gen
39+
40+
41+
# --- Inlined helpers (previously from ezmsg.util.generator, now deprecated) ---
42+
43+
def consumer(func):
44+
"""Prime a generator by advancing it to the first yield statement."""
45+
@wraps(func)
46+
def wrapper(*args, **kwargs):
47+
gen = func(*args, **kwargs)
48+
next(gen)
49+
return gen
50+
return wrapper
51+
52+
53+
def compose(*funcs):
54+
"""Compose a chain of generator functions into a single callable."""
55+
return lambda x: reduce(lambda f, g: g.send(f), list(funcs), x)
56+
57+
58+
class GenState(ez.State):
59+
gen: Generator[Any, Any, None]
60+
61+
62+
class Gen(ez.Unit):
63+
STATE = GenState
64+
65+
INPUT = ez.InputStream(Any)
66+
OUTPUT = ez.OutputStream(Any)
67+
68+
async def initialize(self) -> None:
69+
self.construct_generator()
70+
71+
def construct_generator(self):
72+
raise NotImplementedError
73+
74+
@ez.subscriber(INPUT)
75+
@ez.publisher(OUTPUT)
76+
async def on_message(self, message: Any) -> AsyncGenerator:
77+
try:
78+
ret = self.STATE.gen.send(message)
79+
if ret is not None:
80+
yield self.OUTPUT, ret
81+
except (StopIteration, GeneratorExit):
82+
ez.logger.debug(f"Generator closed in {self.address}")
83+
except Exception:
84+
ez.logger.info(traceback.format_exc())
3385

3486

3587
@consumer

src/ezmsg/util/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
77
Key modules:
88
- :mod:`ezmsg.util.debuglog`: Debug logging utilities
9-
- :mod:`ezmsg.util.generator`: Generator-based message processing decorators
9+
- :mod:`ezmsg.util.generator`: Generator-based message processing decorators (for newer code, use ``ezmsg.baseproc`` instead)
1010
- :mod:`ezmsg.util.messagecodec`: JSON encoding/decoding for message logging
1111
- :mod:`ezmsg.util.messagegate`: Message flow control and gating
1212
- :mod:`ezmsg.util.messagelogger`: File-based message logging

src/ezmsg/util/generator.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,15 @@
1+
"""
2+
Generator-based message processing decorators (legacy).
3+
4+
.. note::
5+
For new code, prefer ``ezmsg.baseproc`` (BaseTransformer, BaseProducer, etc.)
6+
over the generator pattern in this module. The baseproc classes provide a cleaner
7+
separation of state management, hashing, and processing, and are the recommended
8+
approach for all new ezmsg units.
9+
10+
This module is maintained for backward compatibility and simple stateless use cases.
11+
"""
12+
113
import ezmsg.core as ez
214
import traceback
315
from collections.abc import AsyncGenerator, Callable, Generator

src/ezmsg/util/messages/modify.py

Lines changed: 76 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,136 +1,122 @@
1-
from collections.abc import AsyncGenerator, Generator
2-
import traceback
1+
import warnings
2+
from collections.abc import AsyncGenerator
33

44
import numpy as np
55

66
import ezmsg.core as ez
77
from .axisarray import AxisArray, replace
8-
from ..generator import consumer, GenState
98

109

11-
@consumer
12-
def modify_axis(
13-
name_map: dict[str, str | None] | None = None,
14-
) -> Generator[AxisArray, AxisArray, None]:
10+
class ModifyAxisSettings(ez.Settings):
1511
"""
16-
Modify an AxisArray's axes and dims according to a name_map.
12+
Settings for ModifyAxisTransformer and ModifyAxis unit.
1713
1814
:param name_map: A dictionary where the keys are the names of the old dims and the values are the new names.
1915
Use None as a value to drop the dimension. If the dropped dimension is not len==1 then an error is raised.
2016
:type name_map: dict[str, str | None] | None
21-
:return: A primed generator object ready to yield an AxisArray with modified axes for each .send(axis_array).
22-
:rtype: collections.abc.Generator[AxisArray, AxisArray, None]
2317
"""
24-
# State variables
25-
axis_arr_in = AxisArray(np.array([]), dims=[""])
26-
axis_arr_out = AxisArray(np.array([]), dims=[""])
27-
28-
while True:
29-
axis_arr_in = yield axis_arr_out
30-
31-
if name_map is not None:
32-
new_dims = [name_map.get(old_k, old_k) for old_k in axis_arr_in.dims]
33-
new_axes = {
34-
name_map.get(old_k, old_k): v for old_k, v in axis_arr_in.axes.items()
35-
}
36-
drop_ax_ix = [
37-
ix
38-
for ix, old_dim in enumerate(axis_arr_in.dims)
39-
if new_dims[ix] is None
40-
]
41-
if len(drop_ax_ix) > 0:
42-
# Rewrite new_dims and new_axes without the dropped axes
43-
new_dims = [_ for _ in new_dims if _ is not None]
44-
new_axes.pop(None, None)
45-
# Reshape data
46-
# np.squeeze will raise ValueError if not len==1
47-
axis_arr_out = replace(
48-
axis_arr_in,
49-
data=np.squeeze(axis_arr_in.data, axis=tuple(drop_ax_ix)),
50-
dims=new_dims,
51-
axes=new_axes,
52-
)
53-
else:
54-
axis_arr_out = replace(axis_arr_in, dims=new_dims, axes=new_axes)
55-
else:
56-
axis_arr_out = axis_arr_in
5718

19+
name_map: dict[str, str | None] | None = None
5820

59-
class ModifyAxisSettings(ez.Settings):
21+
22+
class ModifyAxisTransformer:
23+
"""
24+
Modify an AxisArray's axes and dims according to a name_map.
25+
26+
This is a stateless transformer that renames dimensions and axes, with support
27+
for dropping length-1 dimensions.
28+
29+
:param settings: Settings for the transformer.
30+
:type settings: ModifyAxisSettings
31+
"""
32+
33+
def __init__(self, settings: ModifyAxisSettings = ModifyAxisSettings()):
34+
self.settings = settings
35+
36+
def __call__(self, message: AxisArray) -> AxisArray:
37+
name_map = self.settings.name_map
38+
if name_map is None:
39+
return message
40+
41+
new_dims = [name_map.get(old_k, old_k) for old_k in message.dims]
42+
new_axes = {
43+
name_map.get(old_k, old_k): v for old_k, v in message.axes.items()
44+
}
45+
drop_ax_ix = [
46+
ix
47+
for ix, old_dim in enumerate(message.dims)
48+
if new_dims[ix] is None
49+
]
50+
if len(drop_ax_ix) > 0:
51+
new_dims = [d for d in new_dims if d is not None]
52+
new_axes.pop(None, None)
53+
return replace(
54+
message,
55+
data=np.squeeze(message.data, axis=tuple(drop_ax_ix)),
56+
dims=new_dims,
57+
axes=new_axes,
58+
)
59+
return replace(message, dims=new_dims, axes=new_axes)
60+
61+
_send_warned: bool = False
62+
63+
def send(self, message: AxisArray) -> AxisArray:
64+
"""Alias for __call__. Deprecated — use ``__call__`` directly instead."""
65+
if not ModifyAxisTransformer._send_warned:
66+
ModifyAxisTransformer._send_warned = True
67+
warnings.warn(
68+
"ModifyAxisTransformer.send() is deprecated. Use __call__() instead.",
69+
DeprecationWarning,
70+
stacklevel=2,
71+
)
72+
return self(message)
73+
74+
75+
def modify_axis(
76+
name_map: dict[str, str | None] | None = None,
77+
) -> ModifyAxisTransformer:
6078
"""
61-
Settings for ModifyAxis unit.
79+
Create a :class:`ModifyAxisTransformer` instance.
6280
63-
Configuration for modifying axis names and dimensions of AxisArray messages.
81+
This function exists for backward compatibility with code that previously used
82+
the generator-based ``modify_axis``. The returned object supports ``.send(msg)``.
6483
6584
:param name_map: A dictionary where the keys are the names of the old dims and the values are the new names.
6685
Use None as a value to drop the dimension. If the dropped dimension is not len==1 then an error is raised.
6786
:type name_map: dict[str, str | None] | None
87+
:return: A ModifyAxisTransformer instance.
88+
:rtype: ModifyAxisTransformer
6889
"""
69-
70-
name_map: dict[str, str | None] | None = None
90+
return ModifyAxisTransformer(ModifyAxisSettings(name_map=name_map))
7191

7292

7393
class ModifyAxis(ez.Unit):
7494
"""
7595
Unit for modifying axis names and dimensions of AxisArray messages.
7696
7797
Renames dimensions and axes according to a name mapping, with support for
78-
dropping dimensions. Uses zero-copy operations for efficient processing.
98+
dropping dimensions.
7999
"""
80100

81-
STATE = GenState
82101
SETTINGS = ModifyAxisSettings
83102

84103
INPUT_SIGNAL = ez.InputStream(AxisArray)
85104
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
86105
INPUT_SETTINGS = ez.InputStream(ModifyAxisSettings)
87106

88-
async def initialize(self) -> None:
89-
"""
90-
Initialize the ModifyAxis unit.
91-
92-
Sets up the generator for axis modification operations.
93-
"""
94-
self.construct_generator()
107+
_transformer: ModifyAxisTransformer
95108

96-
def construct_generator(self):
97-
"""
98-
Construct the axis-modifying generator with current settings.
99-
100-
Creates a new modify_axis generator instance using the unit's name mapping.
101-
"""
102-
self.STATE.gen = modify_axis(name_map=self.SETTINGS.name_map)
109+
async def initialize(self) -> None:
110+
self._transformer = ModifyAxisTransformer(self.SETTINGS)
103111

104112
@ez.subscriber(INPUT_SETTINGS)
105113
async def on_settings(self, msg: ez.Settings) -> None:
106-
"""
107-
Handle incoming settings updates.
108-
109-
:param msg: New settings to apply.
110-
:type msg: ez.Settings
111-
"""
112114
self.apply_settings(msg)
113-
self.construct_generator()
115+
self._transformer = ModifyAxisTransformer(self.SETTINGS)
114116

115117
@ez.subscriber(INPUT_SIGNAL, zero_copy=True)
116118
@ez.publisher(OUTPUT_SIGNAL)
117119
async def on_message(self, message: AxisArray) -> AsyncGenerator:
118-
"""
119-
Process incoming AxisArray messages and modify their axes.
120-
121-
Uses zero-copy operations to efficiently modify axis names and dimensions
122-
while preserving data integrity.
123-
124-
:param message: Input AxisArray to modify.
125-
:type message: AxisArray
126-
:return: Async generator yielding AxisArray with modified axes.
127-
:rtype: collections.abc.AsyncGenerator
128-
"""
129-
try:
130-
ret = self.STATE.gen.send(message)
131-
if ret is not None:
132-
yield self.OUTPUT_SIGNAL, ret
133-
except (StopIteration, GeneratorExit):
134-
ez.logger.debug(f"ModifyAxis closed in {self.address}")
135-
except Exception:
136-
ez.logger.info(traceback.format_exc())
120+
ret = self._transformer(message)
121+
if ret is not None:
122+
yield self.OUTPUT_SIGNAL, ret

tests/test_generator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# NOTE: This file tests the legacy ezmsg.util.generator module.
2+
# While this pattern is typically not recommended, it has some narrow use cases.
3+
# The @consumer usages here are the subject of the tests and should remain.
14
from collections.abc import AsyncGenerator, Generator
25
import copy
36
import json

0 commit comments

Comments
 (0)