Skip to content

Commit c93ca9e

Browse files
committed
Update deprecated API usage (@consumer, GenAxisArray)
1 parent 9f02f2a commit c93ca9e

2 files changed

Lines changed: 110 additions & 118 deletions

File tree

src/ezmsg/learn/process/sgd.py

Lines changed: 97 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22

33
import ezmsg.core as ez
44
import numpy as np
5-
from ezmsg.baseproc import GenAxisArray
6-
from ezmsg.sigproc.sampler import SampleMessage
7-
from ezmsg.util.generator import consumer
5+
from ezmsg.baseproc import (
6+
BaseAdaptiveTransformer,
7+
BaseAdaptiveTransformerUnit,
8+
SampleMessage,
9+
processor_state,
10+
)
811
from ezmsg.util.messages.axisarray import AxisArray
912
from ezmsg.util.messages.util import replace
1013
from sklearn.exceptions import NotFittedError
@@ -13,103 +16,6 @@
1316
from ..util import ClassifierMessage
1417

1518

16-
@consumer
17-
def sgd_decoder(
18-
alpha: float = 1.5e-5,
19-
eta0: float = 1e-7, # Lower than what you'd use for offline training.
20-
loss: str = "squared_hinge",
21-
label_weights: dict[str, float] | None = None,
22-
settings_path: str | None = None,
23-
) -> typing.Generator[AxisArray | SampleMessage, ClassifierMessage | None, None]:
24-
"""
25-
Passive Aggressive Classifier
26-
Online Passive-Aggressive Algorithms <http://jmlr.csail.mit.edu/papers/volume7/crammer06a/crammer06a.pdf>
27-
K. Crammer, O. Dekel, J. Keshat, S. Shalev-Shwartz, Y. Singer - JMLR (2006)
28-
29-
Args:
30-
alpha: Maximum step size (regularization)
31-
eta0: The initial learning rate for the 'adaptive’ schedules.
32-
loss: The loss function to be used:
33-
hinge: equivalent to PA-I in the reference paper.
34-
squared_hinge: equivalent to PA-II in the reference paper.
35-
label_weights: An optional dictionary of label names and their relative weight.
36-
e.g., {'Go': 31.0, 'Stop': 0.5}
37-
If this is None then settings_path must be provided and the pre-trained model
38-
settings_path: Path to the stored sklearn model pkl file.
39-
40-
Returns:
41-
Generator that accepts `SampleMessage` for incremental training (`partial_fit`) and yields None,
42-
or `AxisArray` for inference (`predict`) and yields a `ClassifierMessage`.
43-
"""
44-
# pre-init inputs and outputs
45-
msg_out = ClassifierMessage(data=np.array([]), dims=[""])
46-
47-
# State variables:
48-
49-
if settings_path is not None:
50-
import pickle
51-
52-
with open(settings_path, "rb") as f:
53-
model = pickle.load(f)
54-
if label_weights is not None:
55-
model.class_weight = label_weights
56-
# Overwrite eta0, probably with a value lower than what was used online.
57-
model.eta0 = eta0
58-
else:
59-
model = SGDClassifier(
60-
loss=loss,
61-
alpha=alpha,
62-
penalty="elasticnet",
63-
learning_rate="adaptive",
64-
eta0=eta0,
65-
early_stopping=False,
66-
class_weight=label_weights,
67-
)
68-
69-
b_first_train = True
70-
# TODO: template_out
71-
72-
while True:
73-
msg_in: AxisArray | SampleMessage = yield msg_out
74-
75-
msg_out = None
76-
if type(msg_in) is SampleMessage:
77-
# SampleMessage used for training.
78-
if not np.any(np.isnan(msg_in.sample.data)):
79-
train_sample = msg_in.sample.data.reshape(1, -1)
80-
if b_first_train:
81-
model.partial_fit(
82-
train_sample,
83-
[msg_in.trigger.value],
84-
classes=list(label_weights.keys()),
85-
)
86-
b_first_train = False
87-
else:
88-
model.partial_fit(train_sample, [msg_in.trigger.value])
89-
elif msg_in.data.size:
90-
# AxisArray used for inference
91-
if not np.any(np.isnan(msg_in.data)):
92-
try:
93-
X = msg_in.data.reshape((msg_in.data.shape[0], -1))
94-
result = model._predict_proba_lr(X)
95-
except NotFittedError:
96-
result = None
97-
if result is not None:
98-
out_axes = {}
99-
if msg_in.dims[0] in msg_in.axes:
100-
out_axes[msg_in.dims[0]] = replace(
101-
msg_in.axes[msg_in.dims[0]],
102-
offset=msg_in.axes[msg_in.dims[0]].offset,
103-
)
104-
msg_out = ClassifierMessage(
105-
data=result,
106-
dims=msg_in.dims[:1] + ["labels"],
107-
axes=out_axes,
108-
labels=list(model.class_weight.keys()),
109-
key=msg_in.key,
110-
)
111-
112-
11319
class SGDDecoderSettings(ez.Settings):
11420
alpha: float = 1e-5
11521
eta0: float = 3e-4
@@ -118,14 +24,96 @@ class SGDDecoderSettings(ez.Settings):
11824
settings_path: str | None = None
11925

