diff --git a/doc/changes/dev/13905.apichange.rst b/doc/changes/dev/13905.apichange.rst new file mode 100644 index 00000000000..aea11de14e1 --- /dev/null +++ b/doc/changes/dev/13905.apichange.rst @@ -0,0 +1 @@ +In :meth:`mne.Evoked.animate_topomap` the ``vmin`` and ``vmax`` keyword arguments have been deprecated in favor of ``vlim``, by `Eric Larson`_. diff --git a/doc/changes/dev/13905.newfeature.rst b/doc/changes/dev/13905.newfeature.rst new file mode 100644 index 00000000000..0b4695f15e3 --- /dev/null +++ b/doc/changes/dev/13905.newfeature.rst @@ -0,0 +1,7 @@ +Improved :class:`mne.Report` by: + +- Using ``"agg"`` backend for 2D plots to avoid windows unnecessarily popping up +- Improving :meth:`~mne.Report.add_stc` by using direct capture of brain images, allowing direct control of image size using ``stc_plot_kwargs["size"]`` +- Improving :meth:`~mne.Report.add_evokeds` evoked topomap plotting by using :meth:`mne.Evoked.animate_topomap` with ``butterfly=True`` rather than re-plotting each topomap in a loop + +Also improved :meth:`mne.Evoked.animate_topomap` by deduplicating code with :meth:`mne.Evoked.plot_topomap`. Changes by `Eric Larson`_. diff --git a/examples/visualization/evoked_topomap.py b/examples/visualization/evoked_topomap.py index 53b7a60dbba..cbafaeb875c 100644 --- a/examples/visualization/evoked_topomap.py +++ b/examples/visualization/evoked_topomap.py @@ -192,4 +192,4 @@ # sphinx_gallery_thumbnail_number = 9 times = np.arange(0.05, 0.151, 0.01) -fig, anim = evoked.animate_topomap(times=times, ch_type="mag", frame_rate=2, blit=False) +fig, anim = evoked.animate_topomap(times=times, ch_type="mag", frame_rate=2) diff --git a/mne/evoked.py b/mne/evoked.py index 2f953ff24a0..17048bf0193 100644 --- a/mne/evoked.py +++ b/mne/evoked.py @@ -791,21 +791,40 @@ def plot_joint( topomap_args=topomap_args, ) - @fill_doc + @verbose def animate_topomap( self, - ch_type=None, - times=None, - frame_rate=None, *, + times=None, + average=None, + ch_type=None, + scalings=None, + proj=False, + sensors=True, + show_names=False, + mask=None, + mask_params=None, + contours=6, + outlines="head", + sphere=None, + image_interp=_INTERPOLATION_DEFAULT, + extrapolate=_EXTRAPOLATE_DEFAULT, + border=_BORDER_DEFAULT, + res=64, + size=1.0, cmap=None, + vlim=(None, None), + cnorm=None, + colorbar=True, + cbar_fmt="%3.1f", + units=None, + axes=None, + time_unit="s", + time_format=None, + frame_rate=None, butterfly=False, blit=True, show=True, - time_unit="s", - sphere=None, - image_interp=_INTERPOLATION_DEFAULT, - extrapolate=_EXTRAPOLATE_DEFAULT, vmin=None, vmax=None, verbose=None, @@ -818,23 +837,44 @@ def animate_topomap( Parameters ---------- - ch_type : str | None - Channel type to plot. Accepted data types: 'mag', 'grad', 'eeg', - 'hbo', 'hbr', 'fnirs_cw_amplitude', - 'fnirs_fd_ac_amplitude', 'fnirs_fd_phase', and 'fnirs_od'. - If None, first available channel type from the above list is used. - Defaults to None. times : array of float | None - The time points to plot. If None, 10 evenly spaced samples are - calculated over the evoked time series. Defaults to None. + The time points to plot. If None (default), 10 evenly spaced samples are + calculated over the evoked time series. + %(average_plot_evoked_topomap)s + %(ch_type_topomap)s + %(scalings_topomap)s + %(proj_plot)s + %(sensors_topomap)s + %(show_names_topomap)s + %(mask_evoked_topomap)s + %(mask_params_topomap)s + %(contours_topomap)s + %(outlines_topomap)s + %(sphere_topomap_auto)s + %(image_interp_topomap)s + %(extrapolate_topomap)s + %(border_topomap)s + %(res_topomap)s + %(size_topomap)s + %(cmap_topomap)s + %(vlim_plot_topomap_psd)s + %(cnorm)s + %(colorbar_topomap)s + %(cbar_fmt_topomap)s + %(units_topomap_evoked)s + axes : list of matplotlib.axes.Axes | None + The axes to use for plotting. Must have one axis for the topomap, + then one for the colorbar (if ``colorbar=True``), then one for the + butterfly axes (if ``butterfly=True``). + time_unit : str + The units for the time axis, can be "ms" or "s" (default). + time_format : str | None + String format for topomap values. Defaults (None) to "%%01d ms" if + ``time_unit='ms'``, "%%0.3f s" if ``time_unit='s'``, and + "%%g" otherwise. Can be an empty string to omit the time label. frame_rate : int | None Frame rate for the animation in Hz. If None, frame rate = sfreq / 10. Defaults to None. - cmap : matplotlib colormap | None - Colormap to use. If None, 'Reds' is used for all positive data, - otherwise defaults to 'RdBu_r'. - - .. versionadded:: 1.12.0 butterfly : bool Whether to plot the data as butterfly plot under the topomap. Defaults to False. @@ -845,19 +885,10 @@ def animate_topomap( Defaults to True. show : bool Whether to show the animation. Defaults to True. - time_unit : str - The units for the time axis, can be "ms" (default in 0.16) - or "s" (will become the default in 0.17). - - .. versionadded:: 0.16 - %(sphere_topomap_auto)s - %(image_interp_topomap)s - %(extrapolate_topomap)s - - .. versionadded:: 0.22 - %(vmin_vmax_topomap)s - - .. versionadded:: 1.1.0 + vmin : float | None + Deprecated, use ``vlim=(vmin, vmax)`` instead. + vmax : float | None + Deprecated, use ``vlim=(vmin, vmax)`` instead. %(verbose)s Returns @@ -869,24 +900,46 @@ def animate_topomap( Notes ----- + .. versionchanged:: 1.13.0 + The ``vmin`` and ``vmax`` parameters were deprecated in favor of a single + ``vlim`` parameter, and parameters were added and reordered to follow + :meth:`~mne.Evoked.plot_topomap`. .. versionadded:: 0.12.0 """ return _topomap_animation( - self, - ch_type=ch_type, + evoked=self, times=times, - frame_rate=frame_rate, - butterfly=butterfly, - blit=blit, - show=show, - time_unit=time_unit, + average=average, + ch_type=ch_type, + scalings=scalings, + proj=proj, + sensors=sensors, + show_names=show_names, + mask=mask, + mask_params=mask_params, + contours=contours, + outlines=outlines, sphere=sphere, image_interp=image_interp, extrapolate=extrapolate, + border=border, + res=res, + size=size, + cmap=cmap, + vlim=vlim, + cnorm=cnorm, + colorbar=colorbar, + cbar_fmt=cbar_fmt, + units=units, + axes=axes, + time_unit=time_unit, + time_format=time_format, + frame_rate=frame_rate, + butterfly=butterfly, + blit=blit, vmin=vmin, vmax=vmax, - verbose=verbose, - cmap=cmap, + show=show, ) def as_type(self, ch_type="grad", mode="fast"): diff --git a/mne/report/report.py b/mne/report/report.py index 50b6c9fc30b..92d628315b0 100644 --- a/mne/report/report.py +++ b/mne/report/report.py @@ -7,6 +7,7 @@ import base64 import copy import dataclasses +import functools import os import os.path as op import re @@ -20,7 +21,9 @@ from pathlib import Path from shutil import copyfile +import matplotlib import numpy as np +from matplotlib.animation import AbstractMovieWriter from .. import __version__ as MNE_VERSION from .._fiff.meas_info import Info, read_info @@ -366,6 +369,40 @@ def _check_tags(tags) -> tuple[str]: # PLOTTING FUNCTIONS +class _NdArrayCapture(AbstractMovieWriter): + def __init__(self, frames: list): + super().__init__(fps=1, metadata={}, bitrate=0) + self.frames = frames + + def grab_frame(self, **savefig_kwargs): + img = _fig_to_img( + fig=self.fig, image_format="ndarray", pad_inches=0, **savefig_kwargs + ) + self.frames.append(img) + + def save(self, filename, *args, **kwargs): + pass + + def finish(self): + pass + + def setup(self, fig, outfile, dpi=None): + self.fig = fig + + +def _use_agg(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + old_backend = matplotlib.get_backend() + matplotlib.use("agg") + try: + func(*args, **kwargs) + finally: + matplotlib.use(old_backend) + + return wrapper + + def _constrain_fig_resolution(fig, *, max_width, max_res): """Limit the resolution (DPI) of a figure. @@ -449,29 +486,23 @@ def _fig_to_img( logger.debug( f"Saving figure with dimension {fig.get_size_inches()} inches with {dpi} dpi" ) - - # https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html - pil_kwargs = dict() - if image_format == "webp": - pil_kwargs.update(lossless=True, method=6) - elif image_format == "png": - pil_kwargs.update(optimize=True, compress_level=9) - if pil_kwargs: - # matplotlib modifies the passed dict, which is a bug - mpl_kwargs["pil_kwargs"] = pil_kwargs.copy() - - mpl_format = image_format - if image_format == "ndarray": - mpl_format = "png" + mpl_format = "svg" if image_format == "svg" else "png" fig.savefig(output, format=mpl_format, dpi=dpi, **mpl_kwargs) if own_figure: plt.close(fig) - # Remove alpha + # Remove alpha channel entirely (for space and to avoid rendering issues) if image_format not in ("svg", "ndarray"): from PIL import Image + # https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html + pil_kwargs = dict() + if image_format == "webp": + # Here quality means speed/size tradeoff (either way the result is lossless) + pil_kwargs.update(lossless=False, quality=80) + elif image_format == "png": + pil_kwargs.update(optimize=True, compress_level=9) output.seek(0) orig = Image.open(output) if orig.mode == "RGBA": @@ -1361,6 +1392,11 @@ def add_evokeds( %(topomap_kwargs)s %(n_jobs)s + .. versionchanged:: 1.13 + This parameter is currently unused, as parallelization of evoked topomap + plotting is no longer needed. This parameter will remain for + compatibility in case future options allow parallelization. + Notes ----- .. versionadded:: 0.24.0 @@ -1406,7 +1442,6 @@ def add_evokeds( tags=tags, section=title, topomap_kwargs=topomap_kwargs, - n_jobs=n_jobs, replace=replace, ) @@ -1519,6 +1554,10 @@ def add_stc( Notes ----- + .. versionchanged:: 1.13 + The size of the obtained brain images (in pixels) now comes from the + ``size`` argument of ``stc_plot_kwargs``, which has a default + ``size=(450, 450)``. The ``report.img_max_res`` is now ignored. .. versionadded:: 0.24.0 """ tags = _check_tags(tags) @@ -3572,6 +3611,7 @@ def _add_raw( replace=replace, ) + @_use_agg def _add_projs( self, *, @@ -3823,41 +3863,50 @@ def _add_evoked_joint( own_figure=True, ) - def _plot_one_evoked_topomap_timepoint( - self, *, evoked, time, ch_types, vmin, vmax, topomap_kwargs + def _plot_evoked_topomap_timepoints( + self, *, evoked, ch_types, vmin, vmax, topomap_kwargs, times ): import matplotlib.pyplot as plt - fig, ax = plt.subplots( - 1, - len(ch_types) * 2, - gridspec_kw={"width_ratios": [8, 0.5] * len(ch_types)}, - figsize=(2.5 * len(ch_types), 2), - layout="constrained", - ) - ch_type_ax_map = dict( - zip( - ch_types, - [(ax[i], ax[i + 1]) for i in range(0, 2 * len(ch_types) - 1, 2)], - ) - ) - + frames = dict() + # In principle we could parallelize this across channel types, but + # in practice it doesn't speed things up, and adds issues (e.g., with mpl + # backend switching and global state somehow) for ch_type in ch_types: - evoked.plot_topomap( - times=[time], + fig, axes = plt.subplots( + 2, + 2, + gridspec_kw={"width_ratios": [8, 0.5], "height_ratios": [4, 1]}, + figsize=(2.8, 2.8), + layout="constrained", + ) + fig.delaxes(axes[1, 1]) + axes = axes.ravel()[:3] + axes[0].set_title(ch_type) + frames[ch_type] = list() + this_writer = _NdArrayCapture(frames[ch_type]) + _, ch_anim = evoked.animate_topomap( + times=times, ch_type=ch_type, vlim=(vmin[ch_type], vmax[ch_type]), - axes=ch_type_ax_map[ch_type], + axes=axes, show=False, + time_format="", # we impose our own in HTML + butterfly=True, **topomap_kwargs, ) - ch_type_ax_map[ch_type][0].set_title(ch_type) - - return self._fig_to_img( - fig=fig, - image_format="ndarray", - pad_inches=0, - ) + ch_anim.pause() + ch_anim.save("", writer=this_writer) + plt.close(fig) + del ( + fig, + axes, + ) + imgs = [ + np.concatenate([frames[ch_type][ti] for ch_type in ch_types], axis=1) + for ti in range(len(times)) + ] + return imgs def _add_evoked_topomap_slider( self, @@ -3869,7 +3918,6 @@ def _add_evoked_topomap_slider( section, tags, topomap_kwargs, - n_jobs, replace, ): if n_time_points is None: @@ -3916,22 +3964,14 @@ def _add_evoked_topomap_slider( return # No need to warn here, we did that above else: topomap_kwargs = self._validate_topomap_kwargs(topomap_kwargs) - parallel, p_fun, n_jobs = parallel_func( - func=self._plot_one_evoked_topomap_timepoint, - n_jobs=n_jobs, - max_jobs=len(times), - ) with use_log_level(_verbose_safe_false(level="error")): - fig_arrays = parallel( - p_fun( - evoked=evoked, - time=time, - ch_types=ch_types, - vmin=vmin, - vmax=vmax, - topomap_kwargs=topomap_kwargs, - ) - for time in times + fig_arrays = self._plot_evoked_topomap_timepoints( + evoked=evoked, + times=times, + ch_types=ch_types, + vmin=vmin, + vmax=vmax, + topomap_kwargs=topomap_kwargs, ) captions = [f"Time point: {round(t, 3):0.3f} s" for t in times] @@ -4010,6 +4050,7 @@ def _add_evoked_whitened( own_figure=True, ) + @_use_agg def _add_evoked( self, *, @@ -4021,7 +4062,6 @@ def _add_evoked( section, tags, topomap_kwargs, - n_jobs, replace, ): # Summary table @@ -4055,7 +4095,6 @@ def _add_evoked( section=section, tags=tags, topomap_kwargs=topomap_kwargs, - n_jobs=n_jobs, replace=replace, ) @@ -4095,6 +4134,7 @@ def _add_evoked( logger.debug("Evoked: done") + @_use_agg def _add_events( self, *, @@ -4132,6 +4172,7 @@ def _add_events( own_figure=True, ) + @_use_agg def _add_epochs_psd(self, *, epochs, psd, image_format, tags, section, replace): epoch_duration = epochs.tmax - epochs.tmin @@ -4222,6 +4263,7 @@ def _add_epochs_metadata(self, *, epochs, section, tags, replace): replace=replace, ) + @_use_agg def _add_epochs( self, *, @@ -4352,6 +4394,7 @@ def _add_epochs( replace=replace, ) + @_use_agg def _add_cov(self, *, cov, info, image_format, section, tags, replace): """Render covariance matrix & SVD.""" if not isinstance(cov, Covariance): @@ -4515,16 +4558,7 @@ def _add_stc( if backend_is_3d: brain.set_time(t) - fig, ax = plt.subplots(figsize=(4.5, 4.5), layout="constrained") - ax.imshow(brain.screenshot(time_viewer=True, mode="rgb")) - ax.axis("off") - _constrain_fig_resolution( - fig, - max_width=stc_plot_kwargs["size"][0], - max_res=self.img_max_res, - ) - figs.append(fig) - plt.close(fig) + figs.append(brain.screenshot(time_viewer=True, mode="rgb")) else: fig_lh = plt.figure(layout="constrained") fig_rh = plt.figure(layout="constrained") @@ -4582,8 +4616,10 @@ def _add_stc( own_figure=False, # prevent rescaling ) for fig in figs: - plt.close(fig) + if not isinstance(fig, np.ndarray): + plt.close(fig) + @_use_agg def _add_bem( self, *, diff --git a/mne/viz/tests/test_topomap.py b/mne/viz/tests/test_topomap.py index d364d196f18..44e6b6730a6 100644 --- a/mne/viz/tests/test_topomap.py +++ b/mne/viz/tests/test_topomap.py @@ -169,25 +169,48 @@ def test_plot_projs_topomap_joint(meg, vlim, raw): assert len(fig.axes) == 4 # 2 mag, 2 grad -def test_plot_topomap_animation(capsys): +def test_plot_topomap_animation(capsys, tmp_path): """Test topomap plotting.""" - # evoked evoked = read_evokeds(evoked_fname, "Left Auditory", baseline=(None, 0)) - - # Test animation - fig, anim = evoked.animate_topomap( - ch_type="grad", - times=[0, 0.1], - cmap="viridis", - butterfly=False, - time_unit="s", - verbose="debug", - ) - anim._func(1) # _animate has to be tested separately on 'Agg' backend. + with pytest.warns(FutureWarning, match=".* vmin .* deprecated.*"): + fig, anim = evoked.animate_topomap( + times=[0, 0.1], + cmap="viridis", + vmin=0, + vmax=10, + verbose="debug", + ) out, _ = capsys.readouterr() - assert "extrapolation mode local to 0" in out + assert "extrapolation mode local to mean" in out assert fig.axes[0].images[0].get_cmap().name == "viridis" + # saving + PIL = pytest.importorskip("PIL") + gif_path = tmp_path / "test.gif" + anim.save(gif_path, writer="pillow") + assert gif_path.exists() + with PIL.Image.open(gif_path) as img: + assert img.format == "GIF" + assert img.n_frames == 2 + for frame in PIL.ImageSequence.Iterator(img): + assert frame.format == "GIF" + data = np.array(frame) + assert data.any() # not all empty + + # failure modes + evoked.pick("mag") + with pytest.raises(ValueError, match="No channels of type"): + evoked.animate_topomap(ch_type="eeg") + fig, axes = plt.subplots(1, 4) + with pytest.raises(ValueError, match="it must have length 2"): + evoked.animate_topomap(axes=axes) + with pytest.raises(ValueError, match="it must have length 3"): + evoked.animate_topomap(axes=axes, butterfly=True) + with pytest.raises(TypeError, match="axes must be an instance"): + evoked.animate_topomap(axes="test") + with pytest.raises(TypeError, match=r"axes\[0\] must be an instance"): + evoked.animate_topomap(axes=["test", "test"]) + def test_plot_topomap_animation_csd(capsys): """Test topomap plotting of CSD data.""" @@ -201,17 +224,12 @@ def test_plot_topomap_animation_csd(capsys): ) anim._func(1) # _animate has to be tested separately on 'Agg' backend. out, _ = capsys.readouterr() - assert "extrapolation mode head to 0" in out + assert "extrapolation mode head to mean" in out -@pytest.mark.filterwarnings("ignore:.*No contour levels.*:UserWarning") -def test_plot_topomap_animation_nirs(fnirs_evoked, capsys): +def test_plot_topomap_animation_nirs(fnirs_evoked): """Test topomap plotting for nirs data.""" - fig, anim = fnirs_evoked.animate_topomap(ch_type="hbo", verbose="debug") - anim._func(1) # _animate has to be tested separately on 'Agg' backend. - out, _ = capsys.readouterr() - assert "extrapolation mode head to 0" in out - assert len(fig.axes) == 2 + fnirs_evoked.animate_topomap(ch_type="hbo", verbose="debug") def test_plot_evoked_topomap_errors(evoked, monkeypatch): diff --git a/mne/viz/topomap.py b/mne/viz/topomap.py index fcbd213ce78..bea3cb00303 100644 --- a/mne/viz/topomap.py +++ b/mne/viz/topomap.py @@ -10,6 +10,8 @@ from functools import partial from numbers import Integral +import matplotlib.artist +import matplotlib.patches import numpy as np from scipy.interpolate import ( CloughTocher2DInterpolator, @@ -46,7 +48,6 @@ _is_numeric, _time_mask, _validate_type, - check_version, fill_doc, legacy, logger, @@ -85,12 +86,6 @@ ) -# 3.8+ uses a single Collection artist rather than .collections -# https://github.com/matplotlib/matplotlib/pull/25247 -def _cont_collections(cont): - return (cont,) if check_version("matplotlib", "3.8") else tuple(cont.collections) - - def _adjust_meg_sphere(sphere, info, ch_type): sphere = _check_sphere(sphere, info) assert ch_type is not None @@ -350,39 +345,36 @@ def _plot_update_evoked_topomap(params, bools): interp = params["interp"] new_contours = list() - use_contours = params["contours_"] - if not len(use_contours): - use_contours = [None] * len(params["axes"]) - assert len(use_contours) == len(params["images"]) + assert len(params["contours_"]) == len(params["images"]) assert len(params["axes"]) == len(params["images"]) assert len(data.T) == len(params["images"]) - for cont, ax, im, d in zip(use_contours, params["axes"], params["images"], data.T): + Xi, Yi = interp.Xi, interp.Yi + for cont, ax, im, d in zip( + params["contours_"], params["axes"], params["images"], data.T + ): Zi = interp.set_values(d)() im.set_data(Zi) - if cont is None: - continue - # must be removed and re-added - cont_collections = _cont_collections(cont) - for col in cont_collections: - col.remove() - col = cont_collections[0] - lw = col.get_linewidth() - visible = col.get_visible() - patch_ = col.get_clip_path() - color = col.get_edgecolors() - cont = ax.contour( - interp.Xi, interp.Yi, Zi, params["contours"], colors=color, linewidths=lw - ) - cont_collections = _cont_collections(cont) - for col in cont_collections: - col.set_visible(visible) - col.set_clip_path(patch_) - new_contours.append(cont) - params["contours_"] = new_contours - + new_contours.append(_update_contours(cont, ax, Xi, Yi, Zi, params["contours"])) + params["contours_"][:] = new_contours params["fig"].canvas.draw() +def _update_contours(cont, ax, Xi, Yi, Zi, contours): + if cont is None: + return cont + lw = cont.get_linewidth() + visible = cont.get_visible() + patch_ = cont.get_clip_path() + color = cont.get_edgecolors() + zorder = _TOPOMAP_ZORDER["contours"] + if cont in ax.collections: + cont.remove() + cont = ax.contour(Xi, Yi, Zi, contours, colors=color, linewidths=lw, zorder=zorder) + cont.set_visible(visible) + cont.set_clip_path(patch_) + return cont + + def _add_colorbar( ax, im, @@ -759,17 +751,20 @@ def _draw_outlines(ax, outlines): from matplotlib import rcParams outlines_ = {k: v for k, v in outlines.items() if k not in ["patch"]} + drawn_outlines = list() for key, (x_coord, y_coord) in outlines_.items(): if "mask" in key or key in ("clip_radius", "clip_origin"): continue - ax.plot( + (line,) = ax.plot( x_coord, y_coord, color=rcParams["axes.edgecolor"], linewidth=1, clip_on=False, + zorder=_TOPOMAP_ZORDER["head_outlines"], ) - return outlines_ + drawn_outlines.append(line) + return drawn_outlines def _get_extra_points(pos, extrapolate, origin, radii): @@ -979,17 +974,20 @@ def __call__(self, *args): def _topomap_plot_sensors(pos_x, pos_y, sensors, ax): """Plot sensors.""" + zorder = _TOPOMAP_ZORDER["sensors"] if sensors is True: - ax.scatter( + drawn_sensors = ax.scatter( pos_x, pos_y, s=0.25, marker="o", edgecolor=["k"] * len(pos_x), facecolor="none", + zorder=zorder, ) else: - ax.plot(pos_x, pos_y, sensors) + drawn_sensors = ax.plot(pos_x, pos_y, sensors, zorder=zorder)[0] + return drawn_sensors def _get_pos_outlines(info, picks, sphere, to_sphere=True): @@ -1187,6 +1185,7 @@ def _voronoi_topomap(data, pos, outlines, ax, cmap, norm, extent, res): aspect="equal", extent=extent, norm=norm, + zorder=_TOPOMAP_ZORDER["imshow"], ) rx, ry = outlines["clip_radius"] cx, cy = outlines.get("clip_origin", (0.0, 0.0)) @@ -1218,13 +1217,17 @@ def _voronoi_topomap(data, pos, outlines, ax, cmap, norm, extent, res): x *= rx / np.linalg.norm(vor.vertices[i]) y *= ry / np.linalg.norm(vor.vertices[i]) polygon.append((x, y)) - ax.fill(*zip(*polygon), color=cmap(norm(data[point_idx]))) + ax.fill( + *zip(*polygon), + color=cmap(norm(data[point_idx])), + zorder=_TOPOMAP_ZORDER["voronoi"], + ) return im -def _get_patch(outlines, extrapolate, interp, ax): - from matplotlib import patches - +def _make_head_patch(outlines, extrapolate, interp, ax): + # TODO: Disentangle adding the patch with creating it? Confusing flow here + # for "patch in outlines" and "_use_default_outlines" clip_radius = outlines["clip_radius"] clip_origin = outlines.get("clip_origin", (0.0, 0.0)) _use_default_outlines = any(k.startswith("head") for k in outlines) @@ -1238,11 +1241,11 @@ def _get_patch(outlines, extrapolate, interp, ax): ax.set_clip_path(patch_) if _use_default_outlines: if extrapolate == "local": - patch_ = patches.Polygon( + patch_ = matplotlib.patches.Polygon( interp.mask_pts, clip_on=True, transform=ax.transData ) else: - patch_ = patches.Ellipse( + patch_ = matplotlib.patches.Ellipse( clip_origin, 2 * clip_radius[0], 2 * clip_radius[1], @@ -1252,6 +1255,15 @@ def _get_patch(outlines, extrapolate, interp, ax): return patch_ +_TOPOMAP_ZORDER = dict( # keep these in order that we want to draw them, too + imshow=1.0, + voronoi=1.5, + contours=2.0, + head_outlines=2.5, + sensors=3.0, +) + + def _plot_topomap( data, pos, @@ -1371,6 +1383,8 @@ def _plot_topomap( _prepare_topomap(pos, axes) mask_params = _handle_default("mask_params", mask_params) + if "zorder" not in mask_params: + mask_params["zorder"] = _TOPOMAP_ZORDER["sensors"] # find mask limits and setup interpolation extent, Xi, Yi, interp = _setup_interp( @@ -1380,7 +1394,7 @@ def _plot_topomap( Zi = interp.set_locations(Xi, Yi)() # plot outline - patch_ = _get_patch(outlines, extrapolate, interp, axes) + head_patch = _make_head_patch(outlines, extrapolate, interp, axes) # get colormap normalization if cnorm is None: @@ -1407,6 +1421,7 @@ def _plot_topomap( extent=extent, interpolation="bilinear", norm=cnorm, + zorder=_TOPOMAP_ZORDER["imshow"], ) # gh-1432 had a workaround for no contours here, but we'll remove it @@ -1421,14 +1436,19 @@ def _plot_topomap( with warnings.catch_warnings(record=True): warnings.simplefilter("ignore") cont = axes.contour( - Xi, Yi, Zi, contours, colors="k", linewidths=linewidth / 2.0 + Xi, + Yi, + Zi, + contours, + colors="k", + linewidths=linewidth / 2.0, + zorder=_TOPOMAP_ZORDER["contours"], ) - if patch_ is not None: - im.set_clip_path(patch_) + if head_patch is not None: + im.set_clip_path(head_patch) if cont is not None: - for col in _cont_collections(cont): - col.set_clip_path(patch_) + cont.set_clip_path(head_patch) pos_x, pos_y = pos.T mask = mask.astype(bool, copy=False) if mask is not None else None @@ -2226,6 +2246,79 @@ def plot_evoked_topomap( * :class:`~mne.viz.ui_events.TimeChange` whenever a new time is selected. """ + fig, _params = _plot_evoked_topomap( + evoked=evoked, + times=times, + average=average, + ch_type=ch_type, + scalings=scalings, + proj=proj, + sensors=sensors, + show_names=show_names, + mask=mask, + mask_params=mask_params, + contours=contours, + outlines=outlines, + sphere=sphere, + image_interp=image_interp, + extrapolate=extrapolate, + border=border, + res=res, + size=size, + cmap=cmap, + vlim=vlim, + cnorm=cnorm, + colorbar=colorbar, + cbar_fmt=cbar_fmt, + units=units, + axes=axes, + time_unit=time_unit, + time_format=time_format, + nrows=nrows, + ncols=ncols, + interactive_colorbar=True, + single_time_point=False, + ) + plt_show(show, block=False) + if axes is not None: + fig.canvas.draw() + return fig + + +def _plot_evoked_topomap( + *, + evoked, + times, + average, + ch_type, + scalings, + proj, + sensors, + show_names, + mask, + mask_params, + contours, + outlines, + sphere, + image_interp, + extrapolate, + border, + res, + size, + cmap, + vlim, + cnorm, + colorbar, + cbar_fmt, + units, + axes, + time_unit, + time_format, + nrows, + ncols, + interactive_colorbar, + single_time_point, +): import matplotlib.pyplot as plt from matplotlib.gridspec import GridSpec from matplotlib.widgets import Slider @@ -2281,12 +2374,12 @@ def plot_evoked_topomap( evoked.apply_proj() elif proj == "reconstruct": evoked._reconstruct_proj() - data = evoked.data # remove compensation matrices (safe: only plotting & already made copy) with evoked.info._unlock(): evoked.info["comps"] = [] evoked = evoked._pick_drop_channels(picks, verbose=False) + data = evoked.data # determine which times to plot if isinstance(axes, plt.Axes): axes = [axes] @@ -2299,7 +2392,7 @@ def plot_evoked_topomap( f"Times should be between {evoked.times[0]:0.3} and {evoked.times[-1]:0.3}." ) # create axes - want_axes = n_times + int(colorbar) + want_axes = (1 if single_time_point else n_times) + int(colorbar) if interactive: height_ratios = [5, 1] nrows = 2 @@ -2339,10 +2432,11 @@ def plot_evoked_topomap( "an array-like object of the previous" ) + all_data = data.copy() averaged_times = [] if average is None: average = np.array([None] * n_times) - data = data[np.ix_(picks, time_idx)] + data = data[:, time_idx] else: if _is_numeric(average): average = np.array([average] * n_times) @@ -2374,18 +2468,19 @@ def plot_evoked_topomap( raise ValueError(msg) if this_average is None: - data_[:, average_idx] = data[picks][:, this_time_idx] + data_[:, average_idx] = data[:, this_time_idx] averaged_times.append([this_time]) else: tmin_ = this_time - this_average / 2 tmax_ = this_time + this_average / 2 time_mask = (tmin_ < evoked.times) & (evoked.times < tmax_) - data_[:, average_idx] = data[picks][:, time_mask].mean(-1) + data_[:, average_idx] = data[:, time_mask].mean(-1) averaged_times.append(evoked.times[time_mask]) data = data_ # apply scalings and merge channels data *= scaling + all_data *= scaling if merge_channels: # check modality if any(ch["coil_type"] in _opm_coils for ch in evoked.info["chs"]): @@ -2394,7 +2489,10 @@ def plot_evoked_topomap( modality = "fnirs" else: modality = "other" - # merge data + # merge data (need to copy the names on the first call, modified inplace) + all_data, _ = _merge_ch_data( + all_data, ch_type, list(ch_names), modality=modality + ) data, ch_names = _merge_ch_data(data, ch_type, ch_names, modality=modality) # if ch_type in _fnirs_types: if modality != "other": @@ -2440,10 +2538,12 @@ def plot_evoked_topomap( border=border, ch_type=ch_type, ) - images, contours_ = [], [] + images, drawn_contours, interps = [], [], [] # loop over times for average_idx, (time, this_average) in enumerate(zip(times, average)): - tp, cn, interp = _plot_topomap( + if single_time_point and average_idx > 0: + break + im, cn, interp = _plot_topomap( data[:, average_idx], pos, axes=axes[average_idx], @@ -2452,10 +2552,10 @@ def plot_evoked_topomap( vmax=_vlim[1], **kwargs, ) - - images.append(tp) - if cn is not None: - contours_.append(cn) + images.append(im) + interps.append(interp) + drawn_contours.append(cn) + del im, cn, interp if time_format != "": if this_average is None: axes_title = time_format % (time * scaling_time) @@ -2526,10 +2626,10 @@ def _slider_changed(val): cbar = fig.colorbar(images[-1], ax=axes, cax=cax, format=cbar_fmt, shrink=0.6) if unit is not None: cbar.ax.set_title(unit) - if cn is not None: + if drawn_contours[0] is not None: cbar.set_ticks(contours) cbar.ax.tick_params(labelsize=7) - if cmap[1]: + if cmap[1] and interactive_colorbar: for im in images: im.axes.CB = DraggableColorbar( cbar, im, kind="evoked_topomap", ch_type=ch_type @@ -2543,7 +2643,7 @@ def _slider_changed(val): projs=evoked.info["projs"], picks=picks, images=images, - contours_=contours_, + contours_=drawn_contours, pos=pos, time_idx=time_idx, res=res, @@ -2552,7 +2652,7 @@ def _slider_changed(val): scale=scaling, axes=axes[: len(axes) - bool(interactive)], contours=contours, - interp=interp, + interp=interps[0], # TODO: Maybe not correct for multiple axes! extrapolate=extrapolate, ) _draw_proj_checkbox(None, params) @@ -2562,10 +2662,22 @@ def _slider_changed(val): fig.mne = BrowserParams(proj_checkboxes=params["proj_checks"]) - plt_show(show, block=False) - if axes_given: - fig.canvas.draw() - return fig + # Additional things that might be needed by callers (e.g., animation) + params = dict( + data=data, + all_data=all_data, + all_times=evoked.times, + ch_type=ch_type, + used_times=times, + interps=interps, + images=images, + contours=contours, + drawn_contours=drawn_contours, + time_format=time_format, + scaling_time=scaling_time, + ) + + return fig, params def _resize_cbar(cax, n_fig_axes, size=1): @@ -3164,302 +3276,213 @@ def _check_extrapolate(extrapolate, ch_type): return extrapolate -@verbose -def _init_anim( - ax, - ax_line, - ax_cbar, - params, - merge_channels, - sphere, - ch_type, - cmap, - image_interp, - extrapolate, - verbose, -): - """Initialize animated topomap.""" - logger.info("Initializing animation...") - data = params["data"] - items = list() - vmin = params["vmin"] if "vmin" in params else None - vmax = params["vmax"] if "vmax" in params else None - if params["butterfly"]: - all_times = params["all_times"] - for idx in range(len(data)): - ax_line.plot(all_times, data[idx], color="k", lw=1) - vmin, vmax = _setup_vmin_vmax(data, vmin, vmax) - ax_line.set( - yticks=np.around(np.linspace(vmin, vmax, 5), -1), xlim=all_times[[0, -1]] - ) - params["line"] = ax_line.axvline(all_times[0], color="r") - items.append(params["line"]) - if merge_channels: - from mne.channels.layout import _merge_ch_data - - data, _ = _merge_ch_data(data, "grad", []) - norm = True if np.min(data) > 0 else False - if cmap is None: - cmap = "Reds" if norm else "RdBu_r" - - vmin, vmax = _setup_vmin_vmax(data, vmin, vmax, norm) - - outlines = _make_head_outlines(sphere, params["pos"], "head", params["clip_origin"]) - - _hide_frame(ax) - extent, Xi, Yi, interp = _setup_interp( - pos=params["pos"], - res=64, - image_interp=image_interp, - extrapolate=extrapolate, - outlines=outlines, - border=0, - ) - - patch_ = _get_patch(outlines, extrapolate, interp, ax) - - params["Zis"] = list() - for frame in params["frames"]: - params["Zis"].append(interp.set_values(data[:, frame])(Xi, Yi)) - Zi = params["Zis"][0] - zi_min = np.nanmin(params["Zis"]) - zi_max = np.nanmax(params["Zis"]) - cont_lims = np.linspace(zi_min, zi_max, 7, endpoint=False)[1:] - params.update( - { - "vmin": vmin, - "vmax": vmax, - "Xi": Xi, - "Yi": Yi, - "Zi": Zi, - "extent": extent, - "cmap": cmap, - "cont_lims": cont_lims, - } - ) - # plot map and contour - im = ax.imshow( - Zi, - cmap=cmap, - vmin=vmin, - vmax=vmax, - origin="lower", - aspect="equal", - extent=extent, - interpolation="bilinear", - ) - ax.autoscale(enable=True, tight=True) - ax.figure.colorbar(im, cax=ax_cbar) - cont = ax.contour(Xi, Yi, Zi, levels=cont_lims, colors="k", linewidths=1) - - im.set_clip_path(patch_) - text = ax.text(0.55, 0.95, "", transform=ax.transAxes, va="center", ha="right") - params["text"] = text - items.append(im) - items.append(text) - cont_collections = _cont_collections(cont) - for col in cont_collections: - col.set_clip_path(patch_) - - outlines_ = _draw_outlines(ax, outlines) - - params.update({"patch": patch_, "outlines": outlines_}) - return tuple(items) + cont_collections - - -def _animate(frame, ax, ax_line, params): - """Update animated topomap.""" - if params["pause"]: - frame = params["frame"] - time_idx = params["frames"][frame] - - if params["time_unit"] == "ms": - title = f"{params['times'][frame] * 1e3:6.0f} ms" - else: - title = f"{params['times'][frame]:6.3f} s" - if params["blit"]: - text = params["text"] - else: - ax.cla() # Clear old contours. - text = ax.text(0.45, 1.15, "", transform=ax.transAxes) - for k, (x, y) in params["outlines"].items(): - if "mask" in k: - continue - ax.plot(x, y, color="k", linewidth=1, clip_on=False) - - _hide_frame(ax) - text.set_text(title) - - vmin = params["vmin"] - vmax = params["vmax"] - Xi = params["Xi"] - Yi = params["Yi"] - Zi = params["Zis"][frame] - extent = params["extent"] - cmap = params["cmap"] - patch = params["patch"] - - im = ax.imshow( - Zi, - cmap=cmap, - vmin=vmin, - vmax=vmax, - origin="lower", - aspect="equal", - extent=extent, - interpolation="bilinear", - ) - cont_lims = params["cont_lims"] - with warnings.catch_warnings(record=True): - warnings.simplefilter("ignore") - cont = ax.contour(Xi, Yi, Zi, levels=cont_lims, colors="k", linewidths=1) - - im.set_clip_path(patch) - cont_collections = _cont_collections(cont) - for col in cont_collections: - col.set_clip_path(patch) - - items = [im, text] - if params["butterfly"]: - all_times = params["all_times"] - line = params["line"] - line.remove() - ylim = ax_line.get_ylim() - params["line"] = ax_line.axvline(all_times[time_idx], color="r") - ax_line.set_ylim(ylim) - items.append(params["line"]) - params["frame"] = frame - return tuple(items) + cont_collections - - -def _pause_anim(event, params): - """Pause or continue the animation on mouse click.""" - params["pause"] = not params["pause"] - - -def _key_press(event, params): - """Handle key presses for the animation.""" - if event.key == "left": - params["pause"] = True - params["frame"] = max(params["frame"] - 1, 0) - elif event.key == "right": - params["pause"] = True - params["frame"] = min(params["frame"] + 1, len(params["frames"]) - 1) +def _validate_artists(items): + items = tuple(items) + for ii, item in enumerate(items): + _validate_type(item, matplotlib.artist.Artist, f"items[{ii}]={item!r}") + return items def _topomap_animation( + *, evoked, - ch_type, - cmap, times, frame_rate, butterfly, blit, + axes, show, - time_unit, + vmin, + vmax, + # pass-through kwargs + average, + ch_type, + scalings, + proj, + sensors, + show_names, + mask, + mask_params, + contours, + outlines, sphere, image_interp, extrapolate, - *, - vmin, - vmax, - verbose=None, + border, + res, + size, + cmap, + vlim, + cnorm, + colorbar, + cbar_fmt, + units, + time_unit, + time_format, ): - """Make animation of evoked data as topomap timeseries. - - See mne.evoked.Evoked.animate_topomap. - """ + """Make animation of evoked data as topomap timeseries.""" + import matplotlib.pyplot as plt from matplotlib import animation - from matplotlib import pyplot as plt - if ch_type is None: - ch_type = _get_plot_ch_type(evoked, ch_type) - - time_unit, _ = _check_time_unit(time_unit, evoked.times) - if times is None: - times = np.linspace(evoked.times[0], evoked.times[-1], 10) - times = np.array(times) - - if times.ndim != 1: - raise ValueError(f"times must be 1D, got {times.ndim} dimensions") - if max(times) > evoked.times[-1] or min(times) < evoked.times[0]: - raise ValueError("All times must be inside the evoked time series.") - frames = [np.abs(evoked.times - time).argmin() for time in times] - - picks, pos, merge_channels, _, ch_type, sphere, clip_origin = _prepare_topomap_plot( - evoked, ch_type, sphere=sphere - ) - data = evoked.data[picks, :] - data *= _handle_default("scalings")[ch_type] - - norm = np.min(data) >= 0 - vmin, vmax = _setup_vmin_vmax(data, vmin, vmax, norm) + if frame_rate is None: + frame_rate = evoked.info["sfreq"] / 10.0 - fig = plt.figure(figsize=(6, 5), layout="constrained") - shape = (8, 12) - colspan = shape[1] - 1 - rowspan = shape[0] - bool(butterfly) - ax = plt.subplot2grid(shape, (0, 0), rowspan=rowspan, colspan=colspan) - if butterfly: - ax_line = plt.subplot2grid(shape, (rowspan, 0), colspan=colspan) + ax_line = None + if axes is None: + fig = plt.figure(figsize=(6, 5), layout="constrained") + shape = (8, 12) + colspan = shape[1] - bool(colorbar) + rowspan = shape[0] - bool(butterfly) + ax = plt.subplot2grid(shape, (0, 0), rowspan=rowspan, colspan=colspan, fig=fig) + if butterfly: + ax_line = plt.subplot2grid(shape, (rowspan, 0), colspan=colspan, fig=fig) + if colorbar: + ax_cbar = plt.subplot2grid(shape, (0, colspan), rowspan=rowspan, fig=fig) else: - ax_line = None - if isinstance(frames, Integral): - frames = np.linspace(0, len(evoked.times) - 1, frames).astype(int) - ax_cbar = plt.subplot2grid(shape, (0, colspan), rowspan=rowspan) - ax_cbar.set_title(_handle_default("units")[ch_type], fontsize=10) - extrapolate = _check_extrapolate(extrapolate, ch_type) + _validate_type(axes, "array-like", "axes") + axes = list(axes) + want_len = 1 + bool(colorbar) + bool(butterfly) + if len(axes) != want_len: + raise ValueError( + f"If axes is provided, it must have length {want_len} when " + f"{butterfly=} and {colorbar=}, got {len(axes)}" + ) + for ai, a in enumerate(axes): + _validate_type(a, plt.Axes, f"axes[{ai}]") + ax = axes[0] + if colorbar: + ax_cbar = axes[1] + if butterfly: + ax_line = axes[2] + fig = ax.figure + axes = [ax, ax_cbar] if colorbar else [ax] - params = dict( - data=data, - pos=pos, - all_times=evoked.times, - frame=0, - frames=frames, - butterfly=butterfly, - blit=blit, - pause=False, + if times is None: + times = np.linspace(evoked.times[0], evoked.times[-1], 10) + if vmin is not None or vmax is not None: + # Once this dep is done: remove vmin and vmax, and remove vlim from explicit + # pass below and above as a kwarg (so it just gets absorbed by + # _plot_evoked_topomap_kwargs) + vlim = (vmin, vmax) + warn( + f"vmax and vmin are deprecated, use vlim instead; using {vlim=}", + FutureWarning, + ) + del vmin, vmax + fig, topomap_params = _plot_evoked_topomap( + evoked=evoked, + # we handle these separately times=times, - time_unit=time_unit, - clip_origin=clip_origin, - vmin=vmin, - vmax=vmax, - ) - init_func = partial( - _init_anim, - ax=ax, - ax_cbar=ax_cbar, - ax_line=ax_line, - params=params, - merge_channels=merge_channels, - sphere=sphere, + axes=axes, + interactive_colorbar=False, + single_time_point=True, + nrows=1, + ncols=1, + # pass-through arguments + average=average, ch_type=ch_type, - cmap=cmap, + scalings=scalings, + proj=proj, + sensors=sensors, + show_names=show_names, + mask=mask, + mask_params=mask_params, + contours=contours, + outlines=outlines, + sphere=sphere, image_interp=image_interp, extrapolate=extrapolate, - verbose=verbose, + border=border, + res=res, + size=size, + cmap=cmap, + vlim=vlim, + cnorm=cnorm, + colorbar=colorbar, + cbar_fmt=cbar_fmt, + units=units, + time_unit=time_unit, + time_format=time_format, ) - animate_func = partial(_animate, ax=ax, ax_line=ax_line, params=params) - pause_func = partial(_pause_anim, params=params) - fig.canvas.mpl_connect("button_press_event", pause_func) - key_press_func = partial(_key_press, params=params) - fig.canvas.mpl_connect("key_press_event", key_press_func) - if frame_rate is None: - frame_rate = evoked.info["sfreq"] / 10.0 + del evoked, axes + data = topomap_params["data"] + all_data = topomap_params["all_data"] + all_times = topomap_params["all_times"] + used_times = topomap_params["used_times"] + assert data.ndim == 2 + assert data.shape[1] == len(used_times) + contours = topomap_params["contours"] + im = topomap_params["images"][0] + interp = topomap_params["interps"][0] + cont = topomap_params["drawn_contours"][0] + ax.set_title(topomap_params["ch_type"]) + scaling_time = topomap_params["scaling_time"] + time_format = topomap_params["time_format"] + del topomap_params + + text = None + if time_format: + text = ax.text(0.5, 0.95, "", transform=ax.transAxes, va="center", ha="center") + Xi, Yi = interp.Xi, interp.Yi + Zis = [interp.set_values(d)() for d in data.T] + del interp + butterfly_vline = None + if butterfly: + ax_line.plot(all_times, all_data.T, color="k", lw=0.5, alpha=0.5) + ax_line.set_xlim(all_times[0], all_times[-1]) + butterfly_vline = ax_line.axvline(used_times[0], color="r") + + params = dict(frame=0, frames=list(range(len(used_times))), pause=False, cont=cont) + del cont + + def animate(frame): + params["frame"] = frame + Zi = Zis[frame] + im.set_data(Zi) + if time_format: + text.set_text(time_format % (used_times[frame] * scaling_time)) + params["cont"] = _update_contours(params["cont"], ax, Xi, Yi, Zi, contours) + items = [im] + if butterfly: + butterfly_vline.set_xdata([used_times[frame]]) + return _validate_artists(items) + interval = 1000 / frame_rate # interval is in ms anim = animation.FuncAnimation( fig, - animate_func, - init_func=init_func, - frames=len(frames), + animate, + frames=params["frames"], interval=interval, blit=blit, + cache_frame_data=False, ) + + def pause_anim(event=None, *, pause=None): + if pause is None: + pause = not params["pause"] # we need to flip the state + if pause: # we need to pause + params["pause"] = True + anim.pause() + else: + params["pause"] = False + anim.resume() + + def key_press(event): + if event.key not in ("left", "right"): + return + pause_anim(pause=True) # ensure paused + if event.key == "left": + params["frame"] = max(params["frame"] - 1, 0) + elif event.key == "right": + params["frame"] = min(params["frame"] + 1, len(params["frames"]) - 1) + animate(params["frame"]) + fig.canvas.draw_idle() + + fig.canvas.mpl_connect("button_press_event", pause_anim) + fig.canvas.mpl_connect("key_press_event", key_press) + fig.mne_animation = anim # to make sure anim is not garbage collected plt_show(show, block=False) - if "line" in params: - # Finally remove the vertical line so it does not appear in saved fig. - params["line"].remove() return fig, anim diff --git a/tools/vulture_allowlist.py b/tools/vulture_allowlist.py index 763b0e8ea1b..1367a2cac25 100644 --- a/tools/vulture_allowlist.py +++ b/tools/vulture_allowlist.py @@ -50,6 +50,11 @@ _.requires_fit _.regressor_tags +# report +_.grab_frame +_.finish +_.setup + deep # Backward compat or rarely used