|
9 | 9 |
|
10 | 10 | from .._fiff.meas_info import Info |
11 | 11 | from ..defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT |
12 | | -from ..utils import _check_option, _validate_type, fill_doc, legacy |
| 12 | +from ..utils import ( |
| 13 | + _check_option, |
| 14 | + _validate_type, |
| 15 | + fill_doc, |
| 16 | + legacy, |
| 17 | + verbose, |
| 18 | +) |
13 | 19 | from ._covs_ged import _csp_estimate, _spoc_estimate |
14 | 20 | from ._mod_ged import _csp_mod, _spoc_mod |
15 | | -from .base import _GEDTransformer |
| 21 | +from .base import _GEDTransformer, _read_ged |
16 | 22 | from .spatial_filter import get_spatial_filter_from_estimator |
17 | 23 |
|
18 | 24 |
|
@@ -160,13 +166,43 @@ def __init__( |
160 | 166 | R_func=sum, |
161 | 167 | ) |
162 | 168 |
|
| 169 | + _save_fname_type = "csp" |
| 170 | + |
163 | 171 | def __sklearn_tags__(self): |
164 | 172 | """Tag the transformer.""" |
165 | 173 | tags = super().__sklearn_tags__() |
166 | 174 | tags.target_tags.required = True |
167 | 175 | tags.target_tags.multi_output = True |
168 | 176 | return tags |
169 | 177 |
|
| 178 | + _required_state_keys = ( |
| 179 | + "component_order", |
| 180 | + "cov_est", |
| 181 | + "cov_method_params", |
| 182 | + "info", |
| 183 | + "log", |
| 184 | + "n_components", |
| 185 | + "norm_trace", |
| 186 | + "rank", |
| 187 | + "reg", |
| 188 | + "restr_type", |
| 189 | + "transform_into", |
| 190 | + ) |
| 191 | + |
| 192 | + def _restore_callables(self): |
| 193 | + """Restore CSP-specific callables after loading state.""" |
| 194 | + self.cov_callable = partial( |
| 195 | + _csp_estimate, |
| 196 | + reg=self.reg, |
| 197 | + cov_method_params=self.cov_method_params, |
| 198 | + cov_est=self.cov_est, |
| 199 | + info=self.info, |
| 200 | + rank=self.rank, |
| 201 | + norm_trace=self.norm_trace, |
| 202 | + ) |
| 203 | + self.mod_ged_callable = partial(_csp_mod, evecs_order=self.component_order) |
| 204 | + self.R_func = sum |
| 205 | + |
170 | 206 | def _validate_params(self, *, y): |
171 | 207 | _validate_type(self.n_components, int, "n_components") |
172 | 208 | if hasattr(self, "cov_est"): |
@@ -766,12 +802,37 @@ def __init__( |
766 | 802 | delattr(self, "cov_est") |
767 | 803 | delattr(self, "norm_trace") |
768 | 804 |
|
| 805 | + _save_fname_type = "spoc" |
| 806 | + |
769 | 807 | def __sklearn_tags__(self): |
770 | 808 | """Tag the transformer.""" |
771 | 809 | tags = super().__sklearn_tags__() |
772 | 810 | tags.target_tags.multi_output = False |
773 | 811 | return tags |
774 | 812 |
|
| 813 | + _required_state_keys = ( |
| 814 | + "cov_method_params", |
| 815 | + "info", |
| 816 | + "log", |
| 817 | + "n_components", |
| 818 | + "rank", |
| 819 | + "reg", |
| 820 | + "restr_type", |
| 821 | + "transform_into", |
| 822 | + ) |
| 823 | + |
| 824 | + def _restore_callables(self): |
| 825 | + """Restore SPoC-specific callables after loading state.""" |
| 826 | + self.cov_callable = partial( |
| 827 | + _spoc_estimate, |
| 828 | + reg=self.reg, |
| 829 | + cov_method_params=self.cov_method_params, |
| 830 | + info=self.info, |
| 831 | + rank=self.rank, |
| 832 | + ) |
| 833 | + self.mod_ged_callable = _spoc_mod |
| 834 | + self.R_func = None |
| 835 | + |
775 | 836 | def fit(self, X, y): |
776 | 837 | """Estimate the SPoC decomposition on epochs. |
777 | 838 |
|
@@ -848,3 +909,57 @@ def fit_transform(self, X, y=None, **fit_params): |
848 | 909 | """ |
849 | 910 | # use parent TransformerMixin method but with custom docstring |
850 | 911 | return super().fit_transform(X, y=y, **fit_params) |
| 912 | + |
| 913 | + |
| 914 | +@verbose |
| 915 | +def read_csp(fname, *, verbose=None): |
| 916 | + """Load a saved :class:`mne.decoding.CSP` object from disk. |
| 917 | +
|
| 918 | + Parameters |
| 919 | + ---------- |
| 920 | + fname : path-like |
| 921 | + Path to a CSP file in HDF5 format, which should end with ``.h5`` or |
| 922 | + ``.hdf5``. |
| 923 | + %(verbose)s |
| 924 | +
|
| 925 | + Returns |
| 926 | + ------- |
| 927 | + csp : instance of :class:`~mne.decoding.CSP` |
| 928 | + The loaded CSP object with all fitted attributes restored. |
| 929 | +
|
| 930 | + See Also |
| 931 | + -------- |
| 932 | + mne.decoding.CSP.save |
| 933 | +
|
| 934 | + Notes |
| 935 | + ----- |
| 936 | + .. versionadded:: 1.12 |
| 937 | + """ |
| 938 | + return _read_ged(fname, CSP, verbose=verbose) |
| 939 | + |
| 940 | + |
| 941 | +@verbose |
| 942 | +def read_spoc(fname, *, verbose=None): |
| 943 | + """Load a saved :class:`mne.decoding.SPoC` object from disk. |
| 944 | +
|
| 945 | + Parameters |
| 946 | + ---------- |
| 947 | + fname : path-like |
| 948 | + Path to a SPoC file in HDF5 format, which should end with ``.h5`` or |
| 949 | + ``.hdf5``. |
| 950 | + %(verbose)s |
| 951 | +
|
| 952 | + Returns |
| 953 | + ------- |
| 954 | + spoc : instance of :class:`~mne.decoding.SPoC` |
| 955 | + The loaded SPoC object with all fitted attributes restored. |
| 956 | +
|
| 957 | + See Also |
| 958 | + -------- |
| 959 | + mne.decoding.SPoC.save |
| 960 | +
|
| 961 | + Notes |
| 962 | + ----- |
| 963 | + .. versionadded:: 1.12 |
| 964 | + """ |
| 965 | + return _read_ged(fname, SPoC, verbose=verbose) |
0 commit comments