Skip to content

Commit 9f02f2a

Browse files
committed
Add classification docs -- I don't remember where this comes from.
1 parent 4be23cf commit 9f02f2a

2 files changed

Lines changed: 268 additions & 3 deletions

File tree

Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
Real-Time Classification
2+
========================
3+
4+
This guide shows how to use ezmsg-learn for real-time classification in streaming pipelines.
5+
6+
.. contents:: On this page
7+
:local:
8+
:depth: 2
9+
10+
11+
Overview
12+
--------
13+
14+
ezmsg-learn provides machine learning components that integrate with ezmsg pipelines.
15+
Key features include:
16+
17+
- **Pre-trained models**: Load and apply existing classifiers
18+
- **Online learning**: Update models incrementally with streaming data
19+
- **Flexible backends**: Support for scikit-learn, PyTorch, and River models
20+
21+
22+
Available Classifiers
23+
---------------------
24+
25+
ezmsg-learn includes several classifier types:
26+
27+
.. list-table::
28+
:header-rows: 1
29+
:widths: 25 40 35
30+
31+
* - Classifier
32+
- Description
33+
- Use Case
34+
* - ``SLDA``
35+
- Shrinkage Linear Discriminant Analysis
36+
- BCI, small datasets
37+
* - ``SklearnModelUnit``
38+
- Wrapper for any scikit-learn model
39+
- General ML tasks
40+
* - ``SGDClassifier``
41+
- Stochastic Gradient Descent
42+
- Online learning
43+
* - ``MLPUnit``
44+
- Multi-layer Perceptron (PyTorch)
45+
- Complex patterns
46+
47+
48+
Using a Pre-Trained SLDA Classifier
49+
-----------------------------------
50+
51+
The simplest approach is to use a pre-trained model:
52+
53+
.. code-block:: python
54+
55+
from ezmsg.learn.process.slda import SLDA, SLDASettings
56+
57+
classifier = SLDA(
58+
SLDASettings(
59+
settings_path="path/to/trained_model.pkl",
60+
axis="time", # Axis containing samples
61+
)
62+
)
63+
64+
**Input format**: ``AxisArray[time, features]`` where features are flattened from your pipeline.
65+
66+
**Output format**: ``ClassifierMessage[time, classes]`` with class probabilities.
67+
68+
Training an SLDA model (offline):
69+
70+
.. code-block:: python
71+
72+
import pickle
73+
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
74+
75+
# Train offline with your data
76+
X_train = ... # shape: (n_samples, n_features)
77+
y_train = ... # shape: (n_samples,)
78+
79+
lda = LDA(solver="lsqr", shrinkage="auto")
80+
lda.fit(X_train, y_train)
81+
82+
# Save for use in ezmsg
83+
with open("trained_model.pkl", "wb") as f:
84+
pickle.dump(lda, f)
85+
86+
87+
Using Scikit-Learn Models
88+
-------------------------
89+
90+
``SklearnModelUnit`` wraps any scikit-learn compatible model:
91+
92+
.. code-block:: python
93+
94+
from ezmsg.learn.process.sklearn import SklearnModelUnit, SklearnModelSettings
95+
import numpy as np
96+
97+
classifier = SklearnModelUnit(
98+
SklearnModelSettings(
99+
model_class="sklearn.linear_model.SGDClassifier",
100+
model_kwargs={
101+
"loss": "log_loss", # For probability outputs
102+
"warm_start": True,
103+
},
104+
partial_fit_classes=np.array([0, 1]), # Required for online learning
105+
)
106+
)
107+
108+
Loading a pre-trained model:
109+
110+
.. code-block:: python
111+
112+
classifier = SklearnModelUnit(
113+
SklearnModelSettings(
114+
model_class="sklearn.linear_model.SGDClassifier",
115+
checkpoint_path="path/to/saved_model.pkl",
116+
)
117+
)
118+
119+
120+
Online Learning
121+
---------------
122+
123+
For models that support ``partial_fit``, you can update them during streaming:
124+
125+
.. code-block:: python
126+
127+
from ezmsg.learn.process.sklearn import SklearnModelProcessor, SklearnModelSettings
128+
from ezmsg.sigproc.sampler import SampleMessage
129+
130+
# Create processor with online learning support
131+
processor = SklearnModelProcessor(
132+
settings=SklearnModelSettings(
133+
model_class="sklearn.linear_model.SGDClassifier",
134+
model_kwargs={"loss": "log_loss"},
135+
partial_fit_classes=np.array([0, 1]),
136+
)
137+
)
138+
139+
# Training with labeled samples
140+
sample_msg = SampleMessage(
141+
sample=feature_array, # AxisArray with features
142+
trigger=label_value, # The class label
143+
)
144+
processor.partial_fit(sample_msg)
145+
146+
# Prediction (after training)
147+
prediction = processor(input_features)
148+
149+
150+
Complete Pipeline Example
151+
-------------------------
152+
153+
Here's a complete BCI classification pipeline:
154+
155+
.. code-block:: python
156+
157+
import ezmsg.core as ez
158+
from ezmsg.lsl.inlet import LSLInletUnit, LSLInletSettings, LSLInfo
159+
from ezmsg.lsl.outlet import LSLOutletUnit, LSLOutletSettings
160+
from ezmsg.sigproc.butterworthfilter import ButterworthFilter, ButterworthFilterSettings
161+
from ezmsg.sigproc.window import Window, WindowSettings
162+
from ezmsg.sigproc.spectrum import Spectrum, SpectrumSettings
163+
from ezmsg.sigproc.aggregate import RangedAggregate, RangedAggregateSettings, AggregationFunction
164+
from ezmsg.learn.process.slda import SLDA, SLDASettings
165+
166+
components = {
167+
# Data acquisition
168+
"LSL_IN": LSLInletUnit(
169+
LSLInletSettings(info=LSLInfo(name="EEG", type="EEG"))
170+
),
171+
172+
# Signal processing
173+
"FILTER": ButterworthFilter(
174+
ButterworthFilterSettings(order=4, cuton=8.0, cutoff=30.0)
175+
),
176+
"WINDOW": Window(
177+
WindowSettings(window_dur=1.0, window_shift=0.5)
178+
),
179+
"SPECTRUM": Spectrum(SpectrumSettings(window="hann")),
180+
"BANDPOWER": RangedAggregate(
181+
RangedAggregateSettings(
182+
axis="freq",
183+
bands=[(8.0, 12.0), (18.0, 25.0)],
184+
operation=AggregationFunction.MEAN,
185+
)
186+
),
187+
188+
# Classification
189+
"CLASSIFIER": SLDA(
190+
SLDASettings(settings_path="model.pkl", axis="time")
191+
),
192+
193+
# Output
194+
"LSL_OUT": LSLOutletUnit(
195+
LSLOutletSettings(stream_name="Predictions", stream_type="Markers")
196+
),
197+
}
198+
199+
connections = (
200+
(components["LSL_IN"].OUTPUT_SIGNAL, components["FILTER"].INPUT_SIGNAL),
201+
(components["FILTER"].OUTPUT_SIGNAL, components["WINDOW"].INPUT_SIGNAL),
202+
(components["WINDOW"].OUTPUT_SIGNAL, components["SPECTRUM"].INPUT_SIGNAL),
203+
(components["SPECTRUM"].OUTPUT_SIGNAL, components["BANDPOWER"].INPUT_SIGNAL),
204+
(components["BANDPOWER"].OUTPUT_SIGNAL, components["CLASSIFIER"].INPUT_SIGNAL),
205+
(components["CLASSIFIER"].OUTPUT_SIGNAL, components["LSL_OUT"].INPUT_SIGNAL),
206+
)
207+
208+
if __name__ == "__main__":
209+
ez.run(components=components, connections=connections)
210+
211+
212+
Feature Preparation
213+
-------------------
214+
215+
Classifiers expect flattened 2D input ``[samples, features]``. Multi-dimensional arrays
216+
are automatically flattened along the channel dimension.
217+
218+
For example, if your bandpower output is ``[time=1, band=2, ch=8]``:
219+
220+
- The classifier receives shape ``[1, 16]`` (2 bands × 8 channels)
221+
- Features are flattened in C-order (row-major)
222+
223+
224+
Output Format
225+
-------------
226+
227+
Classification outputs use ``ClassifierMessage``, which extends ``AxisArray`` with:
228+
229+
- **dims**: ``["time", "classes"]``
230+
- **data**: Probability scores for each class
231+
- **labels**: List of class names/identifiers
232+
233+
Example output shape: ``[time=1, classes=2]`` with probabilities for each class.
234+
235+
236+
Tips for Better Performance
237+
---------------------------
238+
239+
1. **Normalize features**: Use ``Scaler`` from ezmsg-sigproc before classification
240+
241+
.. code-block:: python
242+
243+
from ezmsg.sigproc.scaler import Scaler, ScalerSettings
244+
scaler = Scaler(ScalerSettings(mode="zscore"))
245+
246+
2. **Match training conditions**: Ensure online features match offline training preprocessing
247+
248+
3. **Window size**: Larger windows give more stable features but higher latency
249+
250+
4. **Feature selection**: Start with relevant frequency bands for your application
251+
252+
253+
Troubleshooting
254+
---------------
255+
256+
**"Model has not been fit yet"**:
257+
The model needs training data before prediction. Either:
258+
- Provide a ``checkpoint_path`` with a pre-trained model
259+
- Call ``fit()`` or ``partial_fit()`` before processing
260+
261+
**Shape mismatch errors**:
262+
- Verify input feature dimensions match trained model
263+
- Check ``n_features_in_`` attribute of loaded models
264+
265+
**NaN in predictions**:
266+
- Ensure input features don't contain NaN values
267+
- Check for numerical stability in preprocessing

docs/source/index.rst

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,11 @@ Quick Start
4949

5050
For general ezmsg tutorials and guides, visit `ezmsg.org <https://www.ezmsg.org>`_.
5151

52-
For package-specific examples and usage, see the :doc:`api/index` documentation.
53-
5452
.. toctree::
5553
:maxdepth: 2
56-
:hidden:
5754
:caption: Contents:
5855

56+
guides/classification
5957
api/index
6058

6159

0 commit comments

Comments
 (0)