Skip to content

Commit abab1c7

Browse files
authored
Merge pull request #7 from ezmsg-org/deprecate_samplemessage
Replace deprecated SampleMessage with AxisArray with trigger in attrs
2 parents 44ad802 + fb559f4 commit abab1c7

20 files changed

Lines changed: 144 additions & 162 deletions

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ license = "MIT"
99
requires-python = ">=3.10.15"
1010
dynamic = ["version"]
1111
dependencies = [
12-
"ezmsg-baseproc>=1.0.2",
13-
"ezmsg-sigproc>=2.14.0",
12+
"ezmsg-baseproc>=1.3.0",
13+
"ezmsg-sigproc>=2.15.0",
1414
"river>=0.22.0",
1515
"scikit-learn>=1.6.0",
1616
"torch>=2.6.0",

src/ezmsg/learn/process/adaptive_linear_regressor.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
BaseAdaptiveTransformerUnit,
1212
processor_state,
1313
)
14-
from ezmsg.sigproc.sampler import SampleMessage
1514
from ezmsg.util.messages.axisarray import AxisArray, replace
1615

1716
from ..util import AdaptiveLinearRegressor, RegressorType, get_regressor
@@ -78,30 +77,30 @@ def _reset_state(self, message: AxisArray) -> None:
7877
# .template is updated in partial_fit
7978
pass
8079

81-
def partial_fit(self, message: SampleMessage) -> None:
82-
if np.any(np.isnan(message.sample.data)):
80+
def partial_fit(self, message: AxisArray) -> None:
81+
if np.any(np.isnan(message.data)):
8382
return
8483

8584
if self.settings.model_type in [
8685
AdaptiveLinearRegressor.LINEAR,
8786
AdaptiveLinearRegressor.LOGISTIC,
8887
]:
89-
x = pd.DataFrame.from_dict({k: v for k, v in zip(message.sample.axes["ch"].data, message.sample.data.T)})
88+
x = pd.DataFrame.from_dict({k: v for k, v in zip(message.axes["ch"].data, message.data.T)})
9089
y = pd.Series(
91-
data=message.trigger.value.data[:, 0],
92-
name=message.trigger.value.axes["ch"].data[0],
90+
data=message.attrs["trigger"].value.data[:, 0],
91+
name=message.attrs["trigger"].value.axes["ch"].data[0],
9392
)
9493
self.state.model.learn_many(x, y)
9594
else:
96-
X = message.sample.data
97-
if message.sample.get_axis_idx("time") != 0:
98-
X = np.moveaxis(X, message.sample.get_axis_idx("time"), 0)
99-
self.state.model.partial_fit(X, message.trigger.value.data)
95+
X = message.data
96+
if message.get_axis_idx("time") != 0:
97+
X = np.moveaxis(X, message.get_axis_idx("time"), 0)
98+
self.state.model.partial_fit(X, message.attrs["trigger"].value.data)
10099

101100
self.state.template = replace(
102-
message.trigger.value,
103-
data=np.empty_like(message.trigger.value.data),
104-
key=message.trigger.value.key + "_pred",
101+
message.attrs["trigger"].value,
102+
data=np.empty_like(message.attrs["trigger"].value.data),
103+
key=message.attrs["trigger"].value.key + "_pred",
105104
)
106105

107106
def _process(self, message: AxisArray) -> AxisArray | None:

src/ezmsg/learn/process/linear_regressor.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
BaseAdaptiveTransformerUnit,
88
processor_state,
99
)
10-
from ezmsg.sigproc.sampler import SampleMessage
1110
from ezmsg.util.messages.axisarray import AxisArray, replace
1211
from sklearn.linear_model._base import LinearModel
1312

@@ -53,18 +52,18 @@ def _reset_state(self, message: AxisArray) -> None:
5352
# .model and .template are initialized in __init__
5453
pass
5554

56-
def partial_fit(self, message: SampleMessage) -> None:
57-
if np.any(np.isnan(message.sample.data)):
55+
def partial_fit(self, message: AxisArray) -> None:
56+
if np.any(np.isnan(message.data)):
5857
return
5958

60-
X = message.sample.data
61-
y = message.trigger.value.data
59+
X = message.data
60+
y = message.attrs["trigger"].value.data
6261
# TODO: Resample should provide identical durations.
6362
self.state.model = self.state.model.fit(X[: y.shape[0]], y[: X.shape[0]])
6463
self.state.template = replace(
65-
message.trigger.value,
64+
message.attrs["trigger"].value,
6665
data=np.array([[]]),
67-
key=message.trigger.value.key + "_pred",
66+
key=message.attrs["trigger"].value.key + "_pred",
6867
)
6968

7069
def _process(self, message: AxisArray) -> AxisArray:

src/ezmsg/learn/process/mlp_old.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
BaseAdaptiveTransformerUnit,
1010
processor_state,
1111
)
12-
from ezmsg.sigproc.sampler import SampleMessage
1312
from ezmsg.util.messages.axisarray import AxisArray
1413
from ezmsg.util.messages.util import replace
1514

