forked from mne-tools/mne-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot_epoch_quality.py
More file actions
109 lines (87 loc) · 3.5 KB
/
plot_epoch_quality.py
File metadata and controls
109 lines (87 loc) · 3.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""
.. _ex-epoch-quality:
=====================================
Exploring epoch quality before rejection
=====================================
Before rejecting epochs with :meth:`mne.Epochs.drop_bad`, it can be useful
to get a sense of which epochs are the most likely artifacts. This example
shows how to compute simple per-epoch statistics — peak-to-peak amplitude,
variance, and kurtosis — and use them to rank epochs by their outlier score.
The approach is inspired by established methods in the EEG artifact detection
literature, namely FASTER (Nolan et al., 2010) and Delorme et al. (2007), both
of which use z-scored kurtosis and variance across epochs to flag bad trials.
References
----------
.. [1] Nolan, H., Whelan, R., & Reilly, R. B. (2010). FASTER: Fully Automated
Statistical Thresholding for EEG artifact Rejection.
Journal of Neuroscience Methods, 192(1), 152-162.
.. [2] Delorme, A., Sejnowski, T., & Makeig, S. (2007). Enhanced detection of
artifacts in EEG data using higher-order statistics and independent
component analysis. NeuroImage, 34(4), 1443-1449.
"""
# Authors: Aman Srivastava
#
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
# %%
import matplotlib.pyplot as plt
import numpy as np
import mne
from mne.datasets import sample
print(__doc__)
data_path = sample.data_path()
# %%
# Load the sample dataset and create epochs
meg_path = data_path / "MEG" / "sample"
raw_fname = meg_path / "sample_audvis_filt-0-40_raw.fif"
raw = mne.io.read_raw_fif(raw_fname, preload=True)
events = mne.find_events(raw, "STI 014")
event_id = {"auditory/left": 1, "auditory/right": 2}
tmin, tmax = -0.2, 0.5
picks = mne.pick_types(raw.info, meg="grad", eeg=False)
epochs = mne.Epochs(
raw, events, event_id, tmin, tmax, picks=picks, preload=True, baseline=(None, 0)
)
# %%
# Compute per-epoch statistics
# We compute three features for each epoch:
# - Peak-to-peak amplitude (sensitive to large jumps)
# - Variance (sensitive to sustained high-amplitude noise)
# - Kurtosis (sensitive to spike artifacts)
#
# Each feature is z-scored robustly using median absolute deviation (MAD)
# across epochs, then averaged into a single outlier score per epoch.
data = epochs.get_data() # (n_epochs, n_channels, n_times)
# Feature 1: peak-to-peak
ptp = np.ptp(data, axis=-1).mean(axis=-1)
# Feature 2: variance
var = data.var(axis=-1).mean(axis=-1)
# Feature 3: kurtosis
from scipy.stats import kurtosis # noqa: E402
kurt = np.array([kurtosis(data[i].ravel()) for i in range(len(data))])
# Robust z-score using MAD
features = np.column_stack([ptp, var, kurt])
median = np.median(features, axis=0)
mad = np.median(np.abs(features - median), axis=0) + 1e-10
z = np.abs((features - median) / mad)
# Normalize to [0, 1]
raw_score = z.mean(axis=-1)
scores = (raw_score - raw_score.min()) / (raw_score.max() - raw_score.min() + 1e-10)
# %%
# Plot the scores ranked from cleanest to noisiest
fig, ax = plt.subplots(layout="constrained")
sorted_idx = np.argsort(scores)
ax.bar(np.arange(len(scores)), scores[sorted_idx], color="steelblue")
ax.axhline(0.8, color="red", linestyle="--", label="Example threshold (0.8)")
ax.set(
xlabel="Epoch (sorted by score)",
ylabel="Outlier score",
title="Epoch quality scores (0 = clean, 1 = likely artifact)",
)
ax.legend()
# %%
# Inspect the worst epochs
# Epochs scoring above 0.8 are worth inspecting manually
bad_epochs = np.where(scores > 0.8)[0]
print(f"Epochs worth inspecting: {bad_epochs}")
print(f"That's {len(bad_epochs)} out of {len(epochs)} total epochs")