Skip to content

Commit 84cabed

Browse files
Aniketsytsbinns
andauthored
ENH: Add native save and read support for SSD (#13718)
Co-authored-by: Thomas S. Binns <t.s.binns@outlook.com>
1 parent ba12118 commit 84cabed

10 files changed

Lines changed: 516 additions & 14 deletions

File tree

doc/api/decoding.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,7 @@ Functions that assist with decoding and model fitting:
4141
cross_val_multiscore
4242
get_coef
4343
get_spatial_filter_from_estimator
44+
read_csp
45+
read_spoc
46+
read_ssd
47+
read_xdawn_transformer

doc/changes/dev/13718.other.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add native save and read functionality for :class:`mne.decoding.SSD`, :class:`mne.decoding.CSP`, :class:`mne.decoding.SPoC`, and :class:`mne.decoding.XdawnTransformer` objects, by `Aniket Singh Yadav`_.

mne/decoding/__init__.pyi

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ __all__ = [
1010
"SPoC",
1111
"SSD",
1212
"Scaler",
13+
"read_csp",
14+
"read_spoc",
15+
"read_ssd",
16+
"read_xdawn_transformer",
1317
"SlidingEstimator",
1418
"SpatialFilter",
1519
"TemporalFilter",
@@ -31,12 +35,12 @@ from .base import (
3135
cross_val_multiscore,
3236
get_coef,
3337
)
34-
from .csp import CSP, SPoC
38+
from .csp import CSP, SPoC, read_csp, read_spoc
3539
from .ems import EMS, compute_ems
3640
from .receptive_field import ReceptiveField
3741
from .search_light import GeneralizingEstimator, SlidingEstimator
3842
from .spatial_filter import SpatialFilter, get_spatial_filter_from_estimator
39-
from .ssd import SSD
43+
from .ssd import SSD, read_ssd
4044
from .time_delaying_ridge import TimeDelayingRidge
4145
from .time_frequency import TimeFrequency
4246
from .transformer import (
@@ -47,4 +51,4 @@ from .transformer import (
4751
UnsupervisedSpatialFilter,
4852
Vectorizer,
4953
)
50-
from .xdawn import XdawnTransformer
54+
from .xdawn import XdawnTransformer, read_xdawn_transformer

mne/decoding/base.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,20 @@
2525
from sklearn.utils import indexable
2626
from sklearn.utils.validation import check_is_fitted
2727

28+
from .._fiff.meas_info import Info
2829
from ..parallel import parallel_func
2930
from ..utils import (
31+
_check_fname,
3032
_check_option,
33+
_import_h5io_funcs,
3134
_pl,
3235
_validate_type,
3336
logger,
3437
pinv,
3538
verbose,
3639
warn,
3740
)
41+
from ..utils.check import check_fname
3842
from ._fixes import validate_data
3943
from ._ged import (
4044
_handle_restr_mat,
@@ -135,6 +139,65 @@ def __init_subclass__(cls, **kwargs):
135139
super().__init_subclass__(**kwargs)
136140
cls._is_base_ged = False
137141

142+
def __getstate__(self):
143+
"""Prepare state for serialization."""
144+
state = self.__dict__.copy()
145+
state.pop("cov_callable", None)
146+
state.pop("mod_ged_callable", None)
147+
state.pop("R_func", None)
148+
return state
149+
150+
def _restore_callables(self):
151+
"""Restore callables after loading serialized state."""
152+
# Expected to be implemented in child classes
153+
pass
154+
155+
def __setstate__(self, state):
156+
"""Restore state from serialization."""
157+
required_state_keys = getattr(self, "_required_state_keys", ())
158+
missing = [k for k in required_state_keys if k not in state]
159+
if missing:
160+
raise ValueError(
161+
f"Cannot read file as type {type(self).__name__}, "
162+
f"it is missing required keys: {missing}. "
163+
"Please report this to MNE developers "
164+
"(https://github.com/mne-tools/mne-python/issues/new) "
165+
"and include a copy of this entire traceback and a link "
166+
"to the problematic file with your report."
167+
)
168+
if "info" in state and isinstance(state["info"], dict):
169+
state["info"] = Info(**state["info"])
170+
self.__dict__.update(state)
171+
self._restore_callables()
172+
173+
@verbose
174+
def save(self, fname, *, overwrite=False, verbose=None):
175+
"""Save the object to disk (in HDF5 format).
176+
177+
Parameters
178+
----------
179+
fname : path-like
180+
The file path to save to. Should end with ``'.h5'`` or
181+
``'.hdf5'``.
182+
%(overwrite)s
183+
%(verbose)s
184+
185+
Notes
186+
-----
187+
.. versionadded:: 1.12
188+
"""
189+
_, write_hdf5 = _import_h5io_funcs()
190+
class_name = getattr(self, "_save_fname_type", type(self).__name__.lower())
191+
check_fname(fname, class_name, (".h5", ".hdf5"))
192+
fname = _check_fname(fname, overwrite=overwrite, verbose=verbose)
193+
write_hdf5(
194+
fname,
195+
self.__getstate__(),
196+
overwrite=overwrite,
197+
title="mnepython",
198+
slash="replace",
199+
)
200+
138201
def fit(self, X, y=None):
139202
"""..."""
140203
# Let the inheriting transformers check data by themselves
@@ -338,6 +401,18 @@ def __sklearn_tags__(self):
338401
return tags
339402

340403

404+
@verbose
405+
def _read_ged(fname, ged_class, *, verbose=None):
406+
"""Load a saved GED transformer object from disk."""
407+
read_hdf5, _ = _import_h5io_funcs()
408+
_validate_type(fname, "path-like", "fname")
409+
fname = _check_fname(fname, overwrite=True, must_exist=True, verbose=verbose)
410+
state = read_hdf5(fname, title="mnepython", slash="replace")
411+
inst = object.__new__(ged_class)
412+
inst.__setstate__(state)
413+
return inst
414+
415+
341416
class LinearModel(MetaEstimatorMixin, BaseEstimator):
342417
"""Compute and store patterns from linear models.
343418

mne/decoding/csp.py

Lines changed: 117 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,16 @@
99

1010
from .._fiff.meas_info import Info
1111
from ..defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT
12-
from ..utils import _check_option, _validate_type, fill_doc, legacy
12+
from ..utils import (
13+
_check_option,
14+
_validate_type,
15+
fill_doc,
16+
legacy,
17+
verbose,
18+
)
1319
from ._covs_ged import _csp_estimate, _spoc_estimate
1420
from ._mod_ged import _csp_mod, _spoc_mod
15-
from .base import _GEDTransformer
21+
from .base import _GEDTransformer, _read_ged
1622
from .spatial_filter import get_spatial_filter_from_estimator
1723

1824

@@ -160,13 +166,43 @@ def __init__(
160166
R_func=sum,
161167
)
162168

169+
_save_fname_type = "csp"
170+
163171
def __sklearn_tags__(self):
164172
"""Tag the transformer."""
165173
tags = super().__sklearn_tags__()
166174
tags.target_tags.required = True
167175
tags.target_tags.multi_output = True
168176
return tags
169177

178+
_required_state_keys = (
179+
"component_order",
180+
"cov_est",
181+
"cov_method_params",
182+
"info",
183+
"log",
184+
"n_components",
185+
"norm_trace",
186+
"rank",
187+
"reg",
188+
"restr_type",
189+
"transform_into",
190+
)
191+
192+
def _restore_callables(self):
193+
"""Restore CSP-specific callables after loading state."""
194+
self.cov_callable = partial(
195+
_csp_estimate,
196+
reg=self.reg,
197+
cov_method_params=self.cov_method_params,
198+
cov_est=self.cov_est,
199+
info=self.info,
200+
rank=self.rank,
201+
norm_trace=self.norm_trace,
202+
)
203+
self.mod_ged_callable = partial(_csp_mod, evecs_order=self.component_order)
204+
self.R_func = sum
205+
170206
def _validate_params(self, *, y):
171207
_validate_type(self.n_components, int, "n_components")
172208
if hasattr(self, "cov_est"):
@@ -766,12 +802,37 @@ def __init__(
766802
delattr(self, "cov_est")
767803
delattr(self, "norm_trace")
768804

805+
_save_fname_type = "spoc"
806+
769807
def __sklearn_tags__(self):
770808
"""Tag the transformer."""
771809
tags = super().__sklearn_tags__()
772810
tags.target_tags.multi_output = False
773811
return tags
774812

813+
_required_state_keys = (
814+
"cov_method_params",
815+
"info",
816+
"log",
817+
"n_components",
818+
"rank",
819+
"reg",
820+
"restr_type",
821+
"transform_into",
822+
)
823+
824+
def _restore_callables(self):
825+
"""Restore SPoC-specific callables after loading state."""
826+
self.cov_callable = partial(
827+
_spoc_estimate,
828+
reg=self.reg,
829+
cov_method_params=self.cov_method_params,
830+
info=self.info,
831+
rank=self.rank,
832+
)
833+
self.mod_ged_callable = _spoc_mod
834+
self.R_func = None
835+
775836
def fit(self, X, y):
776837
"""Estimate the SPoC decomposition on epochs.
777838
@@ -848,3 +909,57 @@ def fit_transform(self, X, y=None, **fit_params):
848909
"""
849910
# use parent TransformerMixin method but with custom docstring
850911
return super().fit_transform(X, y=y, **fit_params)
912+
913+
914+
@verbose
915+
def read_csp(fname, *, verbose=None):
916+
"""Load a saved :class:`mne.decoding.CSP` object from disk.
917+
918+
Parameters
919+
----------
920+
fname : path-like
921+
Path to a CSP file in HDF5 format, which should end with ``.h5`` or
922+
``.hdf5``.
923+
%(verbose)s
924+
925+
Returns
926+
-------
927+
csp : instance of :class:`~mne.decoding.CSP`
928+
The loaded CSP object with all fitted attributes restored.
929+
930+
See Also
931+
--------
932+
mne.decoding.CSP.save
933+
934+
Notes
935+
-----
936+
.. versionadded:: 1.12
937+
"""
938+
return _read_ged(fname, CSP, verbose=verbose)
939+
940+
941+
@verbose
942+
def read_spoc(fname, *, verbose=None):
943+
"""Load a saved :class:`mne.decoding.SPoC` object from disk.
944+
945+
Parameters
946+
----------
947+
fname : path-like
948+
Path to a SPoC file in HDF5 format, which should end with ``.h5`` or
949+
``.hdf5``.
950+
%(verbose)s
951+
952+
Returns
953+
-------
954+
spoc : instance of :class:`~mne.decoding.SPoC`
955+
The loaded SPoC object with all fitted attributes restored.
956+
957+
See Also
958+
--------
959+
mne.decoding.SPoC.save
960+
961+
Notes
962+
-----
963+
.. versionadded:: 1.12
964+
"""
965+
return _read_ged(fname, SPoC, verbose=verbose)

0 commit comments

Comments
 (0)