@@ -134,14 +133,14 @@ def _to_tensor(self, data: np.ndarray) -> torch.Tensor:
134133
dtype = torch.float32 if self.settings.single_precision else torch.float64
135134
return torch.tensor(data, dtype=dtype, device=self._state.device)
136135

137-
def partial_fit(self, message: SampleMessage) -> None:
136+
def partial_fit(self, message: AxisArray) -> None:
138137
self._state.model.train()
139138

140139
# TODO: loss_fn should be determined by setting
141140
loss_fn = torch.nn.functional.mse_loss
142141

143-
X = self._to_tensor(message.sample.data)
144-
y_targ = self._to_tensor(message.trigger.value)
142+
X = self._to_tensor(message.data)
143+
y_targ = self._to_tensor(message.attrs["trigger"].value)
145144

146145
with torch.set_grad_enabled(True):
147146
self._state.model.train()

src/ezmsg/learn/process/refit_kalman.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
BaseAdaptiveTransformerUnit,
99
processor_state,
1010
)
11-
from ezmsg.sigproc.sampler import SampleMessage
1211
from ezmsg.util.messages.axisarray import AxisArray
1312
from ezmsg.util.messages.util import replace
1413

@@ -284,22 +283,22 @@ def _process(self, message: AxisArray) -> AxisArray:
284283
key=f"{message.key}_filtered" if hasattr(message, "key") else "filtered",
285284
)
286285

287-
def partial_fit(self, message: SampleMessage) -> None:
286+
def partial_fit(self, message: AxisArray) -> None:
288287
"""
289288
Perform refitting using externally provided data.
290289
291-
Expects message.sample.data (neural input) and message.trigger.value as a dict with:
290+
Expects message.data (neural input) and message.attrs["trigger"].value as a dict with:
292291
- Y_state: (n_samples, n_states) array
293292
- intention_velocity_indices: Optional[int]
294293
- target_positions: Optional[np.ndarray]
295294
- cursor_positions: Optional[np.ndarray]
296295
- hold_flags: Optional[list[bool]]
297296
"""
298-
if not hasattr(message, "sample") or not hasattr(message, "trigger"):
297+
if "trigger" not in message.attrs:
299298
raise ValueError("Invalid message format for partial_fit.")
300299

301-
X = np.array(message.sample.data)
302-
values = message.trigger.value
300+
X = np.array(message.data)
301+
values = message.attrs["trigger"].value
303302

304303
if not isinstance(values, dict) or "Y_state" not in values:
305304
raise ValueError("partial_fit expects trigger.value to include at least 'Y_state'.")

src/ezmsg/learn/process/rnn.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import torch
66
from ezmsg.baseproc import BaseAdaptiveTransformer, BaseAdaptiveTransformerUnit
77
from ezmsg.baseproc.util.profile import profile_subpub
8-
from ezmsg.sigproc.sampler import SampleMessage
98
from ezmsg.util.messages.axisarray import AxisArray
109
from ezmsg.util.messages.util import replace
1110

@@ -184,18 +183,18 @@ def _train_step(
184183
if self._state.scheduler is not None:
185184
self._state.scheduler.step()
186185

187-
def partial_fit(self, message: SampleMessage) -> None:
186+
def partial_fit(self, message: AxisArray) -> None:
188187
self._state.model.train()
189188

190-
X = self._to_tensor(message.sample.data)
189+
X = self._to_tensor(message.data)
191190

192191
# Add batch dimension if missing
193192
X, batched = self._ensure_batched(X)
194193

195194
batch_size = X.shape[0]
196-
preserve_state = self._maybe_reset_state(message.sample, batch_size)
195+
preserve_state = self._maybe_reset_state(message, batch_size)
197196

198-
y_targ = message.trigger.value
197+
y_targ = message.attrs["trigger"].value
199198
if not isinstance(y_targ, dict):
200199
y_targ = {"output": y_targ}
201200
y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()}

src/ezmsg/learn/process/sgd.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from ezmsg.baseproc import (
66
BaseAdaptiveTransformer,
77
BaseAdaptiveTransformerUnit,
8-
SampleMessage,
98
processor_state,
109
)
1110
from ezmsg.util.messages.axisarray import AxisArray
@@ -87,23 +86,23 @@ def _process(self, message: AxisArray) -> ClassifierMessage | None:
8786
key=message.key,
8887
)
8988

90-
def partial_fit(self, message: SampleMessage) -> None:
89+
def partial_fit(self, message: AxisArray) -> None:
9190
if self._hash != 0:
92-
self._reset_state(message.sample)
91+
self._reset_state(message)
9392
self._hash = 0
9493