12026

121-
class SGDDecoder(GenAxisArray):
122-
SETTINGS = SGDDecoderSettings
123-
INPUT_SAMPLE = ez.InputStream(SampleMessage)
27+
@processor_state
28+
class SGDDecoderState:
29+
model: typing.Any = None
30+
b_first_train: bool = True
12431

125-
# Method to be implemented by subclasses to construct the specific generator
126-
def construct_generator(self):
127-
self.STATE.gen = sgd_decoder(**self.SETTINGS.__dict__)
12832

129-
@ez.subscriber(INPUT_SAMPLE)
130-
async def on_sample(self, msg: SampleMessage) -> None:
131-
_ = self.STATE.gen.send(msg)
33+
class SGDDecoderTransformer(BaseAdaptiveTransformer[SGDDecoderSettings, AxisArray, ClassifierMessage, SGDDecoderState]):
34+
"""
35+
SGD-based online classifier.
36+
37+
Online Passive-Aggressive Algorithms
38+
<http://jmlr.csail.mit.edu/papers/volume7/crammer06a/crammer06a.pdf>
39+
K. Crammer, O. Dekel, J. Keshat, S. Shalev-Shwartz, Y. Singer - JMLR (2006)
40+
"""
41+
42+
def _refreshed_model(self):
43+
if self.settings.settings_path is not None:
44+
import pickle
45+
46+
with open(self.settings.settings_path, "rb") as f:
47+
model = pickle.load(f)
48+
if self.settings.label_weights is not None:
49+
model.class_weight = self.settings.label_weights
50+
model.eta0 = self.settings.eta0
51+
else:
52+
model = SGDClassifier(
53+
loss=self.settings.loss,
54+
alpha=self.settings.alpha,
55+
penalty="elasticnet",
56+
learning_rate="adaptive",
57+
eta0=self.settings.eta0,
58+
early_stopping=False,
59+
class_weight=self.settings.label_weights,
60+
)
61+
return model
62+
63+
def _reset_state(self, message: AxisArray) -> None:
64+
self._state.model = self._refreshed_model()
65+
self._state.b_first_train = True
66+
67+
def _process(self, message: AxisArray) -> ClassifierMessage | None:
68+
if self._state.model is None or not message.data.size:
69+
return None
70+
if np.any(np.isnan(message.data)):
71+
return None
72+
try:
73+
X = message.data.reshape((message.data.shape[0], -1))
74+
result = self._state.model._predict_proba_lr(X)
75+
except NotFittedError:
76+
return None
77+
out_axes = {}
78+
if message.dims[0] in message.axes:
79+
out_axes[message.dims[0]] = replace(
80+
message.axes[message.dims[0]],
81+
offset=message.axes[message.dims[0]].offset,
82+
)
83+
return ClassifierMessage(
84+
data=result,
85+
dims=message.dims[:1] + ["labels"],
86+
axes=out_axes,
87+
labels=list(self._state.model.class_weight.keys()),
88+
key=message.key,
89+
)
90+
91+
def partial_fit(self, message: SampleMessage) -> None:
92+
if self._state.model is None:
93+
# Initialize model on first training sample
94+
self._state.model = self._refreshed_model()
95+
self._state.b_first_train = True
96+
97+
if np.any(np.isnan(message.sample.data)):
98+
return
99+
train_sample = message.sample.data.reshape(1, -1)
100+
if self._state.b_first_train:
101+
self._state.model.partial_fit(
102+
train_sample,
103+
[message.trigger.value],
104+
classes=list(self.settings.label_weights.keys()),
105+
)
106+
self._state.b_first_train = False
107+
else:
108+
self._state.model.partial_fit(train_sample, [message.trigger.value])
109+
110+
111+
class SGDDecoder(
112+
BaseAdaptiveTransformerUnit[
113+
SGDDecoderSettings,
114+
AxisArray,
115+
ClassifierMessage,
116+
SGDDecoderTransformer,
117+
]
118+
):
119+
SETTINGS = SGDDecoderSettings

