Skip to content

Commit f7224c4

Browse files
authored
Merge pull request #8 from ezmsg-org/partial_fit
Directly use `partial_fit` instead of deprecated routing through `__call__`
2 parents abab1c7 + 4b45549 commit f7224c4

9 files changed

Lines changed: 19 additions & 24 deletions

File tree

docs/source/guides/classification.rst

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ For models that support ``partial_fit``, you can update them during streaming:
125125
.. code-block:: python
126126
127127
from ezmsg.learn.process.sklearn import SklearnModelProcessor, SklearnModelSettings
128-
from ezmsg.sigproc.sampler import SampleMessage
128+
from ezmsg.baseproc import SampleTriggerMessage
129+
from ezmsg.util.messages.util import replace
129130
130131
# Create processor with online learning support
131132
processor = SklearnModelProcessor(
@@ -137,9 +138,9 @@ For models that support ``partial_fit``, you can update them during streaming:
137138
)
138139
139140
# Training with labeled samples
140-
sample_msg = SampleMessage(
141-
sample=feature_array, # AxisArray with features
142-
trigger=label_value, # The class label
141+
sample_msg = replace(
142+
feature_array, # AxisArray with features
143+
attrs={"trigger": SampleTriggerMessage(value=label_value)}
143144
)
144145
processor.partial_fit(sample_msg)
145146

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ license = "MIT"
99
requires-python = ">=3.10.15"
1010
dynamic = ["version"]
1111
dependencies = [
12-
"ezmsg-baseproc>=1.3.0",
12+
"ezmsg-baseproc>=1.4.0",
1313
"ezmsg-sigproc>=2.15.0",
1414
"river>=0.22.0",
1515
"scikit-learn>=1.6.0",

src/ezmsg/learn/process/ssr.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -265,11 +265,6 @@ def fit(self, X: np.ndarray) -> None:
265265
self._state.weights = self._solve_weights(self._state.cxx)
266266
self._on_weights_updated()
267267

268-
def fit_transform(self, message: AxisArray) -> AxisArray:
269-
"""Convenience: ``partial_fit`` then ``_process``."""
270-
self.partial_fit(message)
271-
return self._process(message)
272-
273268
# -- abstract hooks for subclasses ---------------------------------------
274269

275270
@abstractmethod

tests/unit/test_adaptive_linear_regressor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_adaptive_linear_regressor(model_type: str):
4545
samp = replace(sig_axarr, attrs={"trigger": samp_trig})
4646

4747
proc = AdaptiveLinearRegressorTransformer(model_type=model_type)
48-
_ = proc.send(samp)
49-
preds = proc.send(replace(sig_axarr, data=X + np.random.randn(*X.shape)))
48+
proc.partial_fit(samp)
49+
preds = proc(replace(sig_axarr, data=X + np.random.randn(*X.shape)))
5050
assert isinstance(preds, AxisArray)
5151
assert preds.data.shape == (n_times, 1)

tests/unit/test_linear_regressor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ def test_linear_regressor(model_type: str):
4242
)
4343
samp = replace(sig_axarr, attrs={"trigger": samp_trig})
4444

45-
gen = LinearRegressorTransformer(model_type=model_type)
46-
_ = gen.send(samp)
47-
preds = gen.send(replace(sig_axarr, data=X + 0.1 * np.random.randn(*X.shape)))
45+
proc = LinearRegressorTransformer(model_type=model_type)
46+
proc.partial_fit(samp)
47+
preds = proc(replace(sig_axarr, data=X + 0.1 * np.random.randn(*X.shape)))
4848
rss = ((samp_trig.value.data - preds.data) ** 2).sum()
4949
tss = ((samp_trig.value.data - samp_trig.value.data.mean()) ** 2).sum()
5050
rsq = 1 - rss / tss

tests/unit/test_mlp_old.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,7 @@ def xy_gen(set: int = 0):
186186
# Train: This is unrealistic in that we would normally do inference on many axisarray messages throughout
187187
# the trial, and only do training infrequently at the end of a trial if we can infer the labels.
188188
# But I'm too lazy to split the data into many small axarrs and one large SampleMessage per trial.
189-
# Note: We don't have to call `partial_fit` because `__call__` inspects the message type and calls it for us.
190-
proc(sample_msg)
189+
proc.partial_fit(sample_msg)
191190

192191
def eval_test(processor, set: int = 1):
193192
# Run the test inference

tests/unit/test_rnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def test_rnn_partial_fit(simple_message):
143143
attrs={**simple_message.attrs, "trigger": SampleTriggerMessage(timestamp=0.0, value=target_value)},
144144
)
145145

146-
proc(sample_message)
146+
proc.partial_fit(sample_message)
147147

148148
assert not proc.state.model.training
149149
updated_weights = [p.detach() for p in proc.state.model.parameters()]

tests/unit/test_sgd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_sgd():
5454

5555
# Now let's try training on all samples
5656
for sample in samples:
57-
decoder(sample) # Should invoke partial_fit
57+
decoder.partial_fit(sample)
5858
# Then doing inference on all multi-wins
5959
probas = [decoder(win) for win in windows]
6060

@@ -72,7 +72,7 @@ def test_sgd():
7272
)
7373
probas = []
7474
for samp_ix, samp in enumerate(samples):
75-
decoder(samp) # Should invoke partial_fit
75+
decoder.partial_fit(samp)
7676
probas.append(decoder(windows[samp_ix * 2]))
7777
probas.append(decoder(windows[samp_ix * 2 + 1]))
7878
class_ids = []

tests/unit/test_ssr.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -230,15 +230,15 @@ def test_non_last_axis(self):
230230
assert out.data.shape == X.shape
231231

232232

233-
class TestFitTransform:
234-
def test_fit_transform(self):
235-
"""fit_transform matches separate partial_fit + process."""
233+
class TestPartialFitTransform:
234+
def test_partial_fit_transform(self):
235+
"""partial_fit_transform matches separate partial_fit + process."""
236236
rng = np.random.default_rng(10)
237237
X = _random_data(rng=rng)
238238
msg = _make_axisarray(X)
239239

240240
proc1 = LRRTransformer(LRRSettings())
241-
out1 = proc1.fit_transform(msg)
241+
out1 = proc1.partial_fit_transform(msg)
242242

243243
proc2 = LRRTransformer(LRRSettings())
244244
proc2.partial_fit(msg)

0 commit comments

Comments
 (0)