22
33import ezmsg .core as ez
44import numpy as np
5- from ezmsg .baseproc import GenAxisArray
6- from ezmsg .sigproc .sampler import SampleMessage
7- from ezmsg .util .generator import consumer
5+ from ezmsg .baseproc import (
6+ BaseAdaptiveTransformer ,
7+ BaseAdaptiveTransformerUnit ,
8+ SampleMessage ,
9+ processor_state ,
10+ )
811from ezmsg .util .messages .axisarray import AxisArray
912from ezmsg .util .messages .util import replace
1013from sklearn .exceptions import NotFittedError
1316from ..util import ClassifierMessage
1417
1518
16- @consumer
17- def sgd_decoder (
18- alpha : float = 1.5e-5 ,
19- eta0 : float = 1e-7 , # Lower than what you'd use for offline training.
20- loss : str = "squared_hinge" ,
21- label_weights : dict [str , float ] | None = None ,
22- settings_path : str | None = None ,
23- ) -> typing .Generator [AxisArray | SampleMessage , ClassifierMessage | None , None ]:
24- """
25- Passive Aggressive Classifier
26- Online Passive-Aggressive Algorithms <http://jmlr.csail.mit.edu/papers/volume7/crammer06a/crammer06a.pdf>
27- K. Crammer, O. Dekel, J. Keshat, S. Shalev-Shwartz, Y. Singer - JMLR (2006)
28-
29- Args:
30- alpha: Maximum step size (regularization)
31- eta0: The initial learning rate for the 'adaptive’ schedules.
32- loss: The loss function to be used:
33- hinge: equivalent to PA-I in the reference paper.
34- squared_hinge: equivalent to PA-II in the reference paper.
35- label_weights: An optional dictionary of label names and their relative weight.
36- e.g., {'Go': 31.0, 'Stop': 0.5}
37- If this is None then settings_path must be provided and the pre-trained model
38- settings_path: Path to the stored sklearn model pkl file.
39-
40- Returns:
41- Generator that accepts `SampleMessage` for incremental training (`partial_fit`) and yields None,
42- or `AxisArray` for inference (`predict`) and yields a `ClassifierMessage`.
43- """
44- # pre-init inputs and outputs
45- msg_out = ClassifierMessage (data = np .array ([]), dims = ["" ])
46-
47- # State variables:
48-
49- if settings_path is not None :
50- import pickle
51-
52- with open (settings_path , "rb" ) as f :
53- model = pickle .load (f )
54- if label_weights is not None :
55- model .class_weight = label_weights
56- # Overwrite eta0, probably with a value lower than what was used online.
57- model .eta0 = eta0
58- else :
59- model = SGDClassifier (
60- loss = loss ,
61- alpha = alpha ,
62- penalty = "elasticnet" ,
63- learning_rate = "adaptive" ,
64- eta0 = eta0 ,
65- early_stopping = False ,
66- class_weight = label_weights ,
67- )
68-
69- b_first_train = True
70- # TODO: template_out
71-
72- while True :
73- msg_in : AxisArray | SampleMessage = yield msg_out
74-
75- msg_out = None
76- if type (msg_in ) is SampleMessage :
77- # SampleMessage used for training.
78- if not np .any (np .isnan (msg_in .sample .data )):
79- train_sample = msg_in .sample .data .reshape (1 , - 1 )
80- if b_first_train :
81- model .partial_fit (
82- train_sample ,
83- [msg_in .trigger .value ],
84- classes = list (label_weights .keys ()),
85- )
86- b_first_train = False
87- else :
88- model .partial_fit (train_sample , [msg_in .trigger .value ])
89- elif msg_in .data .size :
90- # AxisArray used for inference
91- if not np .any (np .isnan (msg_in .data )):
92- try :
93- X = msg_in .data .reshape ((msg_in .data .shape [0 ], - 1 ))
94- result = model ._predict_proba_lr (X )
95- except NotFittedError :
96- result = None
97- if result is not None :
98- out_axes = {}
99- if msg_in .dims [0 ] in msg_in .axes :
100- out_axes [msg_in .dims [0 ]] = replace (
101- msg_in .axes [msg_in .dims [0 ]],
102- offset = msg_in .axes [msg_in .dims [0 ]].offset ,
103- )
104- msg_out = ClassifierMessage (
105- data = result ,
106- dims = msg_in .dims [:1 ] + ["labels" ],
107- axes = out_axes ,
108- labels = list (model .class_weight .keys ()),
109- key = msg_in .key ,
110- )
111-
112-
11319class SGDDecoderSettings (ez .Settings ):
11420 alpha : float = 1e-5
11521 eta0 : float = 3e-4
@@ -118,14 +24,96 @@ class SGDDecoderSettings(ez.Settings):
11824 settings_path : str | None = None
11925
12026
121- class SGDDecoder (GenAxisArray ):
122- SETTINGS = SGDDecoderSettings
123- INPUT_SAMPLE = ez .InputStream (SampleMessage )
27+ @processor_state
28+ class SGDDecoderState :
29+ model : typing .Any = None
30+ b_first_train : bool = True
12431
125- # Method to be implemented by subclasses to construct the specific generator
126- def construct_generator (self ):
127- self .STATE .gen = sgd_decoder (** self .SETTINGS .__dict__ )
12832
129- @ez .subscriber (INPUT_SAMPLE )
130- async def on_sample (self , msg : SampleMessage ) -> None :
131- _ = self .STATE .gen .send (msg )
33+ class SGDDecoderTransformer (BaseAdaptiveTransformer [SGDDecoderSettings , AxisArray , ClassifierMessage , SGDDecoderState ]):
34+ """
35+ SGD-based online classifier.
36+
37+ Online Passive-Aggressive Algorithms
38+ <http://jmlr.csail.mit.edu/papers/volume7/crammer06a/crammer06a.pdf>
39+ K. Crammer, O. Dekel, J. Keshat, S. Shalev-Shwartz, Y. Singer - JMLR (2006)
40+ """
41+
42+ def _refreshed_model (self ):
43+ if self .settings .settings_path is not None :
44+ import pickle
45+
46+ with open (self .settings .settings_path , "rb" ) as f :
47+ model = pickle .load (f )
48+ if self .settings .label_weights is not None :
49+ model .class_weight = self .settings .label_weights
50+ model .eta0 = self .settings .eta0
51+ else :
52+ model = SGDClassifier (
53+ loss = self .settings .loss ,
54+ alpha = self .settings .alpha ,
55+ penalty = "elasticnet" ,
56+ learning_rate = "adaptive" ,
57+ eta0 = self .settings .eta0 ,
58+ early_stopping = False ,
59+ class_weight = self .settings .label_weights ,
60+ )
61+ return model
62+
63+ def _reset_state (self , message : AxisArray ) -> None :
64+ self ._state .model = self ._refreshed_model ()
65+ self ._state .b_first_train = True
66+
67+ def _process (self , message : AxisArray ) -> ClassifierMessage | None :
68+ if self ._state .model is None or not message .data .size :
69+ return None
70+ if np .any (np .isnan (message .data )):
71+ return None
72+ try :
73+ X = message .data .reshape ((message .data .shape [0 ], - 1 ))
74+ result = self ._state .model ._predict_proba_lr (X )
75+ except NotFittedError :
76+ return None
77+ out_axes = {}
78+ if message .dims [0 ] in message .axes :
79+ out_axes [message .dims [0 ]] = replace (
80+ message .axes [message .dims [0 ]],
81+ offset = message .axes [message .dims [0 ]].offset ,
82+ )
83+ return ClassifierMessage (
84+ data = result ,
85+ dims = message .dims [:1 ] + ["labels" ],
86+ axes = out_axes ,
87+ labels = list (self ._state .model .class_weight .keys ()),
88+ key = message .key ,
89+ )
90+
91+ def partial_fit (self , message : SampleMessage ) -> None :
92+ if self ._state .model is None :
93+ # Initialize model on first training sample
94+ self ._state .model = self ._refreshed_model ()
95+ self ._state .b_first_train = True
96+
97+ if np .any (np .isnan (message .sample .data )):
98+ return
99+ train_sample = message .sample .data .reshape (1 , - 1 )
100+ if self ._state .b_first_train :
101+ self ._state .model .partial_fit (
102+ train_sample ,
103+ [message .trigger .value ],
104+ classes = list (self .settings .label_weights .keys ()),
105+ )
106+ self ._state .b_first_train = False
107+ else :
108+ self ._state .model .partial_fit (train_sample , [message .trigger .value ])
109+
110+
111+ class SGDDecoder (
112+ BaseAdaptiveTransformerUnit [
113+ SGDDecoderSettings ,
114+ AxisArray ,
115+ ClassifierMessage ,
116+ SGDDecoderTransformer ,
117+ ]
118+ ):
119+ SETTINGS = SGDDecoderSettings
0 commit comments