Skip to content

Commit 85e8dbb

Browse files
committed
SGD - fix extra model refresh when sample comes before data.
1 parent c93ca9e commit 85e8dbb

3 files changed

Lines changed: 6 additions & 8 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ requires-python = ">=3.10.15"
1010
dynamic = ["version"]
1111
dependencies = [
1212
"ezmsg-baseproc>=1.0.2",
13-
"ezmsg-sigproc",
13+
"ezmsg-sigproc>=2.10.0",
1414
"river>=0.22.0",
1515
"scikit-learn>=1.6.0",
1616
"torch>=2.6.0",

src/ezmsg/learn/process/sgd.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def _refreshed_model(self):
6262

6363
def _reset_state(self, message: AxisArray) -> None:
6464
self._state.model = self._refreshed_model()
65-
self._state.b_first_train = True
6665

6766
def _process(self, message: AxisArray) -> ClassifierMessage | None:
6867
if self._state.model is None or not message.data.size:
@@ -89,10 +88,9 @@ def _process(self, message: AxisArray) -> ClassifierMessage | None:
8988
)
9089

9190
def partial_fit(self, message: SampleMessage) -> None:
92-
if self._state.model is None:
93-
# Initialize model on first training sample
94-
self._state.model = self._refreshed_model()
95-
self._state.b_first_train = True
91+
if self._hash != 0:
92+
self._reset_state(message.sample)
93+
self._hash = 0
9694

9795
if np.any(np.isnan(message.sample.data)):
9896
return

tests/unit/test_sgd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def test_sgd():
5353

5454
# Now let's try training on all samples
5555
for sample in samples:
56-
decoder(sample)
56+
decoder(sample) # Should invoke partial_fit
5757
# Then doing inference on all multi-wins
5858
probas = [decoder(win) for win in windows]
5959

@@ -71,7 +71,7 @@ def test_sgd():
7171
)
7272
probas = []
7373
for samp_ix, samp in enumerate(samples):
74-
decoder(samp)
74+
decoder(samp) # Should invoke partial_fit
7575
probas.append(decoder(windows[samp_ix * 2]))
7676
probas.append(decoder(windows[samp_ix * 2 + 1]))
7777
class_ids = []

0 commit comments

Comments
 (0)