forked from mne-tools/mne-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot_epoch_quality.py
More file actions
114 lines (97 loc) · 4.17 KB
/
plot_epoch_quality.py
File metadata and controls
114 lines (97 loc) · 4.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""
.. _ex-epoch-quality:
========================================
Exploring epoch quality before rejection
========================================
This example shows an approach for identifying epochs containing potential artifacts and
rejecting these bad epochs. We compute per-epoch outlier scores from peak-to-peak
amplitude, variance, and kurtosis — inspired by FASTER :footcite:`NolanEtAl2010` and
:footcite:t:`DelormeEtAl2007` — and use them to rank epochs from cleanest to noisiest to
inform rejection decisions.
"""
# Authors: Aman Srivastava
#
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
# %%
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import kurtosis
import mne
from mne.datasets import eegbci
print(__doc__)
# %%
# Load the EEGBCI dataset and create epochs
# -----------------------------------------
raw_fname = eegbci.load_data(subjects=3, runs=(3,))[0]
raw = mne.io.read_raw(raw_fname, preload=True)
eegbci.standardize(raw)
montage = mne.channels.make_standard_montage("standard_1005")
raw.set_montage(montage)
events, event_id = mne.events_from_annotations(raw)
epochs = mne.Epochs(raw, events, tmin=-0.2, tmax=0.5, preload=True, baseline=(None, 0))
# %%
# Compute per-epoch outlier scores
# --------------------------------
# Peak-to-peak amplitude, variance, and kurtosis are computed per epoch. Each feature is
# z-scored robustly using median absolute deviation across epochs, and averaged into a
# single outlier score normalised between [0, 1]. Scores close to 1 indicate a likely
# presence of artifacts in the epoch.
data = epochs.get_data() # (n_epochs, n_channels, n_times)
ptp = np.ptp(data, axis=-1).mean(axis=-1)
var = data.var(axis=-1).mean(axis=-1)
kurt = np.array([kurtosis(data[i].ravel()) for i in range(len(data))])
features = np.column_stack([ptp, var, kurt])
median = np.median(features, axis=0)
mad = np.median(np.abs(features - median), axis=0) + 1e-10
z = np.abs((features - median) / mad)
raw_score = z.mean(axis=-1)
scores = (raw_score - raw_score.min()) / (raw_score.max() - raw_score.min() + 1e-10)
# %%
# Determining outlier epochs
# --------------------------
# Below, epochs are ranked from cleanest to noisiest. We need to find an appropriate
# threshold to flag those epochs likely containing artifacts. In the plot, we show two
# example thresholds: a more lenient threshold of 0.8; and a stricter threshold of 0.6.
fig, ax = plt.subplots(layout="constrained")
sorted_idx = np.argsort(scores)
ax.bar(np.arange(len(scores)), scores[sorted_idx], color="steelblue")
ax.axhline(0.8, color="red", linestyle="--", label="More lenient threshold (0.8)")
ax.axhline(0.6, color="orange", linestyle="--", label="Stricter threshold (0.6)")
ax.set(
xlabel="Epoch (sorted by score)",
ylabel="Outlier score",
title="Epoch quality scores (0 = clean, 1 = likely artifact)",
)
ax.legend()
for threshold in [0.8, 0.6]:
bad_epochs = np.where(scores > threshold)[0]
print(
f"Threshold {threshold}: {len(bad_epochs)} epochs flagged "
f"out of {len(epochs)} total"
)
# %%
# Epochs flagged by the thresholds can be inspected using the :meth:`~mne.Epochs.plot`
# method. First, we show those epochs with the worst scores (≥ 0.8), containing a number
# of amplitude spikes.
epochs[np.where(scores >= 0.8)[0]].plot(title="Scores ≥ 0.8", scalings=dict(eeg=70e-6))
# %%
# In contrast, the threshold of 0.6 captures epochs with less severe artifact activity,
# which may be overly conservative to exclude from the analysis.
epochs[np.where((scores >= 0.6) & (scores < 0.8))[0]].plot(
title="0.6 ≤ scores < 0.8", scalings=dict(eeg=70e-6)
)
# %%
# Identify and handle suspicious epochs
# --------------------------------------
# Epochs scoring above the threshold can be inspected visually using
# :meth:`mne.Epochs.plot`, or dropped directly using
# :meth:`mne.Epochs.drop`. The threshold to use to flag epochs as outliers varies
# depending on the dataset and analysis goals, and inspecting the flagged epochs is a
# crucial step in identifying the optimal threshold.
epochs.drop(np.where(scores >= 0.8)[0])
print(f"Epochs remaining after dropping scores ≥ 0.8: {len(epochs)}")
# %%
# References
# ----------
# .. footbibliography::