|
11 | 11 | BaseAdaptiveTransformerUnit, |
12 | 12 | processor_state, |
13 | 13 | ) |
14 | | -from ezmsg.sigproc.sampler import SampleMessage |
15 | 14 | from ezmsg.util.messages.axisarray import AxisArray, replace |
16 | 15 |
|
17 | 16 | from ..util import AdaptiveLinearRegressor, RegressorType, get_regressor |
@@ -78,30 +77,30 @@ def _reset_state(self, message: AxisArray) -> None: |
78 | 77 | # .template is updated in partial_fit |
79 | 78 | pass |
80 | 79 |
|
81 | | - def partial_fit(self, message: SampleMessage) -> None: |
82 | | - if np.any(np.isnan(message.sample.data)): |
| 80 | + def partial_fit(self, message: AxisArray) -> None: |
| 81 | + if np.any(np.isnan(message.data)): |
83 | 82 | return |
84 | 83 |
|
85 | 84 | if self.settings.model_type in [ |
86 | 85 | AdaptiveLinearRegressor.LINEAR, |
87 | 86 | AdaptiveLinearRegressor.LOGISTIC, |
88 | 87 | ]: |
89 | | - x = pd.DataFrame.from_dict({k: v for k, v in zip(message.sample.axes["ch"].data, message.sample.data.T)}) |
| 88 | + x = pd.DataFrame.from_dict({k: v for k, v in zip(message.axes["ch"].data, message.data.T)}) |
90 | 89 | y = pd.Series( |
91 | | - data=message.trigger.value.data[:, 0], |
92 | | - name=message.trigger.value.axes["ch"].data[0], |
| 90 | + data=message.attrs["trigger"].value.data[:, 0], |
| 91 | + name=message.attrs["trigger"].value.axes["ch"].data[0], |
93 | 92 | ) |
94 | 93 | self.state.model.learn_many(x, y) |
95 | 94 | else: |
96 | | - X = message.sample.data |
97 | | - if message.sample.get_axis_idx("time") != 0: |
98 | | - X = np.moveaxis(X, message.sample.get_axis_idx("time"), 0) |
99 | | - self.state.model.partial_fit(X, message.trigger.value.data) |
| 95 | + X = message.data |
| 96 | + if message.get_axis_idx("time") != 0: |
| 97 | + X = np.moveaxis(X, message.get_axis_idx("time"), 0) |
| 98 | + self.state.model.partial_fit(X, message.attrs["trigger"].value.data) |
100 | 99 |
|
101 | 100 | self.state.template = replace( |
102 | | - message.trigger.value, |
103 | | - data=np.empty_like(message.trigger.value.data), |
104 | | - key=message.trigger.value.key + "_pred", |
| 101 | + message.attrs["trigger"].value, |
| 102 | + data=np.empty_like(message.attrs["trigger"].value.data), |
| 103 | + key=message.attrs["trigger"].value.key + "_pred", |
105 | 104 | ) |
106 | 105 |
|
107 | 106 | def _process(self, message: AxisArray) -> AxisArray | None: |
|
0 commit comments