|
1 | | -from collections.abc import AsyncGenerator, Generator |
2 | | -import traceback |
| 1 | +import warnings |
| 2 | +from collections.abc import AsyncGenerator |
3 | 3 |
|
4 | 4 | import numpy as np |
5 | 5 |
|
6 | 6 | import ezmsg.core as ez |
7 | 7 | from .axisarray import AxisArray, replace |
8 | | -from ..generator import consumer, GenState |
9 | 8 |
|
10 | 9 |
|
11 | | -@consumer |
12 | | -def modify_axis( |
13 | | - name_map: dict[str, str | None] | None = None, |
14 | | -) -> Generator[AxisArray, AxisArray, None]: |
| 10 | +class ModifyAxisSettings(ez.Settings): |
15 | 11 | """ |
16 | | - Modify an AxisArray's axes and dims according to a name_map. |
| 12 | + Settings for ModifyAxisTransformer and ModifyAxis unit. |
17 | 13 |
|
18 | 14 | :param name_map: A dictionary where the keys are the names of the old dims and the values are the new names. |
19 | 15 | Use None as a value to drop the dimension. If the dropped dimension is not len==1 then an error is raised. |
20 | 16 | :type name_map: dict[str, str | None] | None |
21 | | - :return: A primed generator object ready to yield an AxisArray with modified axes for each .send(axis_array). |
22 | | - :rtype: collections.abc.Generator[AxisArray, AxisArray, None] |
23 | 17 | """ |
24 | | - # State variables |
25 | | - axis_arr_in = AxisArray(np.array([]), dims=[""]) |
26 | | - axis_arr_out = AxisArray(np.array([]), dims=[""]) |
27 | | - |
28 | | - while True: |
29 | | - axis_arr_in = yield axis_arr_out |
30 | | - |
31 | | - if name_map is not None: |
32 | | - new_dims = [name_map.get(old_k, old_k) for old_k in axis_arr_in.dims] |
33 | | - new_axes = { |
34 | | - name_map.get(old_k, old_k): v for old_k, v in axis_arr_in.axes.items() |
35 | | - } |
36 | | - drop_ax_ix = [ |
37 | | - ix |
38 | | - for ix, old_dim in enumerate(axis_arr_in.dims) |
39 | | - if new_dims[ix] is None |
40 | | - ] |
41 | | - if len(drop_ax_ix) > 0: |
42 | | - # Rewrite new_dims and new_axes without the dropped axes |
43 | | - new_dims = [_ for _ in new_dims if _ is not None] |
44 | | - new_axes.pop(None, None) |
45 | | - # Reshape data |
46 | | - # np.squeeze will raise ValueError if not len==1 |
47 | | - axis_arr_out = replace( |
48 | | - axis_arr_in, |
49 | | - data=np.squeeze(axis_arr_in.data, axis=tuple(drop_ax_ix)), |
50 | | - dims=new_dims, |
51 | | - axes=new_axes, |
52 | | - ) |
53 | | - else: |
54 | | - axis_arr_out = replace(axis_arr_in, dims=new_dims, axes=new_axes) |
55 | | - else: |
56 | | - axis_arr_out = axis_arr_in |
57 | 18 |
|
| 19 | + name_map: dict[str, str | None] | None = None |
58 | 20 |
|
59 | | -class ModifyAxisSettings(ez.Settings): |
| 21 | + |
| 22 | +class ModifyAxisTransformer: |
| 23 | + """ |
| 24 | + Modify an AxisArray's axes and dims according to a name_map. |
| 25 | +
|
| 26 | + This is a stateless transformer that renames dimensions and axes, with support |
| 27 | + for dropping length-1 dimensions. |
| 28 | +
|
| 29 | + :param settings: Settings for the transformer. |
| 30 | + :type settings: ModifyAxisSettings |
| 31 | + """ |
| 32 | + |
| 33 | + def __init__(self, settings: ModifyAxisSettings = ModifyAxisSettings()): |
| 34 | + self.settings = settings |
| 35 | + |
| 36 | + def __call__(self, message: AxisArray) -> AxisArray: |
| 37 | + name_map = self.settings.name_map |
| 38 | + if name_map is None: |
| 39 | + return message |
| 40 | + |
| 41 | + new_dims = [name_map.get(old_k, old_k) for old_k in message.dims] |
| 42 | + new_axes = { |
| 43 | + name_map.get(old_k, old_k): v for old_k, v in message.axes.items() |
| 44 | + } |
| 45 | + drop_ax_ix = [ |
| 46 | + ix |
| 47 | + for ix, old_dim in enumerate(message.dims) |
| 48 | + if new_dims[ix] is None |
| 49 | + ] |
| 50 | + if len(drop_ax_ix) > 0: |
| 51 | + new_dims = [d for d in new_dims if d is not None] |
| 52 | + new_axes.pop(None, None) |
| 53 | + return replace( |
| 54 | + message, |
| 55 | + data=np.squeeze(message.data, axis=tuple(drop_ax_ix)), |
| 56 | + dims=new_dims, |
| 57 | + axes=new_axes, |
| 58 | + ) |
| 59 | + return replace(message, dims=new_dims, axes=new_axes) |
| 60 | + |
| 61 | + _send_warned: bool = False |
| 62 | + |
| 63 | + def send(self, message: AxisArray) -> AxisArray: |
| 64 | + """Alias for __call__. Deprecated — use ``__call__`` directly instead.""" |
| 65 | + if not ModifyAxisTransformer._send_warned: |
| 66 | + ModifyAxisTransformer._send_warned = True |
| 67 | + warnings.warn( |
| 68 | + "ModifyAxisTransformer.send() is deprecated. Use __call__() instead.", |
| 69 | + DeprecationWarning, |
| 70 | + stacklevel=2, |
| 71 | + ) |
| 72 | + return self(message) |
| 73 | + |
| 74 | + |
| 75 | +def modify_axis( |
| 76 | + name_map: dict[str, str | None] | None = None, |
| 77 | +) -> ModifyAxisTransformer: |
60 | 78 | """ |
61 | | - Settings for ModifyAxis unit. |
| 79 | + Create a :class:`ModifyAxisTransformer` instance. |
62 | 80 |
|
63 | | - Configuration for modifying axis names and dimensions of AxisArray messages. |
| 81 | + This function exists for backward compatibility with code that previously used |
| 82 | + the generator-based ``modify_axis``. The returned object supports ``.send(msg)``. |
64 | 83 |
|
65 | 84 | :param name_map: A dictionary where the keys are the names of the old dims and the values are the new names. |
66 | 85 | Use None as a value to drop the dimension. If the dropped dimension is not len==1 then an error is raised. |
67 | 86 | :type name_map: dict[str, str | None] | None |
| 87 | + :return: A ModifyAxisTransformer instance. |
| 88 | + :rtype: ModifyAxisTransformer |
68 | 89 | """ |
69 | | - |
70 | | - name_map: dict[str, str | None] | None = None |
| 90 | + return ModifyAxisTransformer(ModifyAxisSettings(name_map=name_map)) |
71 | 91 |
|
72 | 92 |
|
73 | 93 | class ModifyAxis(ez.Unit): |
74 | 94 | """ |
75 | 95 | Unit for modifying axis names and dimensions of AxisArray messages. |
76 | 96 |
|
77 | 97 | Renames dimensions and axes according to a name mapping, with support for |
78 | | - dropping dimensions. Uses zero-copy operations for efficient processing. |
| 98 | + dropping dimensions. |
79 | 99 | """ |
80 | 100 |
|
81 | | - STATE = GenState |
82 | 101 | SETTINGS = ModifyAxisSettings |
83 | 102 |
|
84 | 103 | INPUT_SIGNAL = ez.InputStream(AxisArray) |
85 | 104 | OUTPUT_SIGNAL = ez.OutputStream(AxisArray) |
86 | 105 | INPUT_SETTINGS = ez.InputStream(ModifyAxisSettings) |
87 | 106 |
|
88 | | - async def initialize(self) -> None: |
89 | | - """ |
90 | | - Initialize the ModifyAxis unit. |
91 | | -
|
92 | | - Sets up the generator for axis modification operations. |
93 | | - """ |
94 | | - self.construct_generator() |
| 107 | + _transformer: ModifyAxisTransformer |
95 | 108 |
|
96 | | - def construct_generator(self): |
97 | | - """ |
98 | | - Construct the axis-modifying generator with current settings. |
99 | | -
|
100 | | - Creates a new modify_axis generator instance using the unit's name mapping. |
101 | | - """ |
102 | | - self.STATE.gen = modify_axis(name_map=self.SETTINGS.name_map) |
| 109 | + async def initialize(self) -> None: |
| 110 | + self._transformer = ModifyAxisTransformer(self.SETTINGS) |
103 | 111 |
|
104 | 112 | @ez.subscriber(INPUT_SETTINGS) |
105 | 113 | async def on_settings(self, msg: ez.Settings) -> None: |
106 | | - """ |
107 | | - Handle incoming settings updates. |
108 | | -
|
109 | | - :param msg: New settings to apply. |
110 | | - :type msg: ez.Settings |
111 | | - """ |
112 | 114 | self.apply_settings(msg) |
113 | | - self.construct_generator() |
| 115 | + self._transformer = ModifyAxisTransformer(self.SETTINGS) |
114 | 116 |
|
115 | 117 | @ez.subscriber(INPUT_SIGNAL, zero_copy=True) |
116 | 118 | @ez.publisher(OUTPUT_SIGNAL) |
117 | 119 | async def on_message(self, message: AxisArray) -> AsyncGenerator: |
118 | | - """ |
119 | | - Process incoming AxisArray messages and modify their axes. |
120 | | -
|
121 | | - Uses zero-copy operations to efficiently modify axis names and dimensions |
122 | | - while preserving data integrity. |
123 | | -
|
124 | | - :param message: Input AxisArray to modify. |
125 | | - :type message: AxisArray |
126 | | - :return: Async generator yielding AxisArray with modified axes. |
127 | | - :rtype: collections.abc.AsyncGenerator |
128 | | - """ |
129 | | - try: |
130 | | - ret = self.STATE.gen.send(message) |
131 | | - if ret is not None: |
132 | | - yield self.OUTPUT_SIGNAL, ret |
133 | | - except (StopIteration, GeneratorExit): |
134 | | - ez.logger.debug(f"ModifyAxis closed in {self.address}") |
135 | | - except Exception: |
136 | | - ez.logger.info(traceback.format_exc()) |
| 120 | + ret = self._transformer(message) |
| 121 | + if ret is not None: |
| 122 | + yield self.OUTPUT_SIGNAL, ret |
0 commit comments