Skip to content

Commit 825eea3

Browse files
Customize annotation colors (#13838)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent bd1dc9c commit 825eea3

8 files changed

Lines changed: 126 additions & 3 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add ``annotation_colors`` parameter to :meth:`mne.io.Raw.plot` and :meth:`mne.Epochs.plot` to allow users to specify custom colors for annotations by passing a dict mapping annotation description strings to colors (for example, ``annotation_colors=dict(bad_segment="orange")``), by `Clemens Brunner`_.

mne/epochs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1321,6 +1321,7 @@ def plot(
13211321
theme=None,
13221322
overview_mode=None,
13231323
splash=True,
1324+
annotation_colors=None,
13241325
):
13251326
return plot_epochs(
13261327
self,
@@ -1347,6 +1348,7 @@ def plot(
13471348
theme=theme,
13481349
overview_mode=overview_mode,
13491350
splash=splash,
1351+
annotation_colors=annotation_colors,
13501352
)
13511353

13521354
@copy_function_doc_to_method_doc(plot_topo_image_epochs)

mne/io/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1964,6 +1964,7 @@ def plot(
19641964
bad_color="lightgray",
19651965
event_color="cyan",
19661966
*,
1967+
annotation_colors=None,
19671968
annotation_regex=".*",
19681969
scalings=None,
19691970
remove_dc=True,
@@ -2004,6 +2005,7 @@ def plot(
20042005
color,
20052006
bad_color,
20062007
event_color,
2008+
annotation_colors=annotation_colors,
20072009
annotation_regex=annotation_regex,
20082010
scalings=scalings,
20092011
remove_dc=remove_dc,

mne/viz/_figure.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,16 +167,24 @@ def _get_annotation_labels(self):
167167

168168
def _setup_annotation_colors(self):
169169
"""Set up colors for annotations; init some annotation vars."""
170+
from matplotlib.colors import to_hex
171+
170172
segment_colors = getattr(self.mne, "annotation_segment_colors", dict())
171173
labels = self._get_annotation_labels()
174+
user_colors = {
175+
k: to_hex(v)
176+
for k, v in (getattr(self.mne, "annotation_colors", None) or {}).items()
177+
}
172178
red = "#ff0000"
173179
colors = _get_color_list(remove=("#fa8174", "#d62728", "#ff0000"))
174180
color_cycle = cycle(colors)
175181
for key, color in segment_colors.items():
176-
if color != red and key in labels:
182+
if color != red and key in labels and key not in user_colors:
177183
next(color_cycle)
178184
for idx, key in enumerate(labels):
179-
if key.lower().startswith("bad") or key.lower().startswith("edge"):
185+
if key in user_colors:
186+
segment_colors[key] = user_colors[key]
187+
elif key.lower().startswith("bad") or key.lower().startswith("edge"):
180188
segment_colors[key] = red
181189
elif key in segment_colors:
182190
continue

mne/viz/epochs.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
_handle_precompute,
3030
_make_combine_callable,
3131
_make_event_color_dict,
32+
_normalize_annotation_colors,
3233
_set_title_multiple_electrodes,
3334
_set_window_title,
3435
_setup_cmap,
@@ -763,6 +764,7 @@ def plot_epochs(
763764
theme=None,
764765
overview_mode=None,
765766
splash=True,
767+
annotation_colors=None,
766768
):
767769
"""Visualize epochs.
768770
@@ -865,6 +867,14 @@ def plot_epochs(
865867
%(splash)s
866868
867869
.. versionadded:: 1.6
870+
annotation_colors : dict | None
871+
A dictionary mapping annotation description strings to colors. Use this to
872+
override the default color assigned to specific annotation types (e.g.,
873+
``dict(bad_segment='orange')``). Colors can be any valid Matplotlib color
874+
specification. Keys that do not match any annotation description in the data
875+
will trigger a warning. If ``None`` (default), automatic colors are used.
876+
877+
.. versionadded:: 1.13
868878
869879
Returns
870880
-------
@@ -1014,6 +1024,13 @@ def plot_epochs(
10141024
raise TypeError(f"title must be None or a string, got a {type(title)}")
10151025

10161026
precompute = _handle_precompute(precompute)
1027+
1028+
# handle annotation_colors
1029+
if annotation_colors is not None:
1030+
annotation_colors = _normalize_annotation_colors(
1031+
annotation_colors, epochs.annotations
1032+
)
1033+
10171034
params = dict(
10181035
inst=epochs,
10191036
info=info,
@@ -1058,6 +1075,7 @@ def plot_epochs(
10581075
ch_color_dict=color,
10591076
epoch_color_bad=(1, 0, 0),
10601077
epoch_colors=epoch_colors,
1078+
annotation_colors=annotation_colors,
10611079
# display
10621080
butterfly=butterfly,
10631081
clipping=None,

mne/viz/raw.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,13 @@
1111
from .._fiff.pick import _picks_to_idx, pick_channels, pick_types
1212
from ..defaults import _handle_default
1313
from ..filter import create_filter
14-
from ..utils import _check_option, _get_stim_channel, _validate_type, legacy, verbose
14+
from ..utils import (
15+
_check_option,
16+
_get_stim_channel,
17+
_validate_type,
18+
legacy,
19+
verbose,
20+
)
1521
from ..utils.spectrum import _split_psd_kwargs
1622
from .utils import (
1723
_check_cov,
@@ -20,6 +26,7 @@
2026
_handle_decim,
2127
_handle_precompute,
2228
_make_event_color_dict,
29+
_normalize_annotation_colors,
2330
_shorten_path_from_middle,
2431
)
2532

@@ -38,6 +45,7 @@ def plot_raw(
3845
bad_color="lightgray",
3946
event_color="cyan",
4047
*,
48+
annotation_colors=None,
4149
annotation_regex=".*",
4250
scalings=None,
4351
remove_dc=True,
@@ -104,6 +112,14 @@ def plot_raw(
104112
Color to make bad channels.
105113
%(event_color)s
106114
Defaults to ``'cyan'``.
115+
annotation_colors : dict | None
116+
A dictionary mapping annotation description strings to colors. Use this to
117+
override the default color assigned to specific annotation types (e.g.,
118+
``dict(bad_segment='orange')``). Colors can be any valid Matplotlib color
119+
specification. Keys that do not match any annotation description in the data
120+
will trigger a warning. If ``None`` (default), automatic colors are used.
121+
122+
.. versionadded:: 1.13
107123
annotation_regex : str
108124
A regex pattern applied to each annotation's label.
109125
Matching labels remain visible, non-matching labels are hidden.
@@ -335,6 +351,12 @@ def plot_raw(
335351
if order.size == 0:
336352
raise RuntimeError("No channels found to plot")
337353

354+
# handle annotation_colors
355+
if annotation_colors is not None:
356+
annotation_colors = _normalize_annotation_colors(
357+
annotation_colors, raw.annotations
358+
)
359+
338360
# handle event colors
339361
event_color_dict = _make_event_color_dict(event_color, events, event_id)
340362

@@ -399,6 +421,7 @@ def plot_raw(
399421
# colors
400422
ch_color_bad=bad_color,
401423
ch_color_dict=color,
424+
annotation_colors=annotation_colors,
402425
# display
403426
butterfly=butterfly,
404427
clipping=clipping,

mne/viz/tests/test_raw.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -865,6 +865,44 @@ def test_plot_annotations(raw, browser_backend):
865865
assert "A" in raw.annotations.description
866866

867867

868+
def test_annotation_colors(raw, browser_backend):
869+
"""Test that annotation_colors overrides default colors."""
870+
from matplotlib.colors import to_hex
871+
872+
with raw.info._unlock():
873+
raw.info["lowpass"] = 10.0
874+
875+
raw.set_annotations(
876+
Annotations(
877+
onset=[1, 3, 5],
878+
duration=[1, 1, 1],
879+
description=["BAD_test", "BAD_other", "stimulus"],
880+
)
881+
)
882+
883+
# User-provided colors override defaults (including bad* → red rule).
884+
# BAD_other has no override and should remain red.
885+
fig = raw.plot(
886+
annotation_colors={"BAD_test": "orange", "stimulus": "#00ff00"},
887+
)
888+
colors = fig.mne.annotation_segment_colors
889+
assert colors["BAD_test"] == to_hex("orange"), (
890+
"User color for BAD_test should override red default"
891+
)
892+
assert colors["stimulus"] == "#00ff00"
893+
assert colors["BAD_other"] == "#ff0000", (
894+
"BAD_other has no user override and should remain red"
895+
)
896+
897+
# Unknown label key triggers a warning
898+
with pytest.warns(RuntimeWarning, match="do not match"):
899+
fig = raw.plot(annotation_colors={"nonexistent_label": "blue"})
900+
901+
# Invalid color value raises ValueError
902+
with pytest.raises(ValueError, match="not a valid matplotlib color"):
903+
raw.plot(annotation_colors={"BAD_test": "not_a_color"}, show=False)
904+
905+
868906
@pytest.mark.parametrize("active_annot_idx", (0, 1, 2))
869907
def test_overlapping_annotation_deletion(raw, browser_backend, active_annot_idx):
870908
"""Test deletion of annotations via right-click."""

mne/viz/utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2866,3 +2866,34 @@ def _get_plot_ch_type(inst, ch_type, allow_ref_meg=False):
28662866
f"No plottable channel types found. Allowed types are: {allowed_types}"
28672867
)
28682868
return ch_type
2869+
2870+
2871+
def _normalize_annotation_colors(annotation_colors, annotations):
2872+
"""Normalize annotation_colors and check that keys match annotation descriptions.
2873+
2874+
Parameters
2875+
----------
2876+
annotation_colors : dict[str, color]
2877+
The annotation colors to normalize (``color`` can be any valid Matplotlib color
2878+
specification).
2879+
annotations : mne.Annotations
2880+
The Annotations object to check against.
2881+
"""
2882+
from matplotlib.colors import to_hex
2883+
2884+
_validate_type(annotation_colors, dict, "annotation_colors")
2885+
normalized = {}
2886+
for k, v in annotation_colors.items():
2887+
try:
2888+
normalized[k] = to_hex(v)
2889+
except ValueError:
2890+
raise ValueError(
2891+
f"annotation_colors[{k!r}] is not a valid matplotlib color: {v!r}"
2892+
) from None
2893+
unknown = set(normalized) - set(annotations.description)
2894+
if unknown:
2895+
warn(
2896+
"The following annotation_colors keys do not match any annotation "
2897+
f"description in the data: {sorted(unknown)}"
2898+
)
2899+
return normalized

0 commit comments

Comments
 (0)