tests/unit/test_sgd.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from ezmsg.sigproc.sampler import SampleMessage, SampleTriggerMessage
33
from ezmsg.util.messages.axisarray import AxisArray
44

5-
from ezmsg.learn.process.sgd import sgd_decoder
5+
from ezmsg.learn.process.sgd import SGDDecoderSettings, SGDDecoderTransformer
66

77

88
def test_sgd():
@@ -46,14 +46,16 @@ def test_sgd():
4646
"""
4747
label_weights = {k: 1.0 for k in time_idx.keys()}
4848
# Sending an axis array before it has seen any training samples should yield None
49-
gen = sgd_decoder(alpha=1e-3, loss="squared_hinge", label_weights=label_weights)
50-
assert gen.send(windows[0]) is None
49+
decoder = SGDDecoderTransformer(
50+
settings=SGDDecoderSettings(alpha=1e-3, loss="squared_hinge", label_weights=label_weights)
51+
)
52+
assert decoder(windows[0]) is None
5153

5254
# Now let's try training on all samples
5355
for sample in samples:
54-
gen.send(sample)
56+
decoder(sample)
5557
# Then doing inference on all multi-wins
56-
probas = [gen.send(win) for win in windows]
58+
probas = [decoder(win) for win in windows]
5759

5860
# With this easy-to-classify data, accuracy should be 100%
5961
# when we fit all training before predicting any test.
@@ -64,12 +66,14 @@ def test_sgd():
6466
assert np.array_equal(class_ids, expected_ids)
6567

6668
# Try again (new model) but alternate 1 train, 2 test.
67-
gen = sgd_decoder(alpha=1e-3, loss="squared_hinge", label_weights=label_weights)
69+
decoder = SGDDecoderTransformer(
70+
settings=SGDDecoderSettings(alpha=1e-3, loss="squared_hinge", label_weights=label_weights)
71+
)
6872
probas = []
6973
for samp_ix, samp in enumerate(samples):
70-
gen.send(samp)
71-
probas.append(gen.send(windows[samp_ix * 2]))
72-
probas.append(gen.send(windows[samp_ix * 2 + 1]))
74+
decoder(samp)
75+
probas.append(decoder(windows[samp_ix * 2]))
76+
probas.append(decoder(windows[samp_ix * 2 + 1]))
7377
class_ids = []
7478
for cm in probas:
7579
class_ids.extend(np.argmax(cm.data, axis=1).tolist())

0 commit comments

Comments
 (0)