95-
if np.any(np.isnan(message.sample.data)):
94+
if np.any(np.isnan(message.data)):
9695
return
97-
train_sample = message.sample.data.reshape(1, -1)
96+
train_sample = message.data.reshape(1, -1)
9897
if self._state.b_first_train:
9998
self._state.model.partial_fit(
10099
train_sample,
101-
[message.trigger.value],
100+
[message.attrs["trigger"].value],
102101
classes=list(self.settings.label_weights.keys()),
103102
)
104103
self._state.b_first_train = False
105104
else:
106-
self._state.model.partial_fit(train_sample, [message.trigger.value])
105+
self._state.model.partial_fit(train_sample, [message.attrs["trigger"].value])
107106

108107

109108
class SGDDecoder(

src/ezmsg/learn/process/sklearn.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
BaseAdaptiveTransformerUnit,
1111
processor_state,
1212
)
13-
from ezmsg.sigproc.sampler import SampleMessage
1413
from ezmsg.util.messages.axisarray import AxisArray
1514
from ezmsg.util.messages.util import replace
1615

@@ -116,25 +115,25 @@ def _reset_state(self, message: AxisArray) -> None:
116115
# No checkpoint, initialize from scratch
117116
self._init_model()
118117

119-
def partial_fit(self, message: SampleMessage) -> None:
120-
X = message.sample.data
121-
y = message.trigger.value
118+
def partial_fit(self, message: AxisArray) -> None:
119+
X = message.data
120+
y = message.attrs["trigger"].value
122121
if self._state.model is None:
123-
self._reset_state(message.sample)
122+
self._reset_state(message)
124123
if hasattr(self._state.model, "partial_fit"):
125124
kwargs = {}
126125
if self.settings.partial_fit_classes is not None:
127126
kwargs["classes"] = self.settings.partial_fit_classes
128127
self._state.model.partial_fit(X, y, **kwargs)
129128
elif hasattr(self._state.model, "learn_many"):
130-
df_X = pd.DataFrame({k: v for k, v in zip(message.sample.axes["ch"].data, message.sample.data.T)})
129+
df_X = pd.DataFrame({k: v for k, v in zip(message.axes["ch"].data, message.data.T)})
131130
name = (
132-
message.trigger.value.axes["ch"].data[0]
133-
if hasattr(message.trigger.value, "axes") and "ch" in message.trigger.value.axes
131+
message.attrs["trigger"].value.axes["ch"].data[0]
132+
if hasattr(message.attrs["trigger"].value, "axes") and "ch" in message.attrs["trigger"].value.axes
134133
else "target"
135134
)
136135
ser_y = pd.Series(
137-
data=np.asarray(message.trigger.value.data).flatten(),
136+
data=np.asarray(message.attrs["trigger"].value.data).flatten(),
138137
name=name,
139138
)
140139
self._state.model.learn_many(df_X, ser_y)

src/ezmsg/learn/process/torch.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
processor_state,
1313
)
1414
from ezmsg.baseproc.util.profile import profile_subpub
15-
from ezmsg.sigproc.sampler import SampleMessage
1615
from ezmsg.util.messages.axisarray import AxisArray
1716
from ezmsg.util.messages.util import replace
1817

@@ -294,13 +293,13 @@ def _reset_state(self, message: AxisArray) -> None:
294293
def _process(self, message: AxisArray) -> list[AxisArray]:
295294
return self._common_process(message)
296295

297-
def partial_fit(self, message: SampleMessage) -> None:
296+
def partial_fit(self, message: AxisArray) -> None:
298297
self._state.model.train()
299298

300-
X = self._to_tensor(message.sample.data)
299+
X = self._to_tensor(message.data)
301300
X, batched = self._ensure_batched(X)
302301

303-
y_targ = message.trigger.value
302+
y_targ = message.attrs["trigger"].value
304303
if not isinstance(y_targ, dict):
305304
y_targ = {"output": y_targ}
306305
y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()}

src/ezmsg/learn/process/transformer.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import torch
55
from ezmsg.baseproc import BaseAdaptiveTransformer, BaseAdaptiveTransformerUnit
66
from ezmsg.baseproc.util.profile import profile_subpub
7-
from ezmsg.sigproc.sampler import SampleMessage
87
from ezmsg.util.messages.axisarray import AxisArray
98
from ezmsg.util.messages.util import replace
109

@@ -125,13 +124,13 @@ def _process(self, message: AxisArray) -> list[AxisArray]:
125124
)
126125
]
127126

128-
def partial_fit(self, message: SampleMessage) -> None:
127+
def partial_fit(self, message: AxisArray) -> None:
129128
self._state.model.train()
130129

131-
X = self._to_tensor(message.sample.data)
130+
X = self._to_tensor(message.data)
132131
X, batched = self._ensure_batched(X)
133132

134-
y_targ = message.trigger.value
133+
y_targ = message.attrs["trigger"].value
135134
if not isinstance(y_targ, dict):
136135
y_targ = {"output": y_targ}
137136
y_targ = {key: self._to_tensor(value) for key, value in y_targ.items()}

0 commit comments

Comments
 (0)