diff --git a/.notes/implementation-notes.html b/.notes/implementation-notes.html index 0a23504f..d1bb62e1 100644 --- a/.notes/implementation-notes.html +++ b/.notes/implementation-notes.html @@ -6,6 +6,138 @@

EEGPrep Core Parity Implementation Notes

+

Issue #213 Statistics and ICLabel Closeout Notes

+

Issue #213 closes the final Fable 5 architecture closeout phase for + statistics and ICLabel/viewprops ownership. The branch keeps public call + signatures, return values, command strings, GUI labels, and Matplotlib + dashboard layout behavior stable while moving implementations into focused + owner modules.

+

Issue #213 Design Decisions

+ +

Issue #213 Closeout Mapping for Remaining #197/#205 Items

+ + + + + + + + + + + + + + + + + + + +
Remaining itemDispositionRationale
pop_load_frombids.py raw-reader and montage + ownershipFixed in PR #216 / issue #212BIDS raw loading + and montage inference moved into EEG-BIDS helper modules while + pop_load_frombids stayed orchestration/history glue.
functions/statistics/_core.py mega-moduleFixed + in issue #213Implementations now live in same-name statistics + modules, with _shared.py holding only shared helpers.
plugins/ICLabel/pop_prop_extended.py mixed + ownershipFixed in issue #213Pop/history/dialog glue, + numerics/data assembly, and Matplotlib browser rendering now have + separate owner modules.
functions/guifunc/qt.py stateless renderer namespaceFixed + in PR #216 / issue #212Qt renderer helpers were split into + module-level helpers while preserving call-site compatibility and dialog + behavior.
pop_clust finite-outlier robust k-means dispatchFixed + in PR #217 / issue #210STUDY clustering behavior was consolidated + in the STUDY/time-frequency/statistics phase with dedicated tests.
pop_chanplot, measure-field maps, cached measure + axes, range masks, default plot targets, _trial_rows, factor + matching, and std_clustplot history constructionFixed + in PR #217 / issue #210STUDY measure/cache helper ownership was + consolidated as one phase because those paths share STUDY data contracts.
Time-frequency numeric-vector parsing, bootstrap/FDR helpers, + threshold-vector helpers, and empirical p-value convention auditFixed + or explicitly documented in PR #214 / issue #208 and PR #217 / issue + #210Shared parsing moved to lower-level helpers in Phase 1; + time-frequency and statistics behavior consolidation landed in Phase 3 + where numerical behavior could be tested together.
Divergent is_on(), empty-value, Python literal, + chanloc serialization, topographic channel selection, boundary-event + detection, and ConsoleEegh history mutation copiesFixed + in PR #214 / issue #208Shared low-level contracts were centralized + before later phases consumed them.
Rejection-family browser plumbing, epoched rejection scaffolds, + component activation access, ICA finalization, and clean_rawdata + channel-removal masksFixed in PR #215 / issue #209These + scientific rejection/ICA/cleaning helpers were consolidated in the phase + that could test visual rejection state and numerical side effects + together.
CLI transform/pipeline duplication, JSON detection, stale + per-module harness contracts, extension catalog split, entry-point + helpers, active-record predicate, and bundled-plugin metadataFixed + in PR #218 / issue #211CLI and extension ownership was + consolidated in one agent-facing architecture phase.
FIR helper ownership and FIR GUI band-edge/shape duplicationFixed + in PR #216 / issue #212plugins/firfilt now owns FIR + design helpers, and clean_rawdata imports downward from firfilt.
GUI extension pop-result STUDY return handling from #197/#205Non-goal + for #213The earlier tracker identified this as behavior-changing + because GUI extension pop_* functions returning STUDY state + would expand observable session behavior. No #213 work touched extension + GUI result semantics; this remains outside a mechanical closeout split + unless a future behavior issue requests it.
All other #197 findings not named aboveSuperseded by epic + #207 phase split#197 and #205 were closed as superseded by the + replacement architecture closeout epic. PRs #214, #215, #217, #218, #216, + and this #213 branch are the final accounting set for those remaining + findings.
+

Issue #213 Verification Notes

+ +

Issue #212 BIDS, Qt, and FIR Ownership Notes

+

Issue #212 closes the Phase 5 architecture ownership work for BIDS + import helpers, Qt dialog rendering helpers, and FIR design helpers while + preserving standalone EEGPrep runtime behavior.

+

Issue #212 Design Decisions

+ +

Issue #212 Verification Notes

+

Issue #164 EEGLAB-Style Sphinx Documentation Notes

Issue #164 turns the final-epic documentation into a coherent standalone EEGPrep manual modeled after EEGLAB's learning path but written for Python, @@ -566,5 +698,28 @@

Tradeoffs

dialog, while the console action boundary keeps command echo and progress output ordered for mixed GUI-plus-console workflows. +

Issue #211 CLI And Extension Architecture Notes

+

Design Decisions

+ +

Verification Notes

+ diff --git a/docs/source/api/extensions.rst b/docs/source/api/extensions.rst index 9ab811b5..62e3ee11 100644 --- a/docs/source/api/extensions.rst +++ b/docs/source/api/extensions.rst @@ -62,11 +62,12 @@ external extension contributions follow one status and lazy-loading model. Catalog and Governance ====================== -Catalog metadata validation lives in ``eegprep.extension_catalog`` and is also -available as the ``eegprep-validate-extension-catalog`` console script. Static -validation checks JSON schema version, required metadata, naming, URLs, license, -maintainer contact, docs, conflicts, curation status, and compatibility fields -without requiring the extension package to be installed. +Extension Manager catalog loading lives in ``eegprep.extension_catalog``. Catalog +submission validation lives in ``eegprep.extension_catalog_validation`` and is +also available as the ``eegprep-validate-extension-catalog`` console script. +Static validation checks JSON schema version, required metadata, naming, URLs, +license, maintainer contact, docs, conflicts, curation status, and compatibility +fields without requiring the extension package to be installed. Stricter validation can also check installed package versions, required dependencies, the ``eegprep.extensions`` entry point, import failures, and @@ -175,6 +176,11 @@ API Reference :members: :undoc-members: +.. automodule:: eegprep.extension_catalog_validation + :no-index: + :members: + :undoc-members: + .. automodule:: eegprep.extension_testing :no-index: :members: diff --git a/docs/source/user_guide/agent_cli.rst b/docs/source/user_guide/agent_cli.rst index f97fb9df..be8a0488 100644 --- a/docs/source/user_guide/agent_cli.rst +++ b/docs/source/user_guide/agent_cli.rst @@ -110,6 +110,13 @@ planned, reviewed, and rerun. eegprep pipeline run preprocess.yaml --json eegprep batch run sub-01.set sub-02.set --pipeline preprocess.yaml --output-dir derivatives/eegprep --json +Pipeline transform steps use the same defaults as the matching direct CLI +commands. In particular, ``clean`` defaults to ASR burst correction with +``burst_criterion: 20`` and leaves flatline, channel, line-noise, window, and +high-pass cleaning criteria off unless the YAML config sets them explicitly. +Set ``highpass`` as a two-value transition band, for example +``highpass: [0.25, 0.75]``. + QC And Reports ============== diff --git a/docs/source/user_guide/study_workflows.rst b/docs/source/user_guide/study_workflows.rst index e41b1fe5..7d183f9c 100644 --- a/docs/source/user_guide/study_workflows.rst +++ b/docs/source/user_guide/study_workflows.rst @@ -107,6 +107,9 @@ cached arrays. Cached measure fields follow EEGLAB names such as ``erpdata``, ``specdata``, ``erspdata``, and ``itcdata``. The selected ``design`` is recorded in each measure group's metadata. EEGPrep stores dataset-level averages in the current standalone cache rather than EEGLAB sidecar measure files. +``pop_chanplot`` reads cached channel and component measures through the same +``std_readdata``/``std_erpplot``/``std_erspplot`` cache contract used by scripts, +so GUI and console plots slice axes and cached channel groups consistently. Use ``std_checkfiles``, ``std_checkdatasession``, ``std_uniformfiles``, and ``std_uniformsetinds`` to audit loaded dataset consistency and cached measure @@ -129,10 +132,12 @@ Select datasets or trials from STUDY metadata: ["target"], ) -These helpers return EEGLAB-facing 1-based dataset and trial indices. Use -``std_substudy`` or ``std_rmdat`` when a workflow needs to remove datasets; -EEGPrep remaps STUDY references and invalidates cached measure arrays after -membership changes. +These helpers return EEGLAB-facing 1-based dataset and trial indices. Trial +metadata may be stored as row dictionaries or as EEGLAB-loaded columnar +``{"factor": [values...]}`` dictionaries; STUDY selectors normalize both forms +before matching factor levels and numerical ranges. Use ``std_substudy`` or +``std_rmdat`` when a workflow needs to remove datasets; EEGPrep remaps STUDY +references and invalidates cached measure arrays after membership changes. ``std_findsameica`` groups matching ICA decompositions within each subject. This preserves the subject boundary used by STUDY designs instead of merging @@ -153,6 +158,12 @@ Precluster and cluster ICA components: STUDY, com = pop_clust(STUDY, ALLEEG, clus_num=4, random_state=0, return_com=True) STUDY, com, fig = pop_clustedit(STUDY, ALLEEG, action="plot", return_com=True) +Finite ``outliers`` values in ``pop_clust`` use the EEGLAB-compatible +``robust_kmeans`` path and record ``["robust_kmeans", clus_num]`` in each +created cluster's algorithm provenance. Infinite ``outliers`` keeps the plain +k-means path. In both cases command history remains pasteable in +``eegprep-console``. + DIPFIT Source Localization ========================== @@ -223,6 +234,14 @@ EEGPrep-owned ``pacdata``, ``pactimes``, and ``pacfreqs`` caches on plots their magnitude using the same cache-reading contract as the other STUDY measure plots. +Empirical p-value conventions are explicit. ``pac`` applies the common +``(exceedances + 1) / (permutations + 1)`` finite-sample convention, +``bootstat`` exposes exact permutation proportions for bootstrap helper tests, +and statistics helpers such as ``stat_surrogate_pvals`` follow EEGLAB's +surrogate-tail convention with FDR correction available through the statistics +module. These definitions are intentionally not collapsed when they answer +different inferential questions. + The feasible in-package LIMO-compatible layer is design preparation: ``std_limodesign`` builds categorical and continuous matrices from ``pop_listfactors`` output and trial metadata, including interaction and split diff --git a/pyproject.toml b/pyproject.toml index 074ac85c..b707724e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,7 +77,7 @@ all = [ eegprep = "eegprep.cli.main:main" eegprep-gui = "eegprep.functions.adminfunc.eeglab:main" eegprep-console = "eegprep.functions.adminfunc.console:main" -eegprep-validate-extension-catalog = "eegprep.extension_catalog:main" +eegprep-validate-extension-catalog = "eegprep.extension_catalog_validation:main" [dependency-groups] dev = [ diff --git a/src/eegprep/__init__.py b/src/eegprep/__init__.py index dafd819d..52ad21d1 100644 --- a/src/eegprep/__init__.py +++ b/src/eegprep/__init__.py @@ -23,9 +23,9 @@ "EEGPrepSession": ("eegprep.functions.guifunc.session", "EEGPrepSession"), "EEGobj": ("eegprep.functions.eegobj.eegobj", "EEGobj"), "CATALOG_SCHEMA_VERSION": ("eegprep.extension_catalog", "CATALOG_SCHEMA_VERSION"), - "CatalogValidationIssue": ("eegprep.extension_catalog", "CatalogValidationIssue"), - "CatalogValidationOptions": ("eegprep.extension_catalog", "CatalogValidationOptions"), - "CatalogValidationReport": ("eegprep.extension_catalog", "CatalogValidationReport"), + "CatalogValidationIssue": ("eegprep.extension_catalog_validation", "CatalogValidationIssue"), + "CatalogValidationOptions": ("eegprep.extension_catalog_validation", "CatalogValidationOptions"), + "CatalogValidationReport": ("eegprep.extension_catalog_validation", "CatalogValidationReport"), "EXTENSION_COMPATIBILITY_POLICY": ("eegprep.extensions", "EXTENSION_COMPATIBILITY_POLICY"), "EXTENSION_API_VERSION": ("eegprep.extensions", "EXTENSION_API_VERSION"), "EXTENSION_CURATION_POLICY_URL": ("eegprep.extensions", "EXTENSION_CURATION_POLICY_URL"), @@ -141,7 +141,7 @@ "lat2point": ("eegprep.functions.redefine_functions", "lat2point"), "listdlg2": ("eegprep.functions.guifunc.listdlg2", "listdlg2"), "minphaserceps": ("eegprep.plugins.firfilt.minphaserceps", "minphaserceps"), - "load_catalog_entries": ("eegprep.extension_catalog", "load_catalog_entries"), + "load_catalog_entries": ("eegprep.extension_catalog_validation", "load_catalog_entries"), "load_extension_catalog": ("eegprep.extension_catalog", "load_extension_catalog"), "loadset": ("eegprep.functions.popfunc.pop_loadset", "loadset"), "mne2eeg": ("eegprep.functions.redefine_functions", "mne2eeg"), @@ -392,8 +392,8 @@ "timewarp": ("eegprep.functions.timefreqfunc.timewarp", "timewarp"), "topoplot": ("eegprep.functions.sigprocfunc.topoplot", "topoplot"), "validate_extension_spec": ("eegprep.extensions", "validate_extension_spec"), - "validate_catalog_entries": ("eegprep.extension_catalog", "validate_catalog_entries"), - "validate_catalog_file": ("eegprep.extension_catalog", "validate_catalog_file"), + "validate_catalog_entries": ("eegprep.extension_catalog_validation", "validate_catalog_entries"), + "validate_catalog_file": ("eegprep.extension_catalog_validation", "validate_catalog_file"), "writelocs": ("eegprep.functions.sigprocfunc.writelocs", "writelocs"), } diff --git a/src/eegprep/cli/commands/bids.py b/src/eegprep/cli/commands/bids.py index 32a1444f..78a49fad 100644 --- a/src/eegprep/cli/commands/bids.py +++ b/src/eegprep/cli/commands/bids.py @@ -4,12 +4,11 @@ import argparse import importlib.util -import json import sys from pathlib import Path from typing import Any -from eegprep.cli.core import build_manifest, json_safe, now_iso, output_path, sha256_file, write_manifest_file +from eegprep.cli.core import build_manifest, now_iso, output_path, sha256_file, write_manifest_file from eegprep.cli.dataset import dataset_summary, load_dataset from eegprep.functions.popfunc.pop_saveset import pop_saveset from eegprep.plugins.EEG_BIDS.bids_list_eeg_files import bids_list_eeg_files @@ -243,14 +242,10 @@ def export_dataset( def main(argv: list[str] | None = None) -> int: - """Standalone module harness for tests and local debugging.""" - parser = argparse.ArgumentParser(prog="eegprep bids") - subparsers = parser.add_subparsers(dest="command", required=True) - register(subparsers) - args = parser.parse_args(["bids", *(sys.argv[1:] if argv is None else argv)]) - result = args.handler(args) - print(json.dumps(json_safe(result), sort_keys=True)) - return 0 if result.get("status") in {"ok", "warning"} else int(result.get("exit_code", 1)) + """Route module execution through the canonical top-level CLI dispatcher.""" + from eegprep.cli.main import main as cli_main + + return cli_main(["bids", *(sys.argv[1:] if argv is None else argv)]) def _file_type(path: Path) -> str: diff --git a/src/eegprep/cli/commands/migrate.py b/src/eegprep/cli/commands/migrate.py index 3cbd6e34..aa2877e5 100644 --- a/src/eegprep/cli/commands/migrate.py +++ b/src/eegprep/cli/commands/migrate.py @@ -3,7 +3,6 @@ from __future__ import annotations import argparse -import json import re import sys from pathlib import Path @@ -12,7 +11,7 @@ import numpy as np import yaml -from eegprep.cli.core import EEGPrepCLIError, json_safe, output_path +from eegprep.cli.core import EEGPrepCLIError, output_path from eegprep.cli.dataset import load_dataset @@ -218,14 +217,10 @@ def convert_script( def main(argv: list[str] | None = None) -> int: - """Standalone module harness for tests and local debugging.""" - parser = argparse.ArgumentParser(prog="eegprep migrate") - subparsers = parser.add_subparsers(dest="command", required=True) - register(subparsers) - args = parser.parse_args(["migrate", *(sys.argv[1:] if argv is None else argv)]) - result = args.handler(args) - print(json.dumps(json_safe(result), sort_keys=True)) - return 0 if result.get("status") in {"ok", "warning"} else 1 + """Route module execution through the canonical top-level CLI dispatcher.""" + from eegprep.cli.main import main as cli_main + + return cli_main(["migrate", *(sys.argv[1:] if argv is None else argv)]) def _history_lines(history_text: Any) -> list[str]: diff --git a/src/eegprep/cli/commands/pipeline.py b/src/eegprep/cli/commands/pipeline.py index 7b22548b..8d999ca2 100644 --- a/src/eegprep/cli/commands/pipeline.py +++ b/src/eegprep/cli/commands/pipeline.py @@ -12,13 +12,13 @@ import yaml +from eegprep.cli.commands import transforms as transform_commands from eegprep.cli.commands.qc import compute_qc_metrics from eegprep.cli.core import ( EEGPrepCLIError as CommandError, build_manifest, command_error as error_result, command_ok as success_result, - emit_command_result as emit_result, file_sha256, utc_now, write_json_file, @@ -26,13 +26,7 @@ ) from eegprep.cli.dataset import load_dataset as load_eeg_dataset from eegprep.cli.reporting import write_report_html -from eegprep.functions.popfunc.pop_epoch import pop_epoch -from eegprep.functions.popfunc.pop_reref import pop_reref -from eegprep.functions.popfunc.pop_resample import pop_resample -from eegprep.functions.popfunc.pop_runica import pop_runica from eegprep.functions.popfunc.pop_saveset import pop_saveset -from eegprep.plugins.clean_rawdata.pop_clean_rawdata import pop_clean_rawdata -from eegprep.plugins.firfilt.pop_eegfiltnew import pop_eegfiltnew PIPELINE_SCHEMA_VERSION = "eegprep.pipeline.v1" @@ -165,7 +159,7 @@ def run_pipeline_config( def register(subparsers: argparse._SubParsersAction[argparse.ArgumentParser]) -> argparse.ArgumentParser: """Register ``pipeline`` commands with an argparse dispatcher.""" parser = subparsers.add_parser("pipeline", help="Validate, plan, or run YAML EEGPrep pipelines.") - pipeline_subparsers = parser.add_subparsers(dest="pipeline_action", required=True, parser_class=type(parser)) + pipeline_subparsers = parser.add_subparsers(dest="pipeline_action", required=True) validate_parser = pipeline_subparsers.add_parser("validate", help="Validate a pipeline YAML file.") validate_parser.add_argument("config", help="Pipeline YAML config") @@ -209,13 +203,10 @@ def handle_registered(args: argparse.Namespace) -> dict[str, Any]: def main(argv: list[str] | None = None) -> int: - """Standalone module entry point for local command testing.""" - parser = argparse.ArgumentParser(prog="eegprep pipeline") - subparsers = parser.add_subparsers(dest="pipeline_action", required=True) - register(subparsers) - args = parser.parse_args(["pipeline", *(sys.argv[1:] if argv is None else argv)]) - result = handle_registered(args) - return emit_result(result, json_output=True) + """Route module execution through the canonical top-level CLI dispatcher.""" + from eegprep.cli.main import main as cli_main + + return cli_main(["pipeline", *(sys.argv[1:] if argv is None else argv)]) def _load_normalized_config(config_path: str | Path) -> tuple[dict[str, Any], list[dict[str, Any]]]: @@ -397,6 +388,8 @@ def _validate_step_parameters( for key in ("burst_criterion", "channel_criterion", "line_noise_criterion", "window_criterion"): if key in parameters and parameters[key] != "off": _positive_number(parameters[key], f"{path}.{key}", errors) + if "highpass" in parameters and parameters["highpass"] not in (None, "", "off"): + _clean_highpass(parameters["highpass"], f"{path}.highpass", errors) elif name == "epoch": if not (parameters.get("event_type") or parameters.get("event_types")): _add_error(errors, f"{path}.event_type", "epoch requires event_type or event_types.") @@ -476,23 +469,10 @@ def _run_step( ) -> tuple[dict[str, Any], str, list[dict[str, Any]], dict[str, Any] | None]: name = step["name"] parameters = step["parameters"] - if name == "filter": - EEG, history = _apply_filter(EEG, parameters) - return EEG, history, [], latest_qc - if name == "rereference": - EEG, history = _apply_rereference(EEG, parameters) - return EEG, history, [], latest_qc - if name == "resample": - EEG, history = _apply_resample(EEG, parameters) - return EEG, history, [], latest_qc - if name == "clean": - EEG, history = _apply_clean(EEG, parameters) - return EEG, history, [], latest_qc - if name == "epoch": - EEG, history = _apply_epoch(EEG, parameters) - return EEG, history, [], latest_qc - if name == "ica": - EEG, history = _apply_ica(EEG, parameters) + if name in MUTATING_STEP_NAMES: + result = transform_commands.apply_transform(EEG, name, parameters) + EEG = result.eeg + history = result.history return EEG, history, [], latest_qc if name == "qc": qc_payload = compute_qc_metrics(EEG, dataset_path=config["input"]["path"]) @@ -512,85 +492,6 @@ def _run_step( raise CommandError("COMMAND_NOT_IMPLEMENTED", f"Pipeline step is not implemented: {name}") -def _apply_filter(EEG: dict[str, Any], parameters: dict[str, Any]) -> tuple[dict[str, Any], str]: - histories = [] - highpass = parameters.get("highpass") - lowpass = parameters.get("lowpass") - if highpass is not None or lowpass is not None: - EEG, history = pop_eegfiltnew( - EEG, - locutoff=highpass, - hicutoff=lowpass, - plotfreqz=False, - gui=False, - return_com=True, - ) - histories.append(history) - notch = parameters.get("notch") - if notch is not None: - width = float(parameters.get("notch_width") or 2.0) - lower_edge = float(notch) - width / 2 - if lower_edge <= 0: - raise CommandError( - "CONFIG_SCHEMA_ERROR", - "notch minus half notch_width must be positive.", - path="steps[].notch", - suggestion="Increase notch or decrease notch_width so the notch stop band stays above 0 Hz.", - ) - EEG, history = pop_eegfiltnew( - EEG, - locutoff=lower_edge, - hicutoff=float(notch) + width / 2, - revfilt=True, - plotfreqz=False, - gui=False, - return_com=True, - ) - histories.append(history) - return EEG, "\n".join(item for item in histories if item) - - -def _apply_rereference(EEG: dict[str, Any], parameters: dict[str, Any]) -> tuple[dict[str, Any], str]: - method = str(parameters.get("method") or "average").lower() - ref = [] if method == "average" else _channel_indices(parameters.get("channels", parameters.get("ref"))) - return pop_reref(EEG, ref, gui=False, return_com=True) - - -def _apply_resample(EEG: dict[str, Any], parameters: dict[str, Any]) -> tuple[dict[str, Any], str]: - return pop_resample(EEG, float(parameters["freq"]), gui=False, return_com=True) - - -def _apply_clean(EEG: dict[str, Any], parameters: dict[str, Any]) -> tuple[dict[str, Any], str]: - clean_kwargs: dict[str, Any] = {} - key_map = { - "burst_criterion": "BurstCriterion", - "channel_criterion": "ChannelCriterion", - "line_noise_criterion": "LineNoiseCriterion", - "window_criterion": "WindowCriterion", - "flatline_criterion": "FlatlineCriterion", - "highpass": "Highpass", - } - for source, target in key_map.items(): - if source in parameters: - clean_kwargs[target] = parameters[source] - return pop_clean_rawdata(EEG, gui=False, return_com=True, **clean_kwargs) - - -def _apply_epoch(EEG: dict[str, Any], parameters: dict[str, Any]) -> tuple[dict[str, Any], str]: - event_types = parameters.get("event_types", parameters.get("event_type")) - limits = [float(parameters["tmin"]), float(parameters["tmax"])] - return pop_epoch(EEG, event_types, limits, gui=False, return_com=True) - - -def _apply_ica(EEG: dict[str, Any], parameters: dict[str, Any]) -> tuple[dict[str, Any], str]: - method = str(parameters.get("method") or "runica").lower() - options = {} - for key in ("seed", "maxsteps", "extended", "lrate"): - if key in parameters: - options[key] = parameters[key] - return pop_runica(EEG, icatype=method, options=options or None, gui=False, return_com=True) - - def _save_dataset(EEG: dict[str, Any], output_path: str | Path, *, history: str) -> list[dict[str, Any]]: if history: existing_history = str(EEG.get("history") or "") @@ -669,33 +570,26 @@ def _resolve_path(value: str | Path, base_dir: Path) -> Path: def _channel_indices(value: Any) -> list[Any]: - values = value if isinstance(value, list | tuple) else [value] - refs = [] - for item in values: - if isinstance(item, int | float) and float(item).is_integer(): - if int(item) < 1: - raise CommandError( - "CONFIG_SCHEMA_ERROR", - "Pipeline channel indices are EEGLAB-facing 1-based values.", - path="steps[].channels", - suggestion="Use channel index 1 for the first channel.", - ) - refs.append(int(item) - 1) - continue - text = str(item) - if text.isdecimal(): - number = int(text) - if number < 1: - raise CommandError( - "CONFIG_SCHEMA_ERROR", - "Pipeline channel indices are EEGLAB-facing 1-based values.", - path="steps[].channels", - suggestion="Use channel index 1 for the first channel.", - ) - refs.append(number - 1) - else: - refs.append(item) - return refs + return transform_commands.channel_tokens( + value if isinstance(value, list | tuple) else [value], + numeric_base=1, + output_base=0, + path="steps[].channels", + ) + + +def _clean_highpass(value: Any, path: str, errors: list[dict[str, Any]]) -> None: + if isinstance(value, str): + values: Any = [item for item in value.replace(",", " ").split() if item] + else: + values = value + if not isinstance(values, list | tuple) or len(values) != 2: + _add_error(errors, path, "clean highpass must contain [low high] transition band values.") + return + low = _positive_number(values[0], f"{path}[0]", errors) + high = _positive_number(values[1], f"{path}[1]", errors) + if low is not None and high is not None and low >= high: + _add_error(errors, path, "clean highpass values must be ordered as [low, high].") def _number(value: Any, path: str, errors: list[dict[str, Any]], *, required: bool = False) -> float | None: diff --git a/src/eegprep/cli/commands/qc.py b/src/eegprep/cli/commands/qc.py index 056041e4..8130bb36 100644 --- a/src/eegprep/cli/commands/qc.py +++ b/src/eegprep/cli/commands/qc.py @@ -6,7 +6,7 @@ import math import sys from pathlib import Path -from typing import Any +from typing import Any, NoReturn import numpy as np @@ -15,7 +15,6 @@ build_manifest, command_error as error_result, command_ok as success_result, - emit_command_result as emit_result, file_sha256, json_safe, utc_now, @@ -28,6 +27,24 @@ QC_SCHEMA_VERSION = "eegprep.qc.v1" +class _QCArgumentParser(argparse.ArgumentParser): + """Nested qc parser that can defer JSON-mode errors to the top-level CLI.""" + + def __init__(self, *args: Any, json_requested: bool = False, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.json_requested = json_requested + + def error(self, message: str) -> NoReturn: + if self.json_requested: + raise CommandError( + "CONFIG_SCHEMA_ERROR", + message, + suggestion="Run the command with --help or inspect schema command qc.", + exit_code=2, + ) + super().error(message) + + def compute_qc_metrics(EEG: dict[str, Any], *, dataset_path: str | Path | None = None) -> dict[str, Any]: """Compute deterministic QC metrics and recommendation codes for an EEG dict.""" data = np.asarray(EEG.get("data", np.array([]))) @@ -249,30 +266,32 @@ def handle_registered(args: argparse.Namespace) -> dict[str, Any]: def dispatch(argv: list[str] | tuple[str, ...] | None = None) -> dict[str, Any]: """Dispatch qc command arguments without assuming a global CLI foundation.""" argv = list(argv or []) + json_requested = "--json" in argv if argv and argv[0] == "report": - parser = _report_parser() + parser = _report_parser(json_requested=json_requested) args = parser.parse_args(argv[1:]) return qc_report_dataset(args.input, args.html, manifest_path=args.manifest, overwrite=args.overwrite) - parser = _qc_parser() + parser = _qc_parser(json_requested=json_requested) args = parser.parse_args(argv) return qc_dataset(args.input) def main(argv: list[str] | None = None) -> int: - """Standalone module entry point for local command testing.""" - result = dispatch(sys.argv[1:] if argv is None else argv) - return emit_result(result, json_output=True) + """Route module execution through the canonical top-level CLI dispatcher.""" + from eegprep.cli.main import main as cli_main + + return cli_main(["qc", *(sys.argv[1:] if argv is None else argv)]) -def _qc_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(prog="eegprep qc") +def _qc_parser(*, json_requested: bool = False) -> argparse.ArgumentParser: + parser = _QCArgumentParser(prog="eegprep qc", json_requested=json_requested) parser.add_argument("input", help="Input EEGLAB .set dataset") parser.add_argument("--json", action="store_true", help="Emit structured JSON") return parser -def _report_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(prog="eegprep qc report") +def _report_parser(*, json_requested: bool = False) -> argparse.ArgumentParser: + parser = _QCArgumentParser(prog="eegprep qc report", json_requested=json_requested) parser.add_argument("input", help="Input EEGLAB .set dataset") parser.add_argument("--html", required=True, help="Output HTML report path") parser.add_argument("--manifest", help="Optional manifest JSON path") diff --git a/src/eegprep/cli/commands/report.py b/src/eegprep/cli/commands/report.py index 4b8acf10..be9fdbaf 100644 --- a/src/eegprep/cli/commands/report.py +++ b/src/eegprep/cli/commands/report.py @@ -13,7 +13,6 @@ build_manifest, command_error as error_result, command_ok as success_result, - emit_command_result as emit_result, file_sha256, utc_now, write_manifest_file, @@ -95,16 +94,10 @@ def handle_registered(args: argparse.Namespace) -> dict[str, Any]: def main(argv: list[str] | None = None) -> int: - """Standalone module entry point for local command testing.""" - parser = argparse.ArgumentParser(prog="eegprep report") - parser.add_argument("input", help="Input EEGLAB .set dataset") - parser.add_argument("--output", required=True, help="Output HTML report path") - parser.add_argument("--manifest", help="Optional manifest JSON path") - parser.add_argument("--overwrite", action="store_true", help="Overwrite existing output files") - parser.add_argument("--json", action="store_true", help="Emit structured JSON") - args = parser.parse_args(sys.argv[1:] if argv is None else argv) - result = report_dataset(args.input, args.output, manifest_path=args.manifest, overwrite=args.overwrite) - return emit_result(result, json_output=True) + """Route module execution through the canonical top-level CLI dispatcher.""" + from eegprep.cli.main import main as cli_main + + return cli_main(["report", *(sys.argv[1:] if argv is None else argv)]) if __name__ == "__main__": diff --git a/src/eegprep/cli/commands/transforms.py b/src/eegprep/cli/commands/transforms.py index fc3f1813..368a6ae4 100644 --- a/src/eegprep/cli/commands/transforms.py +++ b/src/eegprep/cli/commands/transforms.py @@ -8,6 +8,7 @@ from __future__ import annotations import argparse +from collections.abc import Mapping from contextlib import redirect_stdout from dataclasses import dataclass import logging @@ -20,9 +21,7 @@ from eegprep.cli.core import ( EEGPrepCLIError, build_manifest, - command_error, file_sha256, - print_result, utc_now, write_manifest_file, ) @@ -96,33 +95,10 @@ def run_transform_command(args: argparse.Namespace) -> dict[str, Any]: def main(argv: list[str] | None = None) -> int: - """Standalone module entry point for local transform testing.""" + """Route module execution through the canonical top-level CLI dispatcher.""" + from eegprep.cli.main import main as cli_main - parser = build_parser() - args = parser.parse_args(argv) - try: - result = run_transform_command(args) - except CliTransformError as exc: - payload = command_error(getattr(args, "transform_command", "transform"), exc) - if getattr(args, "json", False): - print_result(payload, as_json=True) - else: - print(f"{exc.code}: {exc.message}", file=sys.stderr) - return exc.exit_code - except Exception as exc: - error = CliTransformError("TRANSFORM_FAILED", str(exc)) - payload = command_error(getattr(args, "transform_command", "transform"), error) - if getattr(args, "json", False): - print_result(payload, as_json=True) - else: - print(f"{error.code}: {error.message}", file=sys.stderr) - return 1 - - if getattr(args, "json", False): - print_result(result, as_json=True) - else: - print(result["output"]["path"]) - return 0 + return cli_main(sys.argv[1:] if argv is None else argv) def _run_transform_command(args: argparse.Namespace) -> dict[str, Any]: @@ -137,7 +113,7 @@ def _run_transform_command(args: argparse.Namespace) -> dict[str, Any]: input_files = _dataset_file_records(input_path, eeg) logger.info("Running %s", args.transform_command) - result = _run_loaded_transform(eeg, args) + result = apply_loaded_transform(eeg, args) if result.history: _append_history(result.eeg, result.history) @@ -171,7 +147,13 @@ def _run_transform_command(args: argparse.Namespace) -> dict[str, Any]: } -def _run_loaded_transform(eeg: dict[str, Any], args: argparse.Namespace) -> TransformResult: +def apply_transform(eeg: dict[str, Any], command: str, parameters: Mapping[str, Any] | None = None) -> TransformResult: + """Apply a transform to an already-loaded EEG dict using CLI-equivalent defaults.""" + return apply_loaded_transform(eeg, transform_args(command, parameters or {})) + + +def apply_loaded_transform(eeg: dict[str, Any], args: argparse.Namespace) -> TransformResult: + """Apply a parsed transform command to an already-loaded EEG dict.""" command = args.transform_command if command == "resample": return _resample(eeg, args) @@ -188,6 +170,81 @@ def _run_loaded_transform(eeg: dict[str, Any], args: argparse.Namespace) -> Tran raise CliTransformError("COMMAND_NOT_IMPLEMENTED", f"Transform command is not implemented: {command}") +def transform_args(command: str, parameters: Mapping[str, Any]) -> argparse.Namespace: + """Return an argparse namespace matching the direct transform subcommand defaults.""" + normalized = _normalize_transform_command(command) + params = dict(parameters) + common: dict[str, Any] = { + "transform_command": normalized, + "input": "", + "output": None, + "manifest": None, + "overwrite": False, + "json": False, + "quiet": False, + "verbose": False, + "no_progress": False, + } + if normalized == "resample": + return argparse.Namespace(**common, freq=params.get("freq"), engine=params.get("engine") or "poly") + if normalized == "rereference": + return argparse.Namespace( + **common, + method=str(params.get("method") or "average").lower(), + channels=_list_or_none(params.get("channels", params.get("ref"))), + exclude=_list_or_none(params.get("exclude")), + keep_ref=bool(params.get("keep_ref", params.get("keepref", False))), + huber=params.get("huber"), + refica=params.get("refica") or "on", + ) + if normalized == "filter": + return argparse.Namespace( + **common, + highpass=params.get("highpass"), + lowpass=params.get("lowpass"), + notch=params.get("notch"), + notch_width=params.get("notch_width", 2.0), + order=params.get("order"), + minphase=bool(params.get("minphase", False)), + usefftfilt=bool(params.get("usefftfilt", False)), + ) + if normalized == "clean": + return argparse.Namespace( + **common, + method=str(params.get("method") or "asr").lower(), + burst_criterion=params.get("burst_criterion", 20.0), + burst_rejection=bool(params.get("burst_rejection", False)), + distance=str(params.get("distance") or "euclidean").lower(), + flatline_criterion=params.get("flatline_criterion"), + channel_criterion=params.get("channel_criterion"), + line_noise_criterion=params.get("line_noise_criterion"), + window_criterion=params.get("window_criterion"), + highpass=params.get("highpass"), + ) + if normalized == "epoch": + return argparse.Namespace( + **common, + event_type=_list_or_empty(params.get("event_type", params.get("event_types"))), + tmin=params.get("tmin"), + tmax=params.get("tmax"), + new_name=params.get("new_name", params.get("newname")), + ) + if normalized == "ica": + return argparse.Namespace( + **common, + method=str(params.get("method") or "runica").lower(), + seed=params.get("seed"), + deterministic=bool(params.get("deterministic", True)), + maxsteps=params.get("maxsteps"), + pca=params.get("pca"), + extended=params.get("extended"), + channels=_list_or_none(params.get("channels")), + option=_ica_option_list(params), + reorder=bool(params.get("reorder", True)), + ) + raise CliTransformError("COMMAND_NOT_IMPLEMENTED", f"Transform command is not implemented: {command}") + + def _resample(eeg: dict[str, Any], args: argparse.Namespace) -> TransformResult: output, history = pop_resample(eeg, args.freq, engine=args.engine, gui=False, return_com=True) return TransformResult( @@ -207,11 +264,11 @@ def _rereference(eeg: dict[str, Any], args: argparse.Namespace) -> TransformResu "rereference --method channels requires --channels.", suggestion="Pass channel labels or EEGLAB-facing 1-based channel indices.", ) - ref = _channel_tokens(args.channels, numeric_base=1, output_base=0) + ref = channel_tokens(args.channels, numeric_base=1, output_base=0) kwargs: dict[str, Any] = {"refica": args.refica} if args.exclude: - kwargs["exclude"] = _channel_tokens(args.exclude, numeric_base=1, output_base=0) + kwargs["exclude"] = channel_tokens(args.exclude, numeric_base=1, output_base=0) if args.keep_ref: kwargs["keepref"] = "on" if args.huber is not None: @@ -261,7 +318,7 @@ def _filter(eeg: dict[str, Any], args: argparse.Namespace) -> TransformResult: lower_edge = float(args.notch) - notch_width / 2.0 upper_edge = float(args.notch) + notch_width / 2.0 if lower_edge <= 0: - raise CliTransformError("CONFIG_SCHEMA_ERROR", "--notch minus half --notch-width must be positive.") + raise CliTransformError("CONFIG_SCHEMA_ERROR", "notch minus half notch_width must be positive.") output, history = pop_eegfiltnew( output, locutoff=lower_edge, @@ -339,7 +396,7 @@ def _epoch(eeg: dict[str, Any], args: argparse.Namespace) -> TransformResult: def _ica(eeg: dict[str, Any], args: argparse.Namespace) -> TransformResult: options = _ica_options(args) - chanind = _channel_tokens(args.channels, numeric_base=1, output_base=1) if args.channels else None + chanind = channel_tokens(args.channels, numeric_base=1, output_base=1) if args.channels else None method = "runamica15" if args.method == "amica" else args.method warnings = [] if args.seed is not None and method != "runica": @@ -374,12 +431,13 @@ def _ica(eeg: dict[str, Any], args: argparse.Namespace) -> TransformResult: def _clean_options(args: argparse.Namespace) -> dict[str, Any]: distance = "Riemannian" if args.distance == "riemannian" else "Euclidean" + highpass = _clean_highpass(args.highpass) if args.method == "asr": return { "FlatlineCriterion": _optional_float_or_off(args.flatline_criterion, default="off"), "ChannelCriterion": _optional_float_or_off(args.channel_criterion, default="off"), "LineNoiseCriterion": _optional_float_or_off(args.line_noise_criterion, default="off"), - "Highpass": list(args.highpass) if args.highpass else "off", + "Highpass": highpass, "BurstCriterion": float(args.burst_criterion), "BurstRejection": bool(args.burst_rejection), "WindowCriterion": _optional_float_or_off(args.window_criterion, default="off"), @@ -393,8 +451,8 @@ def _clean_options(args: argparse.Namespace) -> dict[str, Any]: options["ChannelCriterion"] = float(args.channel_criterion) if args.line_noise_criterion is not None: options["LineNoiseCriterion"] = float(args.line_noise_criterion) - if args.highpass: - options["Highpass"] = list(args.highpass) + if highpass != "off": + options["Highpass"] = highpass if args.burst_criterion is not None: options["BurstCriterion"] = float(args.burst_criterion) if args.window_criterion is not None: @@ -402,6 +460,37 @@ def _clean_options(args: argparse.Namespace) -> dict[str, Any]: return options +def _clean_highpass(value: Any) -> list[float] | str: + if value in (None, "", "off"): + return "off" + if isinstance(value, str): + values: Any = [item for item in value.replace(",", " ").split() if item] + else: + values = value + if not isinstance(values, list | tuple) or len(values) != 2: + raise CliTransformError( + "CONFIG_SCHEMA_ERROR", + "clean highpass must contain [low high] transition band values.", + path="highpass", + suggestion="Use --highpass LOW HIGH or highpass: [LOW, HIGH] in pipeline config.", + ) + try: + low, high = (float(values[0]), float(values[1])) + except (TypeError, ValueError) as exc: + raise CliTransformError( + "CONFIG_SCHEMA_ERROR", + "clean highpass values must be numeric.", + path="highpass", + ) from exc + if low <= 0 or high <= 0 or low >= high: + raise CliTransformError( + "CONFIG_SCHEMA_ERROR", + "clean highpass must be positive and ordered as [low, high].", + path="highpass", + ) + return [low, high] + + def _ica_options(args: argparse.Namespace) -> dict[str, Any]: options: dict[str, Any] = {} if args.seed is not None: @@ -426,6 +515,41 @@ def _ica_options(args: argparse.Namespace) -> dict[str, Any]: return options +def _normalize_transform_command(command: str) -> str: + normalized = str(command).strip().lower() + aliases = {"reref": "rereference", "re_reference": "rereference"} + return aliases.get(normalized, normalized) + + +def _list_or_none(value: Any) -> list[Any] | None: + if value is None: + return None + if isinstance(value, list): + return value + if isinstance(value, tuple): + return list(value) + return [value] + + +def _list_or_empty(value: Any) -> list[Any]: + return _list_or_none(value) or [] + + +def _ica_option_list(parameters: Mapping[str, Any]) -> list[str]: + options: list[str] = [] + raw_options = parameters.get("options", parameters.get("option", [])) + if isinstance(raw_options, Mapping): + options.extend(f"{key}={value}" for key, value in raw_options.items()) + elif isinstance(raw_options, str): + options.append(raw_options) + else: + options.extend(str(item) for item in raw_options or []) + for key in ("lrate",): + if key in parameters: + options.append(f"{key}={parameters[key]}") + return options + + def _ica_is_deterministic(method: str, options: dict[str, Any], args: argparse.Namespace) -> bool: if method != "runica": return args.seed is not None @@ -650,11 +774,25 @@ def _append_history(eeg: dict[str, Any], command: str) -> None: eeg["history"] = f"{existing}\n{command}\n" if existing else f"{command}\n" -def _channel_tokens(values: list[str], *, numeric_base: int, output_base: int) -> list[Any]: +def channel_tokens( + values: list[Any] | tuple[Any, ...], + *, + numeric_base: int, + output_base: int, + path: str | None = None, +) -> list[Any]: + """Return parsed channel labels or converted EEGLAB-facing indices.""" tokens: list[Any] = [] for value in values: scalar = _parse_scalar(value) if isinstance(scalar, int): + if scalar < numeric_base: + raise CliTransformError( + "CONFIG_SCHEMA_ERROR", + "Channel indices are EEGLAB-facing 1-based values.", + path=path, + suggestion="Use channel index 1 for the first channel.", + ) tokens.append(scalar - numeric_base + output_base) else: tokens.append(scalar) diff --git a/src/eegprep/cli/main.py b/src/eegprep/cli/main.py index f8cf8e12..71e7847a 100644 --- a/src/eegprep/cli/main.py +++ b/src/eegprep/cli/main.py @@ -39,7 +39,13 @@ class EEGPrepArgumentParser(argparse.ArgumentParser): """ArgumentParser that can emit structured errors for JSON-mode agent calls.""" - json_requested = False + def __init__(self, *args: Any, json_requested: bool = False, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.json_requested = json_requested + + def add_subparsers(self, *args: Any, **kwargs: Any) -> argparse._SubParsersAction: + kwargs.setdefault("parser_class", self._child_parser_factory()) + return super().add_subparsers(*args, **kwargs) def error(self, message: str) -> NoReturn: if self.json_requested: @@ -51,21 +57,29 @@ def error(self, message: str) -> NoReturn: ) super().error(message) + def _child_parser_factory(self) -> Any: + json_requested = self.json_requested + + def factory(*args: Any, **kwargs: Any) -> "EEGPrepArgumentParser": + return type(self)(*args, json_requested=json_requested, **kwargs) + + return factory + def main(argv: list[str] | None = None) -> int: """Run the EEGPrep CLI.""" - parser = build_parser() + raw_argv = sys.argv[1:] if argv is None else list(argv) + requested_json = _argv_requests_json(raw_argv) + parser = build_parser(json_requested=requested_json) + args: argparse.Namespace | None = None try: - requested_json = "--json" in (sys.argv[1:] if argv is None else argv) - EEGPrepArgumentParser.json_requested = requested_json - parser.json_requested = requested_json - args = parser.parse_args(argv) + args = parser.parse_args(raw_argv) if not hasattr(args, "handler"): parser.print_help() return 0 result = args.handler(args) except EEGPrepCLIError as exc: - print_result(exc.to_response(), as_json=bool(getattr(parser, "json_requested", False))) + print_result(exc.to_response(), as_json=requested_json) return exc.exit_code except Exception as exc: code = getattr(exc, "code", None) @@ -81,7 +95,7 @@ def main(argv: list[str] | None = None) -> int: payload["path"] = str(getattr(exc, "path")) if getattr(exc, "suggestion", None) is not None: payload["suggestion"] = getattr(exc, "suggestion") - print_result(payload, as_json=_json_requested(args)) + print_result(payload, as_json=_json_requested(args, requested_json=requested_json)) return int(getattr(exc, "exit_code", 1) or 1) payload = { "status": "error", @@ -90,15 +104,15 @@ def main(argv: list[str] | None = None) -> int: "message": str(exc), "suggestion": "Rerun with --verbose or file an issue if this is reproducible.", } - print_result(payload, as_json=_json_requested(args)) - if bool(getattr(args, "verbose", False)): + print_result(payload, as_json=_json_requested(args, requested_json=requested_json)) + if args is not None and bool(getattr(args, "verbose", False)): raise return 1 - if isinstance(result, dict) and "_raw_text" in result and not _json_requested(args): + if isinstance(result, dict) and "_raw_text" in result and not _json_requested(args, requested_json=requested_json): print(result["_raw_text"], end="" if result["_raw_text"].endswith("\n") else "\n") else: - print_result(result, as_json=_json_requested(args)) + print_result(result, as_json=_json_requested(args, requested_json=requested_json)) if isinstance(result, dict) and result.get("status") == "error": return int(result.get("exit_code", 1) or 1) if isinstance(result, dict): @@ -106,16 +120,17 @@ def main(argv: list[str] | None = None) -> int: return 0 -def build_parser() -> EEGPrepArgumentParser: +def build_parser(*, json_requested: bool = False) -> EEGPrepArgumentParser: """Build the root parser.""" parser = EEGPrepArgumentParser( prog="eegprep", description="EEGLAB-compatible, Python-native EEG preprocessing CLI.", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=AGENT_EPILOG, + json_requested=json_requested, ) parser.add_argument("--version", action="store_true", help="Show EEGPrep version and exit.") - subparsers = parser.add_subparsers(dest="command", parser_class=EEGPrepArgumentParser) + subparsers = parser.add_subparsers(dest="command") _register_discovery(subparsers) _register_inspect(subparsers) @@ -138,7 +153,7 @@ def _register_discovery(subparsers: argparse._SubParsersAction) -> None: capabilities.set_defaults(handler=lambda _args: discovery.capabilities()) schema = subparsers.add_parser("schema", help="Print command or pipeline schemas.") - schema_sub = schema.add_subparsers(dest="schema_kind", required=True, parser_class=EEGPrepArgumentParser) + schema_sub = schema.add_subparsers(dest="schema_kind", required=True) pipeline = schema_sub.add_parser("pipeline", help="Print pipeline config schema.") pipeline.add_argument("--json", action="store_true") pipeline.set_defaults(handler=lambda _args: discovery.pipeline_schema()) @@ -153,7 +168,7 @@ def _register_discovery(subparsers: argparse._SubParsersAction) -> None: examples.set_defaults(handler=lambda args: discovery.examples(args.name)) skills = subparsers.add_parser("skills", help="List and retrieve bundled agent skill content.") - skills_sub = skills.add_subparsers(dest="skills_command", required=True, parser_class=EEGPrepArgumentParser) + skills_sub = skills.add_subparsers(dest="skills_command", required=True) skills_list = skills_sub.add_parser("list", help="List bundled CLI skills.") skills_list.add_argument("--json", action="store_true") skills_list.set_defaults(handler=lambda _args: discovery.skills_list()) @@ -225,14 +240,14 @@ def _handle_skill_path(args: argparse.Namespace) -> dict[str, Any]: return {"status": "ok", "schema_version": "eegprep.skills.path.v1", "name": args.name, "path": path} -def _json_requested(args: argparse.Namespace) -> bool: - if bool(getattr(args, "json", False)): +def _argv_requests_json(argv: list[str]) -> bool: + return "--json" in argv + + +def _json_requested(args: argparse.Namespace | None, *, requested_json: bool) -> bool: + if requested_json: return True - # Commands that consume their flags via argparse.REMAINDER (e.g. ``qc``) never bind a - # top-level ``--json`` on ``args``. The root parser already records whether ``--json`` - # appeared anywhere in argv, so consult that command-agnostic flag instead of - # introspecting any single subcommand's argument attribute. - return bool(EEGPrepArgumentParser.json_requested) + return bool(getattr(args, "json", False)) if __name__ == "__main__": diff --git a/src/eegprep/extension_catalog.py b/src/eegprep/extension_catalog.py index 770aa816..6f06ca35 100644 --- a/src/eegprep/extension_catalog.py +++ b/src/eegprep/extension_catalog.py @@ -3,9 +3,7 @@ This module loads the metadata-only catalog that the Extension Manager dialog and console inventory display, and builds copyable (never executed) install/update commands. The submission-curation CI validator lives in -``eegprep.extension_catalog_validation``; its public names are re-exported here so -the ``eegprep-validate-extension-catalog`` entry point and existing imports keep -working unchanged. +``eegprep.extension_catalog_validation``. """ from __future__ import annotations @@ -326,44 +324,18 @@ def _catalog_normalize_name(name: str) -> str: return str(name).strip().lower() -# The submission-curation CI validator lives in extension_catalog_validation. Its -# public names are re-exported here so the eegprep-validate-extension-catalog entry -# point and existing imports of these symbols from eegprep.extension_catalog keep -# working. The validator imports the shared catalog constants and _is_web_url from -# this module lazily, so this re-import stays one-directional at module-load time. -from eegprep.extension_catalog_validation import ( # noqa: E402 - CATALOG_CURATION_STATUSES, - CATALOG_REQUIRED_FIELDS, - CatalogValidationIssue, - CatalogValidationOptions, - CatalogValidationReport, - load_catalog_entries, - main, - validate_catalog_entries, - validate_catalog_file, -) - __all__ = [ - "CATALOG_CURATION_STATUSES", "CATALOG_KIND_CURATION", "CATALOG_KIND_MANAGER", - "CATALOG_REQUIRED_FIELDS", "CATALOG_ENV_VAR", "CATALOG_RESOURCE", "CATALOG_SCHEMA_VERSION", "INSTALL_TRUST_WARNING", "CatalogSourceType", - "CatalogValidationIssue", - "CatalogValidationOptions", - "CatalogValidationReport", "ExtensionCatalog", "ExtensionCatalogEntry", "build_safe_install_commands", "build_safe_update_commands", - "load_catalog_entries", "load_extension_catalog", "parse_extension_catalog", - "main", - "validate_catalog_entries", - "validate_catalog_file", ] diff --git a/src/eegprep/extension_catalog_validation.py b/src/eegprep/extension_catalog_validation.py index 331343bc..92250988 100644 --- a/src/eegprep/extension_catalog_validation.py +++ b/src/eegprep/extension_catalog_validation.py @@ -26,12 +26,12 @@ ExtensionRegistry, ExtensionSpec, ExtensionStatus, - _entry_point_package_name, - _major_version, - _select_entry_points, check_extension_compatibility, + extension_api_major_version, + extension_entry_point_package_name, extension_version_satisfies, extension_version_spec_is_valid, + select_extension_entry_points, ) CATALOG_CURATION_STATUSES = ("submitted", "curated", "private", "internal") @@ -362,7 +362,7 @@ def _validate_version_policy( errors: list[CatalogValidationIssue], ) -> None: api_version = str(entry.get("api_version") or "") - if _major_version(api_version) != _major_version(EXTENSION_API_VERSION): + if extension_api_major_version(api_version) != extension_api_major_version(EXTENSION_API_VERSION): errors.append( CatalogValidationIssue( f"Extension API version {api_version!r} is not supported by this EEGPrep extension API", @@ -497,7 +497,7 @@ def _validate_installed_entries( options: CatalogValidationOptions, errors: list[CatalogValidationIssue], ) -> None: - selected_entry_points = _select_entry_points(options.entry_points_provider, EXTENSION_ENTRY_POINT_GROUP) + selected_entry_points = select_extension_entry_points(options.entry_points_provider, EXTENSION_ENTRY_POINT_GROUP) for entry in entries: entry_id = _entry_id(entry) package_name = str(entry.get("package_name") or "") @@ -622,7 +622,7 @@ def _matching_entry_point(entry_points: tuple[Any, ...], package_name: str, entr for entry_point in entry_points: if getattr(entry_point, "name", None) != entry_point_name: continue - entry_point_package = _entry_point_package_name(entry_point) + entry_point_package = extension_entry_point_package_name(entry_point) if entry_point_package and _normalize(entry_point_package) != normalized_package: continue return entry_point diff --git a/src/eegprep/extensions.py b/src/eegprep/extensions.py index 6ad3f681..cc91311a 100644 --- a/src/eegprep/extensions.py +++ b/src/eegprep/extensions.py @@ -56,6 +56,29 @@ class ExtensionStatus(str, Enum): UNKNOWN = "unknown" +EXTENSION_ACTIVE_STATUSES = frozenset( + { + ExtensionStatus.BUNDLED.value, + ExtensionStatus.INSTALLED.value, + ExtensionStatus.CURATED.value, + "ok", + } +) +EXTENSION_INSTALLED_STATUSES = frozenset( + { + ExtensionStatus.BUNDLED.value, + ExtensionStatus.INSTALLED.value, + ExtensionStatus.DISABLED.value, + ExtensionStatus.INCOMPATIBLE.value, + ExtensionStatus.FAILED_IMPORT.value, + ExtensionStatus.INVALID_SPEC.value, + ExtensionStatus.MISSING_DEPENDENCY.value, + ExtensionStatus.UNKNOWN.value, + "ok", + } +) + + class ExtensionSourceType(str, Enum): """Source category declared by an extension.""" @@ -244,16 +267,7 @@ class ExtensionRecord: @property def is_active(self) -> bool: """Return whether this record can contribute runtime behavior.""" - return ( - self.enabled - and self.spec is not None - and self.status - in { - ExtensionStatus.BUNDLED, - ExtensionStatus.INSTALLED, - ExtensionStatus.CURATED, - } - ) + return self.enabled and self.spec is not None and extension_status_is_active(self.status) EntryPointsProvider = Callable[..., Any] @@ -511,13 +525,13 @@ def _bundled_records(self) -> list[ExtensionRecord]: def _entry_point_records(self) -> list[ExtensionRecord]: records = [] - for entry_point in _select_entry_points(self._entry_points_provider, self.entry_point_group): + for entry_point in select_extension_entry_points(self._entry_points_provider, self.entry_point_group): records.append(self._record_from_entry_point(entry_point)) return records def _record_from_entry_point(self, entry_point: Any) -> ExtensionRecord: entry_point_name = str(getattr(entry_point, "name", "") or "") - package_name = _entry_point_package_name(entry_point) + package_name = extension_entry_point_package_name(entry_point) try: loaded = entry_point.load() except Exception as exc: @@ -889,7 +903,8 @@ def _dependency_errors(dependencies: tuple[Any, ...], version_provider: VersionP return invalid, missing -def _select_entry_points(provider: EntryPointsProvider, group: str) -> tuple[Any, ...]: +def select_extension_entry_points(provider: EntryPointsProvider, group: str) -> tuple[Any, ...]: + """Return entry points from providers with modern or legacy selection APIs.""" try: selected = provider(group=group) except TypeError: @@ -901,6 +916,18 @@ def _select_entry_points(provider: EntryPointsProvider, group: str) -> tuple[Any return tuple(selected or ()) +def extension_status_is_active(status: ExtensionStatus | str) -> bool: + """Return whether an extension status can contribute runtime behavior.""" + value = status.value if isinstance(status, ExtensionStatus) else str(status) + return value in EXTENSION_ACTIVE_STATUSES + + +def extension_status_is_installed(status: ExtensionStatus | str) -> bool: + """Return whether an extension status represents an installed package/port.""" + value = status.value if isinstance(status, ExtensionStatus) else str(status) + return value in EXTENSION_INSTALLED_STATUSES + + def _status_for_spec( source_type: ExtensionSourceType, enabled: bool, @@ -960,7 +987,8 @@ def _coerce_source_type(source_type: ExtensionSourceType | str) -> ExtensionSour return ExtensionSourceType.UNKNOWN -def _entry_point_package_name(entry_point: Any) -> str: +def extension_entry_point_package_name(entry_point: Any) -> str: + """Return the distribution name associated with an extension entry point.""" dist = getattr(entry_point, "dist", None) dist_metadata = getattr(dist, "metadata", None) if dist_metadata is None: @@ -978,10 +1006,11 @@ def _current_eegprep_version() -> str: def _api_version_supported(api_version: str) -> bool: - return _major_version(api_version) == _major_version(EXTENSION_API_VERSION) + return extension_api_major_version(api_version) == extension_api_major_version(EXTENSION_API_VERSION) -def _major_version(version: str) -> int: +def extension_api_major_version(version: str) -> int: + """Return the integer major component for an extension API/version string.""" text = str(version).strip() if not text: return -1 @@ -1072,9 +1101,11 @@ def _as_tuple(value: Any) -> tuple[Any, ...]: __all__ = [ "EXTENSION_API_VERSION", + "EXTENSION_ACTIVE_STATUSES", "EXTENSION_COMPATIBILITY_POLICY", "EXTENSION_CURATION_POLICY_URL", "EXTENSION_ENTRY_POINT_GROUP", + "EXTENSION_INSTALLED_STATUSES", "EXTENSION_NAMING_PREFIX", "EXTENSION_TRUST_MESSAGE", "ExtensionAction", @@ -1093,7 +1124,12 @@ def _as_tuple(value: Any) -> tuple[Any, ...]: "check_extension_compatibility", "compare_extension_versions", "discover_extensions", + "extension_api_major_version", + "extension_entry_point_package_name", + "extension_status_is_active", + "extension_status_is_installed", "extension_version_satisfies", "extension_version_spec_is_valid", + "select_extension_entry_points", "validate_extension_spec", ] diff --git a/src/eegprep/functions/adminfunc/console.py b/src/eegprep/functions/adminfunc/console.py index 4efb9101..2f4be816 100644 --- a/src/eegprep/functions/adminfunc/console.py +++ b/src/eegprep/functions/adminfunc/console.py @@ -265,8 +265,20 @@ def __call__(self, command: Any = None, *args: Any) -> str: text = "" if not args else str(args[0]) return eegh_find(self.bridge.session.ALLCOM, text) if isinstance(command, (int, float)) and not isinstance(command, bool): - history_command = eegh(command, self.bridge.session.ALLCOM) - if history_command and int(command) > 0: + value = int(command) + if float(command) != value: + raise ValueError("eegh numeric command must be an integer") + session = self.bridge.session + if value == 0: + session.clear_history() + self.bridge.pull_from_session() + return "" + if value < 0: + session.remove_history(abs(value)) + self.bridge.pull_from_session() + return "" + history_command = session.history_command_at(value) + if history_command: self.bridge.execute_history_command(history_command) return history_command session = self.bridge.session @@ -274,7 +286,7 @@ def __call__(self, command: Any = None, *args: Any) -> str: if normalized: session.add_history(command) else: - session.LASTCOM = "" + session.clear_last_command() if args and isinstance(args[0], dict): eegh(normalized, args[0]) self.bridge.pull_from_session() diff --git a/src/eegprep/functions/adminfunc/plugin_menu.py b/src/eegprep/functions/adminfunc/plugin_menu.py index a6ca9c5b..d9712fff 100644 --- a/src/eegprep/functions/adminfunc/plugin_menu.py +++ b/src/eegprep/functions/adminfunc/plugin_menu.py @@ -14,6 +14,8 @@ ExtensionSourceType, ExtensionStatus, compare_extension_versions, + extension_status_is_active, + extension_status_is_installed, ) from eegprep.extension_catalog import ( INSTALL_TRUST_WARNING, @@ -29,23 +31,6 @@ "This manager never downloads, unzips, installs, updates, or removes extension code." ) -_ACTIVE_STATUSES = { - ExtensionStatus.BUNDLED.value, - ExtensionStatus.INSTALLED.value, - ExtensionStatus.CURATED.value, - "ok", -} -_INSTALLED_STATUSES = { - ExtensionStatus.BUNDLED.value, - ExtensionStatus.INSTALLED.value, - ExtensionStatus.DISABLED.value, - ExtensionStatus.INCOMPATIBLE.value, - ExtensionStatus.FAILED_IMPORT.value, - ExtensionStatus.INVALID_SPEC.value, - ExtensionStatus.MISSING_DEPENDENCY.value, - ExtensionStatus.UNKNOWN.value, - "ok", -} _STATUS_LABELS = { ExtensionStatus.BUNDLED.value: "Bundled", ExtensionStatus.INSTALLED.value: "Installed", @@ -70,11 +55,7 @@ ExtensionStatus.MISSING_DEPENDENCY.value: "#ffe9cc", ExtensionStatus.UNKNOWN.value: "#eeeeee", } - -# EEGLAB-style top-level menu labels for the bundled plugin ports. Names, -# versions, descriptions, capabilities, and pop functions are derived from the -# extension registry so they cannot drift from the live discovery path. -_BUNDLED_PLUGIN_MENUS: dict[str, str] = { +_BUNDLED_PLUGIN_MENU_SUMMARIES = { "clean_rawdata": "Tools > Reject data using Clean Rawdata and ASR", "ICLabel": "Tools > Classify components using ICLabel", "firfilt": "Tools > Filter the data", @@ -100,7 +81,7 @@ def bundled_plugins() -> tuple[dict[str, Any], ...]: "status": "ok", "installed": True, "source": record.source_type.value, - "menu": _BUNDLED_PLUGIN_MENUS[record.name], + "menu": _menu_text(record), "description": spec.description if spec is not None else "", "tags": tuple(spec.capabilities if spec is not None else ()), } @@ -374,8 +355,8 @@ def _plugin_from_record( update_commands = build_safe_update_commands(catalog_entry) if catalog_entry is not None else {} conflicts = _catalog_conflicts(record, catalog_entry) status = record.status.value - installed = status in _INSTALLED_STATUSES - active = installed and status in _ACTIVE_STATUSES + installed = extension_status_is_installed(status) + active = record.is_active menu = _menu_text(record) plugin = { "plugin": record.name, @@ -472,7 +453,7 @@ def _normalize_plugin(plugin: dict[str, Any], *, catalog_info: dict[str, Any] | normalized.setdefault("funcname", "") normalized.setdefault("status", "ok" if normalized.get("installed", True) else "unavailable") normalized.setdefault("state_label", _STATUS_LABELS.get(str(normalized["status"]), str(normalized["status"]))) - normalized.setdefault("installed", str(normalized.get("status")) in _INSTALLED_STATUSES) + normalized.setdefault("installed", extension_status_is_installed(str(normalized.get("status")))) normalized["active"] = _is_active(normalized) normalized.setdefault("enabled", normalized["active"]) normalized.setdefault("curated", False) @@ -861,6 +842,8 @@ def _install_guidance( def _menu_text(record: ExtensionRecord) -> str: + if record.source_type == ExtensionSourceType.BUNDLED and record.name in _BUNDLED_PLUGIN_MENU_SUMMARIES: + return _BUNDLED_PLUGIN_MENU_SUMMARIES[record.name] if record.spec is None or not record.spec.menus: return "" menu = record.spec.menus[0] @@ -891,7 +874,7 @@ def _source_detail(record: ExtensionRecord, catalog_entry: ExtensionCatalogEntry def _is_active(plugin: dict[str, Any]) -> bool: status = str(plugin.get("status", "")) - return bool(plugin.get("installed", status in _INSTALLED_STATUSES)) and status in _ACTIVE_STATUSES + return bool(plugin.get("installed", extension_status_is_installed(status))) and extension_status_is_active(status) def _catalog_version_is_newer(installed_version: str, catalog_version: str) -> bool: diff --git a/src/eegprep/functions/guifunc/qt.py b/src/eegprep/functions/guifunc/qt.py index cecf1d49..7dbdf9cc 100644 --- a/src/eegprep/functions/guifunc/qt.py +++ b/src/eegprep/functions/guifunc/qt.py @@ -48,6 +48,55 @@ def _require_qt() -> tuple[Any, Any]: class QtDialogRenderer: """Render :class:`DialogSpec` using PySide6 widgets.""" + # Implementations live at module scope; these aliases preserve existing + # QtDialogRenderer._helper callers. + _QDialog: Any + _apply_eeglab_style: Any + _row_weights: Any + _row_stretch: Any + _spacer_row_height: Any + _add_buttons: Any + _apply_spec_size: Any + _apply_font_hints: Any + _apply_widget_size_policy: Any + _accept_if_valid: Any + _validation_message: Any + _validate_pop_reref_dialog: Any + _validate_pop_interp_dialog: Any + _callback_channels: Any + _validate_channel_text: Any + _parse_channel_text: Any + _parse_numeric_text: Any + _widget_number: Any + _widget_vector: Any + _combo_choice: Any + _is_int_text: Any + _widget_checked: Any + _widget_text: Any + _plot_tf_cycle_calc: Any + _estimate_fir_kaiser_beta: Any + _estimate_firws_order: Any + _estimate_firpm_order: Any + _plot_fir_response: Any + _sync_numeric: Any + _select_event_types: Any + _select_channels: Any + _select_file: Any + _open_eegplot: Any + _open_rejection_browser: Any + _set_headplot_setup_mode: Any + _set_headplot_mesh_choice: Any + _run_headplot_manual_coreg: Any + _choice_or_text: Any + _show_callback_message: Any + _edit_text: Any + _select_interp_channels: Any + _store_interp_selection: Any + _set_reref_mode: Any + _set_enabled: Any + _show_help: Any + _read_widget: Any + def run( self, spec: DialogSpec, @@ -59,15 +108,6 @@ def run( return None return {tag: self._read_widget(widget) for tag, widget in widgets.items()} - @staticmethod - def _QDialog() -> Any: - if QDialog is None: - raise RuntimeError( - "PySide6 is required for EEGPrep GUI dialogs. Install it with " - "`pip install -e .[gui]` or `pip install eegprep[gui]`." - ) - return QDialog - def build_dialog( self, spec: DialogSpec, @@ -143,195 +183,6 @@ def build_dialog( self._apply_spec_size(dialog, spec) return app, dialog, widgets - @staticmethod - def _apply_eeglab_style(dialog: Any, spec: DialogSpec) -> None: - base_stylesheet = """ - QDialog { - background: #a8c2ff; - color: #000066; - font-size: 16px; - } - QLabel, QCheckBox, QPushButton, QLineEdit, QTextEdit, QComboBox, QListWidget { - font-size: 16px; - } - QLabel, QCheckBox { - color: #000066; - background: transparent; - } - QLabel:disabled, QCheckBox:disabled { - color: #7c86a8; - } - QLineEdit { - background: white; - border: 1px solid #7f7f7f; - min-height: 18px; - max-height: 18px; - margin-left: 1px; - padding: 0 3px; - color: #000066; - } - QLineEdit:disabled { - background: #dce6ff; - color: #7c86a8; - } - QTextEdit { - background: white; - border: 1px solid #7f7f7f; - color: #000066; - } - QComboBox { - background: white; - border: 1px solid #7f7f7f; - min-height: 20px; - max-height: 20px; - color: #000066; - } - QComboBox:disabled { - background: #dce6ff; - color: #7c86a8; - } - QListWidget { - background: white; - border: 1px solid #7f7f7f; - min-height: 74px; - max-height: 74px; - color: #000066; - } - QPushButton { - background: #eeeeee; - border: 1px solid #7f7f7f; - min-height: 18px; - max-height: 18px; - padding: 0 10px; - color: #000066; - } - QPushButton:disabled { - color: #b0b0b0; - } - QPushButton#double_dip_help { - min-width: 150px; - max-width: 150px; - } - QPushButton#events_button { - min-width: 130px; - max-width: 130px; - } - QPushButton#scroll { - min-width: 159px; - max-width: 159px; - } - QPushButton#refbr, QPushButton#exclude_button, QPushButton#refloc_button { - min-width: 33px; - max-width: 33px; - padding: 0; - } - QPushButton#interp_nondatchan, - QPushButton#interp_removedchans, - QPushButton#interp_datchan, - QPushButton#interp_selectchan, - QPushButton#interp_uselist { - min-width: 434px; - max-width: 434px; - padding: 0; - } - QDialog#pop_interp QPushButton { - padding: 0; - } - QDialog#pop_interp QLineEdit, - QDialog#pop_interp QComboBox { - min-width: 198px; - max-width: 198px; - } - QDialog#pop_reref QLineEdit { - min-width: 163px; - max-width: 163px; - } - QDialog#pop_runica QListWidget { - min-height: 102px; - max-height: 102px; - } - QCheckBox { - spacing: 4px; - } - QCheckBox::indicator { - width: 13px; - height: 13px; - } - QCheckBox::indicator:unchecked { - background: white; - border: 1px solid #7f7f7f; - } - """ - dialog.setStyleSheet(base_stylesheet + (spec.extra_stylesheet or "")) - - @staticmethod - def _row_weights(row_geometry: Any) -> list[float]: - if isinstance(row_geometry, (list, tuple)): - return [max(0.01, float(value)) for value in row_geometry] - return [1.0] * max(1, int(row_geometry)) - - @staticmethod - def _row_stretch(spec: DialogSpec, row_index: int) -> int: - if spec.geomvert is None: - return 0 - value = spec.geomvert[min(row_index, len(spec.geomvert) - 1)] - return max(1, round(float(value) * 100)) - - @staticmethod - def _spacer_row_height(spec: DialogSpec, row_index: int) -> int: - if spec.geomvert is not None: - value = spec.geomvert[min(row_index, len(spec.geomvert) - 1)] - return max(8, round(float(value) * 55)) - return max(8, spec.row_spacing * 7) - - @staticmethod - def _add_buttons( - QtWidgets: Any, - layout: Any, - spec: DialogSpec, - dialog: Any, - widgets: dict[str, Any], - ) -> None: - if spec.geomvert is None and spec.size is not None: - layout.addStretch(1) - button_container = QtWidgets.QWidget() - button_container.setSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Fixed) - button_layout = QtWidgets.QHBoxLayout(button_container) - button_layout.setContentsMargins(0, 18, 0, 0) - button_layout.setSpacing(16) - if spec.help_text and spec.show_help_button: - help_button = QtWidgets.QPushButton(spec.help_label) - help_button.setObjectName("help") - help_button.setFixedWidth(spec.button_size[0] if spec.button_size is not None else 80) - help_button.clicked.connect(lambda: QtDialogRenderer._show_help(QtWidgets, dialog, spec)) - button_layout.addWidget(help_button) - button_layout.addStretch(1) - cancel_button = QtWidgets.QPushButton(spec.cancel_label) - ok_button = QtWidgets.QPushButton(spec.ok_label) - cancel_button.setObjectName("cancel") - ok_button.setObjectName("ok") - if spec.button_size is not None: - cancel_button.setFixedSize(*spec.button_size) - ok_button.setFixedSize(*spec.button_size) - else: - cancel_button.setFixedWidth(80) - ok_button.setFixedWidth(80) - cancel_button.clicked.connect(dialog.reject) - ok_button.clicked.connect(lambda: QtDialogRenderer._accept_if_valid(dialog, spec, widgets)) - if spec.cancel_first: - button_layout.insertWidget(0, cancel_button) - button_layout.addWidget(ok_button) - else: - button_layout.addWidget(cancel_button) - button_layout.addWidget(ok_button) - layout.addWidget(button_container) - - @staticmethod - def _apply_spec_size(dialog: Any, spec: DialogSpec) -> None: - if spec.size is None: - return - dialog.resize(*spec.size) - def _build_widget(self, QtWidgets: Any, control: ControlSpec, initial_values: Mapping[str, Any]) -> Any: style = control.style.lower() value = initial_values.get(control.tag, control.value) @@ -388,30 +239,6 @@ def _build_widget(self, QtWidgets: Any, control: ControlSpec, initial_values: Ma widget.setEnabled(control.enabled) return widget - @staticmethod - def _apply_font_hints(widget: Any, control: ControlSpec) -> None: - if control.font_weight is None: - return - font = widget.font() - font.setBold(control.font_weight.lower() == "bold") - widget.setFont(font) - - @staticmethod - def _apply_widget_size_policy(QtWidgets: Any, widget: Any, style: str) -> None: - policy = QtWidgets.QSizePolicy - if style in {"edit", "popupmenu", "listbox"}: - widget.setSizePolicy(policy.Expanding, policy.Fixed) - return - if style == "pushbutton": - widget.setSizePolicy(policy.Fixed, policy.Fixed) - return - if style == "textarea": - widget.setSizePolicy(policy.Expanding, policy.Expanding) - return - if style in {"text", "checkbox"}: - widget.setMinimumWidth(0) - widget.setSizePolicy(policy.Expanding, policy.Fixed) - def _connect_callback(self, callback: CallbackSpec | None, widgets: dict[str, Any]) -> None: if callback is None: return @@ -517,203 +344,6 @@ def _connect_callback(self, callback: CallbackSpec | None, widgets: dict[str, An if source is not None: source.clicked.connect(lambda: self._plot_fir_response(source, widgets, params)) - @staticmethod - def _accept_if_valid(dialog: Any, spec: DialogSpec, widgets: dict[str, Any]) -> None: - message = QtDialogRenderer._validation_message(spec, widgets) - if message: - _qt_core, qt_widgets = _require_qt() - qt_widgets.QMessageBox.warning(dialog, "Warning", message) - return - dialog.accept() - - @staticmethod - def _validation_message(spec: DialogSpec, widgets: dict[str, Any]) -> str | None: - if spec.function_name == "pop_reref": - return QtDialogRenderer._validate_pop_reref_dialog(spec, widgets) - if spec.function_name == "pop_interp": - return QtDialogRenderer._validate_pop_interp_dialog(spec, widgets) - if spec.function_name == "pop_resample": - text = QtDialogRenderer._widget_text(widgets.get("freq")).strip() - if not text: - return "New sampling rate is required" - try: - value = float(text) - except ValueError: - return "New sampling rate must be numeric" - if value <= 0: - return "New sampling rate must be positive" - if spec.function_name == "pop_epoch": - limits_text = QtDialogRenderer._widget_text(widgets.get("limits")).strip() - if not limits_text: - return "Epoch limits are required" - try: - limits = QtDialogRenderer._parse_numeric_text(limits_text) - except ValueError: - return "Epoch limits must be numeric" - if len(limits) != 2: - return "Epoch limits must contain 2 values" - if limits[0] >= limits[1]: - return "Epoch start must be lower than epoch end" - valuelim_text = QtDialogRenderer._widget_text(widgets.get("valuelim")).strip() - if valuelim_text: - try: - valuelim = QtDialogRenderer._parse_numeric_text(valuelim_text) - except ValueError: - return "Out-of-bounds EEG limits must be numeric" - if len(valuelim) not in {1, 2}: - return "Out-of-bounds EEG limits must contain 1 or 2 values" - if spec.function_name == "pop_runica" and "dataset" in widgets: - if not QtDialogRenderer._read_widget(widgets["dataset"]): - return "Select at least one dataset" - if spec.function_name == "pop_headplot": - if QtDialogRenderer._widget_checked(widgets.get("loadcb")): - if not QtDialogRenderer._widget_text(widgets.get("load")).strip(): - return "Select a spline file to load" - else: - if not QtDialogRenderer._widget_text(widgets.get("setup_file")).strip(): - return "Enter an output spline file name" - transform_text = QtDialogRenderer._widget_text(widgets.get("transform")).strip() - if not transform_text: - return "Enter a Talairach transformation matrix" - try: - transform = QtDialogRenderer._parse_numeric_text(transform_text) - except ValueError: - return "Talairach transformation matrix must contain numeric values" - if len(transform) not in {6, 9}: - return "Talairach transformation matrix must contain 6 or 9 values" - return None - - @staticmethod - def _validate_pop_reref_dialog(spec: DialogSpec, widgets: dict[str, Any]) -> str | None: - if QtDialogRenderer._widget_checked(widgets.get("huberef")): - huber_text = QtDialogRenderer._widget_text(widgets.get("huberval")).strip() - if huber_text: - try: - float(huber_text) - except ValueError: - return f"could not convert string to float: '{huber_text}'" - - channel_labels = QtDialogRenderer._callback_channels(spec, "refbr") - if QtDialogRenderer._widget_checked(widgets.get("rerefstr")): - ref_text = QtDialogRenderer._widget_text(widgets.get("reref")).strip() - if not ref_text: - return "Aborting: you must enter one or more reference channels" - message = QtDialogRenderer._validate_channel_text(ref_text, channel_labels, "Channel") - if message: - return message - - exclude_text = QtDialogRenderer._widget_text(widgets.get("exclude")).strip() - if exclude_text: - message = QtDialogRenderer._validate_channel_text(exclude_text, channel_labels, "Channel") - if message: - return message - - refloc_text = QtDialogRenderer._widget_text(widgets.get("refloc")).strip() - if refloc_text: - refloc_labels = QtDialogRenderer._callback_channels(spec, "refloc_button") - return QtDialogRenderer._validate_channel_text(refloc_text, refloc_labels, "Reference location") - return None - - @staticmethod - def _validate_pop_interp_dialog(spec: DialogSpec, widgets: dict[str, Any]) -> str | None: - if "chanlist" in widgets: - selection = QtDialogRenderer._read_widget(widgets["chanlist"]) - if not isinstance(selection, dict) or not selection.get("chans"): - return "Select one or more channels to interpolate" - for control in spec.controls: - if control.callback is None or control.callback.name != "validate_numeric_range": - continue - widget = widgets.get(control.tag or "") - text = QtDialogRenderer._widget_text(widget).strip() - if not text: - continue - params = control.callback.params - try: - values = QtDialogRenderer._parse_numeric_text(text) - except ValueError: - return "Time/point range must contain numeric values" - if len(values) != int(params.get("columns", 2)): - return "Time/point range must contain 2 columns exactly" - if min(values) < float(params["lower"]): - return "Time/point range exceed lower data limits" - if math.floor(max(values)) > float(params["upper"]): - return "Time/point range exceed upper data limits" - return None - - @staticmethod - def _callback_channels(spec: DialogSpec, tag: str) -> tuple[str, ...]: - for control in spec.controls: - if control.tag == tag and control.callback is not None: - return tuple(str(value) for value in control.callback.params.get("channels", ())) - return () - - @staticmethod - def _validate_channel_text(text: str, labels: tuple[str, ...], label: str) -> str | None: - values = QtDialogRenderer._parse_channel_text(text) - lower_labels = [value.lower() for value in labels] - for value in values: - if QtDialogRenderer._is_int_text(value): - index = int(value) - if index < 0 or index >= len(labels): - return f"{label} index out of range" - continue - if value.lower() not in lower_labels: - return f"{label} '{value}' not found" - return None - - @staticmethod - def _parse_channel_text(text: str) -> list[str]: - text = text.strip() - if text.startswith("[") and text.endswith("]"): - text = text[1:-1] - if text.startswith("{") and text.endswith("}"): - text = text[1:-1] - tokens = re.findall(r"'([^']*)'|\"([^\"]*)\"|([^,\s]+)", text) - return [next(part for part in token if part) for token in tokens] - - @staticmethod - def _parse_numeric_text(text: str) -> list[float]: - cleaned = text.strip().strip("[]") - if not cleaned: - return [] - return [float(value) for value in re.split(r"[\s,]+", cleaned) if value] - - @staticmethod - def _widget_number(widget: Any) -> float | None: - try: - return numeric_or_none(QtDialogRenderer._widget_text(widget)) - except ValueError: - return None - - @staticmethod - def _widget_vector(widget: Any) -> list[float] | None: - try: - return vector_or_none(QtDialogRenderer._widget_text(widget)) - except ValueError: - return None - - @staticmethod - def _combo_choice(widget: Any, values: tuple[str, ...], default: str) -> str: - if widget is not None and hasattr(widget, "currentIndex"): - index = int(widget.currentIndex()) - if 0 <= index < len(values): - return values[index] - return default - - @staticmethod - def _is_int_text(value: str) -> bool: - return bool(re.fullmatch(r"[+-]?\d+", value.strip())) - - @staticmethod - def _widget_checked(widget: Any) -> bool: - return bool(widget is not None and hasattr(widget, "isChecked") and widget.isChecked()) - - @staticmethod - def _widget_text(widget: Any) -> str: - if widget is None or not hasattr(widget, "text"): - return "" - return str(widget.text()) - def _run_tf_cycle_calc(self, button: Any, widgets: dict[str, Any], params: Mapping[str, Any]) -> None: _qt_core, qt_widgets = _require_qt() fft_widget = widgets.get(str(params.get("fft", ""))) @@ -752,515 +382,932 @@ def _run_tf_cycle_calc(self, button: Any, widgets: dict[str, Any], params: Mappi freqs_widget.setText(_format_numeric_vector(result.widths_table[:, 0])) cycles_widget.setText(_format_numeric_vector(result.cycles)) - @staticmethod - def _plot_tf_cycle_calc(button: Any, widgets: dict[str, Any]) -> None: - _qt_core, qt_widgets = _require_qt() - width_index = 0 - width_widget = widgets.get("widthpop") - if width_widget is not None and hasattr(width_widget, "currentIndex"): - width_index = int(width_widget.currentIndex()) - try: - tf_cycle_calc( - freqs=QtDialogRenderer._widget_text(widgets.get("freqedit")), - width=QtDialogRenderer._widget_text(widgets.get("widthedit")), - width_unit=WIDTH_UNITS[width_index], - log_spaced=QtDialogRenderer._widget_checked(widgets.get("spacingcheck")), - plot=True, - ) - except (IndexError, ValueError) as exc: - qt_widgets.QMessageBox.warning(button, "Warning", str(exc)) - @staticmethod - def _estimate_fir_kaiser_beta(button: Any, widgets: dict[str, Any], params: Mapping[str, Any]) -> None: +def _QDialog() -> Any: + if QDialog is None: + raise RuntimeError( + "PySide6 is required for EEGPrep GUI dialogs. Install it with " + "`pip install -e .[gui]` or `pip install eegprep[gui]`." + ) + return QDialog + + +def _apply_eeglab_style(dialog: Any, spec: DialogSpec) -> None: + base_stylesheet = """ + QDialog { + background: #a8c2ff; + color: #000066; + font-size: 16px; + } + QLabel, QCheckBox, QPushButton, QLineEdit, QTextEdit, QComboBox, QListWidget { + font-size: 16px; + } + QLabel, QCheckBox { + color: #000066; + background: transparent; + } + QLabel:disabled, QCheckBox:disabled { + color: #7c86a8; + } + QLineEdit { + background: white; + border: 1px solid #7f7f7f; + min-height: 18px; + max-height: 18px; + margin-left: 1px; + padding: 0 3px; + color: #000066; + } + QLineEdit:disabled { + background: #dce6ff; + color: #7c86a8; + } + QTextEdit { + background: white; + border: 1px solid #7f7f7f; + color: #000066; + } + QComboBox { + background: white; + border: 1px solid #7f7f7f; + min-height: 20px; + max-height: 20px; + color: #000066; + } + QComboBox:disabled { + background: #dce6ff; + color: #7c86a8; + } + QListWidget { + background: white; + border: 1px solid #7f7f7f; + min-height: 74px; + max-height: 74px; + color: #000066; + } + QPushButton { + background: #eeeeee; + border: 1px solid #7f7f7f; + min-height: 18px; + max-height: 18px; + padding: 0 10px; + color: #000066; + } + QPushButton:disabled { + color: #b0b0b0; + } + QPushButton#double_dip_help { + min-width: 150px; + max-width: 150px; + } + QPushButton#events_button { + min-width: 130px; + max-width: 130px; + } + QPushButton#scroll { + min-width: 159px; + max-width: 159px; + } + QPushButton#refbr, QPushButton#exclude_button, QPushButton#refloc_button { + min-width: 33px; + max-width: 33px; + padding: 0; + } + QPushButton#interp_nondatchan, + QPushButton#interp_removedchans, + QPushButton#interp_datchan, + QPushButton#interp_selectchan, + QPushButton#interp_uselist { + min-width: 434px; + max-width: 434px; + padding: 0; + } + QDialog#pop_interp QPushButton { + padding: 0; + } + QDialog#pop_interp QLineEdit, + QDialog#pop_interp QComboBox { + min-width: 198px; + max-width: 198px; + } + QDialog#pop_reref QLineEdit { + min-width: 163px; + max-width: 163px; + } + QDialog#pop_runica QListWidget { + min-height: 102px; + max-height: 102px; + } + QCheckBox { + spacing: 4px; + } + QCheckBox::indicator { + width: 13px; + height: 13px; + } + QCheckBox::indicator:unchecked { + background: white; + border: 1px solid #7f7f7f; + } + """ + dialog.setStyleSheet(base_stylesheet + (spec.extra_stylesheet or "")) + + +def _row_weights(row_geometry: Any) -> list[float]: + if isinstance(row_geometry, (list, tuple)): + return [max(0.01, float(value)) for value in row_geometry] + return [1.0] * max(1, int(row_geometry)) + + +def _row_stretch(spec: DialogSpec, row_index: int) -> int: + if spec.geomvert is None: + return 0 + value = spec.geomvert[min(row_index, len(spec.geomvert) - 1)] + return max(1, round(float(value) * 100)) + + +def _spacer_row_height(spec: DialogSpec, row_index: int) -> int: + if spec.geomvert is not None: + value = spec.geomvert[min(row_index, len(spec.geomvert) - 1)] + return max(8, round(float(value) * 55)) + return max(8, spec.row_spacing * 7) + + +def _add_buttons( + QtWidgets: Any, + layout: Any, + spec: DialogSpec, + dialog: Any, + widgets: dict[str, Any], +) -> None: + if spec.geomvert is None and spec.size is not None: + layout.addStretch(1) + button_container = QtWidgets.QWidget() + button_container.setSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Fixed) + button_layout = QtWidgets.QHBoxLayout(button_container) + button_layout.setContentsMargins(0, 18, 0, 0) + button_layout.setSpacing(16) + if spec.help_text and spec.show_help_button: + help_button = QtWidgets.QPushButton(spec.help_label) + help_button.setObjectName("help") + help_button.setFixedWidth(spec.button_size[0] if spec.button_size is not None else 80) + help_button.clicked.connect(lambda: _show_help(QtWidgets, dialog, spec)) + button_layout.addWidget(help_button) + button_layout.addStretch(1) + cancel_button = QtWidgets.QPushButton(spec.cancel_label) + ok_button = QtWidgets.QPushButton(spec.ok_label) + cancel_button.setObjectName("cancel") + ok_button.setObjectName("ok") + if spec.button_size is not None: + cancel_button.setFixedSize(*spec.button_size) + ok_button.setFixedSize(*spec.button_size) + else: + cancel_button.setFixedWidth(80) + ok_button.setFixedWidth(80) + cancel_button.clicked.connect(dialog.reject) + ok_button.clicked.connect(lambda: _accept_if_valid(dialog, spec, widgets)) + if spec.cancel_first: + button_layout.insertWidget(0, cancel_button) + button_layout.addWidget(ok_button) + else: + button_layout.addWidget(cancel_button) + button_layout.addWidget(ok_button) + layout.addWidget(button_container) + + +def _apply_spec_size(dialog: Any, spec: DialogSpec) -> None: + if spec.size is None: + return + dialog.resize(*spec.size) + + +def _apply_font_hints(widget: Any, control: ControlSpec) -> None: + if control.font_weight is None: + return + font = widget.font() + font.setBold(control.font_weight.lower() == "bold") + widget.setFont(font) + + +def _apply_widget_size_policy(QtWidgets: Any, widget: Any, style: str) -> None: + policy = QtWidgets.QSizePolicy + if style in {"edit", "popupmenu", "listbox"}: + widget.setSizePolicy(policy.Expanding, policy.Fixed) + return + if style == "pushbutton": + widget.setSizePolicy(policy.Fixed, policy.Fixed) + return + if style == "textarea": + widget.setSizePolicy(policy.Expanding, policy.Expanding) + return + if style in {"text", "checkbox"}: + widget.setMinimumWidth(0) + widget.setSizePolicy(policy.Expanding, policy.Fixed) + + +def _accept_if_valid(dialog: Any, spec: DialogSpec, widgets: dict[str, Any]) -> None: + message = _validation_message(spec, widgets) + if message: _qt_core, qt_widgets = _require_qt() - target = widgets.get(params.get("target", "")) - if target is None or not hasattr(target, "setText"): - return - dev_widget = widgets.get(params.get("dev", "")) - dev = QtDialogRenderer._widget_number(dev_widget) - if dev is None: - dev, accepted = qt_widgets.QInputDialog.getDouble( - button, - "Estimate Kaiser window beta", - "Max passband deviation/ripple:", - 0.001, - 1e-12, - 1.0, - 6, - ) - if not accepted: - return + qt_widgets.QMessageBox.warning(dialog, "Warning", message) + return + dialog.accept() + + +def _validation_message(spec: DialogSpec, widgets: dict[str, Any]) -> str | None: + if spec.function_name == "pop_reref": + return _validate_pop_reref_dialog(spec, widgets) + if spec.function_name == "pop_interp": + return _validate_pop_interp_dialog(spec, widgets) + if spec.function_name == "pop_resample": + text = _widget_text(widgets.get("freq")).strip() + if not text: + return "New sampling rate is required" try: - beta = kaiserbeta(dev) - except ValueError as exc: - qt_widgets.QMessageBox.warning(button, "Warning", str(exc)) - return - target.setText(f"{beta:g}") - if dev_widget is not None and hasattr(dev_widget, "setText"): - dev_widget.setText(f"{dev:g}") - - @staticmethod - def _estimate_firws_order(button: Any, widgets: dict[str, Any], params: Mapping[str, Any]) -> None: - _qt_core, qt_widgets = _require_qt() - target = widgets.get(params.get("target", "")) - if target is None or not hasattr(target, "setText"): - return - srate = QtDialogRenderer._widget_number(widgets.get(params.get("srate", ""))) - if srate is None: - srate = float(params.get("srate_value", 2)) - wtype = QtDialogRenderer._combo_choice(widgets.get(params.get("wtype", "")), WINDOW_TYPES, "hamming") - dev_widget = widgets.get(params.get("dev", "")) - dev = QtDialogRenderer._widget_number(dev_widget) - df_widget = widgets.get(params.get("df", "")) - df = QtDialogRenderer._widget_number(df_widget) - if df is None: - df, accepted = qt_widgets.QInputDialog.getDouble( - button, - "Estimate filter order", - "Transition bandwidth (Hz):", - max(1.0, srate / 100), - 1e-12, - srate / 2, - 6, - ) - if not accepted: - return + value = float(text) + except ValueError: + return "New sampling rate must be numeric" + if value <= 0: + return "New sampling rate must be positive" + if spec.function_name == "pop_epoch": + limits_text = _widget_text(widgets.get("limits")).strip() + if not limits_text: + return "Epoch limits are required" try: - order, out_dev = firwsord(wtype, srate, df, dev) - except ValueError as exc: - qt_widgets.QMessageBox.warning(button, "Warning", str(exc)) - return - target.setText(str(int(order))) - if dev_widget is not None and hasattr(dev_widget, "setText") and out_dev is not None: - dev_widget.setText(f"{float(out_dev):g}") + limits = _parse_numeric_text(limits_text) + except ValueError: + return "Epoch limits must be numeric" + if len(limits) != 2: + return "Epoch limits must contain 2 values" + if limits[0] >= limits[1]: + return "Epoch start must be lower than epoch end" + valuelim_text = _widget_text(widgets.get("valuelim")).strip() + if valuelim_text: + try: + valuelim = _parse_numeric_text(valuelim_text) + except ValueError: + return "Out-of-bounds EEG limits must be numeric" + if len(valuelim) not in {1, 2}: + return "Out-of-bounds EEG limits must contain 1 or 2 values" + if spec.function_name == "pop_runica" and "dataset" in widgets: + if not _read_widget(widgets["dataset"]): + return "Select at least one dataset" + if spec.function_name == "pop_headplot": + if _widget_checked(widgets.get("loadcb")): + if not _widget_text(widgets.get("load")).strip(): + return "Select a spline file to load" + else: + if not _widget_text(widgets.get("setup_file")).strip(): + return "Enter an output spline file name" + transform_text = _widget_text(widgets.get("transform")).strip() + if not transform_text: + return "Enter a Talairach transformation matrix" + try: + transform = _parse_numeric_text(transform_text) + except ValueError: + return "Talairach transformation matrix must contain numeric values" + if len(transform) not in {6, 9}: + return "Talairach transformation matrix must contain 6 or 9 values" + return None - @staticmethod - def _estimate_firpm_order(button: Any, widgets: dict[str, Any], params: Mapping[str, Any]) -> None: - _qt_core, qt_widgets = _require_qt() - fcutoff = QtDialogRenderer._widget_vector(widgets.get(params.get("fcutoff", ""))) - ftrans = QtDialogRenderer._widget_number(widgets.get(params.get("ftrans", ""))) - if not fcutoff or ftrans is None: - qt_widgets.QMessageBox.warning(button, "Warning", "Cutoff frequencies and transition width are required") - return - srate = QtDialogRenderer._widget_number(widgets.get(params.get("srate", ""))) - if srate is None: - srate = float(params.get("srate_value", 2)) - ftype = QtDialogRenderer._combo_choice(widgets.get(params.get("ftype", "")), FILTER_TYPES, "bandpass") - try: - edges, amplitudes = _firpm_order_shape(fcutoff, ftrans, ftype, srate) - order, wtpass, wtstop = pop_firpmord(edges, amplitudes, _firpm_default_devs(amplitudes), srate) - except ValueError as exc: - qt_widgets.QMessageBox.warning(button, "Warning", str(exc)) - return - for tag, value in ( - (params.get("forder"), order), - (params.get("wtpass"), wtpass), - (params.get("wtstop"), wtstop), - ): - widget = widgets.get(tag or "") - if widget is not None and hasattr(widget, "setText"): - widget.setText(f"{float(value):g}") - - @staticmethod - def _plot_fir_response(button: Any, widgets: dict[str, Any], params: Mapping[str, Any]) -> None: - _qt_core, qt_widgets = _require_qt() - design = str(params.get("design", "firws")) - srate = QtDialogRenderer._widget_number(widgets.get(params.get("srate", ""))) - if srate is None: - srate = float(params.get("srate_value", 2)) - try: - if design == "firpm": - b = design_firpm( - srate, - fcutoff=QtDialogRenderer._widget_vector(widgets.get("fcutoff")), - ftrans=QtDialogRenderer._widget_number(widgets.get("ftrans")), - ftype=QtDialogRenderer._combo_choice(widgets.get("ftype"), FILTER_TYPES, "bandpass"), - forder=int(QtDialogRenderer._widget_number(widgets.get("forder")) or 0), - wtpass=QtDialogRenderer._widget_number(widgets.get("wtpass")), - wtstop=QtDialogRenderer._widget_number(widgets.get("wtstop")), - ) - elif design == "firma": - b = design_firma(forder=int(QtDialogRenderer._widget_number(widgets.get("forder")) or 0)) - else: - b = design_firws( - srate, - fcutoff=QtDialogRenderer._widget_vector(widgets.get("fcutoff")), - forder=int(QtDialogRenderer._widget_number(widgets.get("forder")) or 0), - ftype=QtDialogRenderer._combo_choice(widgets.get("ftype"), FILTER_TYPES, "bandpass"), - wtype=QtDialogRenderer._combo_choice(widgets.get("wtype"), WINDOW_TYPES, "hamming"), - warg=QtDialogRenderer._widget_number(widgets.get("warg")), - ) - plotfresp(b, 1, fs=srate, dir="onepass-zerophase") - except ValueError as exc: - qt_widgets.QMessageBox.warning(button, "Warning", str(exc)) - @staticmethod - def _sync_numeric(source: Any, target: Any, multiplier: float) -> None: - text = source.text().strip() +def _validate_pop_reref_dialog(spec: DialogSpec, widgets: dict[str, Any]) -> str | None: + if _widget_checked(widgets.get("huberef")): + huber_text = _widget_text(widgets.get("huberval")).strip() + if huber_text: + try: + float(huber_text) + except ValueError: + return f"could not convert string to float: '{huber_text}'" + + channel_labels = _callback_channels(spec, "refbr") + if _widget_checked(widgets.get("rerefstr")): + ref_text = _widget_text(widgets.get("reref")).strip() + if not ref_text: + return "Aborting: you must enter one or more reference channels" + message = _validate_channel_text(ref_text, channel_labels, "Channel") + if message: + return message + + exclude_text = _widget_text(widgets.get("exclude")).strip() + if exclude_text: + message = _validate_channel_text(exclude_text, channel_labels, "Channel") + if message: + return message + + refloc_text = _widget_text(widgets.get("refloc")).strip() + if refloc_text: + refloc_labels = _callback_channels(spec, "refloc_button") + return _validate_channel_text(refloc_text, refloc_labels, "Reference location") + return None + + +def _validate_pop_interp_dialog(spec: DialogSpec, widgets: dict[str, Any]) -> str | None: + if "chanlist" in widgets: + selection = _read_widget(widgets["chanlist"]) + if not isinstance(selection, dict) or not selection.get("chans"): + return "Select one or more channels to interpolate" + for control in spec.controls: + if control.callback is None or control.callback.name != "validate_numeric_range": + continue + widget = widgets.get(control.tag or "") + text = _widget_text(widget).strip() if not text: - target.setText("") - return + continue + params = control.callback.params try: - value = float(text) * multiplier + values = _parse_numeric_text(text) except ValueError: - return - target.setText(f"{value:g}") + return "Time/point range must contain numeric values" + if len(values) != int(params.get("columns", 2)): + return "Time/point range must contain 2 columns exactly" + if min(values) < float(params["lower"]): + return "Time/point range exceed lower data limits" + if math.floor(max(values)) > float(params["upper"]): + return "Time/point range exceed upper data limits" + return None + + +def _callback_channels(spec: DialogSpec, tag: str) -> tuple[str, ...]: + for control in spec.controls: + if control.tag == tag and control.callback is not None: + return tuple(str(value) for value in control.callback.params.get("channels", ())) + return () + + +def _validate_channel_text(text: str, labels: tuple[str, ...], label: str) -> str | None: + values = _parse_channel_text(text) + lower_labels = [value.lower() for value in labels] + for value in values: + if _is_int_text(value): + index = int(value) + if index < 0 or index >= len(labels): + return f"{label} index out of range" + continue + if value.lower() not in lower_labels: + return f"{label} '{value}' not found" + return None + + +def _parse_channel_text(text: str) -> list[str]: + text = text.strip() + if text.startswith("[") and text.endswith("]"): + text = text[1:-1] + if text.startswith("{") and text.endswith("}"): + text = text[1:-1] + tokens = re.findall(r"'([^']*)'|\"([^\"]*)\"|([^,\s]+)", text) + return [next(part for part in token if part) for token in tokens] + + +def _parse_numeric_text(text: str) -> list[float]: + cleaned = text.strip().strip("[]") + if not cleaned: + return [] + return [float(value) for value in re.split(r"[\s,]+", cleaned) if value] + + +def _widget_number(widget: Any) -> float | None: + try: + return numeric_or_none(_widget_text(widget)) + except ValueError: + return None + + +def _widget_vector(widget: Any) -> list[float] | None: + try: + return vector_or_none(_widget_text(widget)) + except ValueError: + return None + + +def _combo_choice(widget: Any, values: tuple[str, ...], default: str) -> str: + if widget is not None and hasattr(widget, "currentIndex"): + index = int(widget.currentIndex()) + if 0 <= index < len(values): + return values[index] + return default + + +def _is_int_text(value: str) -> bool: + return bool(re.fullmatch(r"[+-]?\d+", value.strip())) + + +def _widget_checked(widget: Any) -> bool: + return bool(widget is not None and hasattr(widget, "isChecked") and widget.isChecked()) + - @staticmethod - def _select_event_types(button: Any, target: Any, params: Mapping[str, Any]) -> None: - event_types = [str(value) for value in params.get("event_types", ())] - if not event_types: +def _widget_text(widget: Any) -> str: + if widget is None or not hasattr(widget, "text"): + return "" + return str(widget.text()) + + +def _plot_tf_cycle_calc(button: Any, widgets: dict[str, Any]) -> None: + _qt_core, qt_widgets = _require_qt() + width_index = 0 + width_widget = widgets.get("widthpop") + if width_widget is not None and hasattr(width_widget, "currentIndex"): + width_index = int(width_widget.currentIndex()) + try: + tf_cycle_calc( + freqs=_widget_text(widgets.get("freqedit")), + width=_widget_text(widgets.get("widthedit")), + width_unit=WIDTH_UNITS[width_index], + log_spaced=_widget_checked(widgets.get("spacingcheck")), + plot=True, + ) + except (IndexError, ValueError) as exc: + qt_widgets.QMessageBox.warning(button, "Warning", str(exc)) + + +def _estimate_fir_kaiser_beta(button: Any, widgets: dict[str, Any], params: Mapping[str, Any]) -> None: + _qt_core, qt_widgets = _require_qt() + target = widgets.get(params.get("target", "")) + if target is None or not hasattr(target, "setText"): + return + dev_widget = widgets.get(params.get("dev", "")) + dev = _widget_number(dev_widget) + if dev is None: + dev, accepted = qt_widgets.QInputDialog.getDouble( + button, + "Estimate Kaiser window beta", + "Max passband deviation/ripple:", + 0.001, + 1e-12, + 1.0, + 6, + ) + if not accepted: return - current = target.text().strip() - _qt_core, qt_widgets = _require_qt() - value, accepted = qt_widgets.QInputDialog.getItem( + try: + beta = kaiserbeta(dev) + except ValueError as exc: + qt_widgets.QMessageBox.warning(button, "Warning", str(exc)) + return + target.setText(f"{beta:g}") + if dev_widget is not None and hasattr(dev_widget, "setText"): + dev_widget.setText(f"{dev:g}") + + +def _estimate_firws_order(button: Any, widgets: dict[str, Any], params: Mapping[str, Any]) -> None: + _qt_core, qt_widgets = _require_qt() + target = widgets.get(params.get("target", "")) + if target is None or not hasattr(target, "setText"): + return + srate = _widget_number(widgets.get(params.get("srate", ""))) + if srate is None: + srate = float(params.get("srate_value", 2)) + wtype = _combo_choice(widgets.get(params.get("wtype", "")), WINDOW_TYPES, "hamming") + dev_widget = widgets.get(params.get("dev", "")) + dev = _widget_number(dev_widget) + df_widget = widgets.get(params.get("df", "")) + df = _widget_number(df_widget) + if df is None: + df, accepted = qt_widgets.QInputDialog.getDouble( button, - "Select event type", - "Event type", - event_types, - 0, - editable=False, + "Estimate filter order", + "Transition bandwidth (Hz):", + max(1.0, srate / 100), + 1e-12, + srate / 2, + 6, ) - if accepted and value: - target.setText((current + " " + value).strip()) - - @staticmethod - def _select_channels(button: Any, target: Any, params: Mapping[str, Any]) -> None: - channels = [str(value) for value in params.get("channels", ())] - if channels: - chanlist, value, _allchanstr = pop_chansel( - channels, - withindex="on", - select=target.text().strip(), - selectionmode=str(params.get("selectionmode", "multiple")), - parent=button, + if not accepted: + return + try: + order, out_dev = firwsord(wtype, srate, df, dev) + except ValueError as exc: + qt_widgets.QMessageBox.warning(button, "Warning", str(exc)) + return + target.setText(str(int(order))) + if dev_widget is not None and hasattr(dev_widget, "setText") and out_dev is not None: + dev_widget.setText(f"{float(out_dev):g}") + + +def _estimate_firpm_order(button: Any, widgets: dict[str, Any], params: Mapping[str, Any]) -> None: + _qt_core, qt_widgets = _require_qt() + fcutoff = _widget_vector(widgets.get(params.get("fcutoff", ""))) + ftrans = _widget_number(widgets.get(params.get("ftrans", ""))) + if not fcutoff or ftrans is None: + qt_widgets.QMessageBox.warning(button, "Warning", "Cutoff frequencies and transition width are required") + return + srate = _widget_number(widgets.get(params.get("srate", ""))) + if srate is None: + srate = float(params.get("srate_value", 2)) + ftype = _combo_choice(widgets.get(params.get("ftype", "")), FILTER_TYPES, "bandpass") + try: + edges, amplitudes = _firpm_order_shape(fcutoff, ftrans, ftype, srate) + order, wtpass, wtstop = pop_firpmord(edges, amplitudes, _firpm_default_devs(amplitudes), srate) + except ValueError as exc: + qt_widgets.QMessageBox.warning(button, "Warning", str(exc)) + return + for tag, value in ( + (params.get("forder"), order), + (params.get("wtpass"), wtpass), + (params.get("wtstop"), wtstop), + ): + widget = widgets.get(tag or "") + if widget is not None and hasattr(widget, "setText"): + widget.setText(f"{float(value):g}") + + +def _plot_fir_response(button: Any, widgets: dict[str, Any], params: Mapping[str, Any]) -> None: + _qt_core, qt_widgets = _require_qt() + design = str(params.get("design", "firws")) + srate = _widget_number(widgets.get(params.get("srate", ""))) + if srate is None: + srate = float(params.get("srate_value", 2)) + try: + if design == "firpm": + b = design_firpm( + srate, + fcutoff=_widget_vector(widgets.get("fcutoff")), + ftrans=_widget_number(widgets.get("ftrans")), + ftype=_combo_choice(widgets.get("ftype"), FILTER_TYPES, "bandpass"), + forder=int(_widget_number(widgets.get("forder")) or 0), + wtpass=_widget_number(widgets.get("wtpass")), + wtstop=_widget_number(widgets.get("wtstop")), ) - if params.get("return_indices"): - value = " ".join(str(index) for index in chanlist) - accepted = bool(value) + elif design == "firma": + b = design_firma(forder=int(_widget_number(widgets.get("forder")) or 0)) else: - _qt_core, qt_widgets = _require_qt() - no_channels_message = str(params.get("no_channels_message", "")).strip() - if no_channels_message: - qt_widgets.QMessageBox.warning(button, "Warning", no_channels_message) - return - value, accepted = qt_widgets.QInputDialog.getText( - button, - "Select channel", - "Channel index or label", + b = design_firws( + srate, + fcutoff=_widget_vector(widgets.get("fcutoff")), + forder=int(_widget_number(widgets.get("forder")) or 0), + ftype=_combo_choice(widgets.get("ftype"), FILTER_TYPES, "bandpass"), + wtype=_combo_choice(widgets.get("wtype"), WINDOW_TYPES, "hamming"), + warg=_widget_number(widgets.get("warg")), ) - if not accepted or not value: - return - target.setText(value.strip()) - - @staticmethod - def _select_file(button: Any, target: Any, params: Mapping[str, Any], widgets: Mapping[str, Any]) -> None: + plotfresp(b, 1, fs=srate, dir="onepass-zerophase") + except ValueError as exc: + qt_widgets.QMessageBox.warning(button, "Warning", str(exc)) + + +def _sync_numeric(source: Any, target: Any, multiplier: float) -> None: + text = source.text().strip() + if not text: + target.setText("") + return + try: + value = float(text) * multiplier + except ValueError: + return + target.setText(f"{value:g}") + + +def _select_event_types(button: Any, target: Any, params: Mapping[str, Any]) -> None: + event_types = [str(value) for value in params.get("event_types", ())] + if not event_types: + return + current = target.text().strip() + _qt_core, qt_widgets = _require_qt() + value, accepted = qt_widgets.QInputDialog.getItem( + button, + "Select event type", + "Event type", + event_types, + 0, + editable=False, + ) + if accepted and value: + target.setText((current + " " + value).strip()) + + +def _select_channels(button: Any, target: Any, params: Mapping[str, Any]) -> None: + channels = [str(value) for value in params.get("channels", ())] + if channels: + chanlist, value, _allchanstr = pop_chansel( + channels, + withindex="on", + select=target.text().strip(), + selectionmode=str(params.get("selectionmode", "multiple")), + parent=button, + ) + if params.get("return_indices"): + value = " ".join(str(index) for index in chanlist) + accepted = bool(value) + else: _qt_core, qt_widgets = _require_qt() - caption = str(params.get("caption", "Select file")) - file_filter = str(params.get("filter", "All files (*)")) - if params.get("mode") == "save": - filename, _selected_filter = qt_widgets.QFileDialog.getSaveFileName(button, caption, "", file_filter) - else: - filename, _selected_filter = qt_widgets.QFileDialog.getOpenFileName(button, caption, "", file_filter) - if not filename: - return - if hasattr(target, "setText"): - target.setText(filename) - return - if hasattr(target, "setEditable") and hasattr(target, "setEditText"): - target.setEditable(True) - target.setEditText(filename) - target.setProperty(_VALUE_PROPERTY, filename) - transform_target = widgets.get(params.get("transform_target", "")) - if transform_target is not None and params.get("custom_transform") is not None: - transform_target.setText(str(params["custom_transform"])) - - @staticmethod - def _open_eegplot(parent: Any, params: Mapping[str, Any]) -> None: - eeg = params.get("eeg") - if not isinstance(eeg, dict): + no_channels_message = str(params.get("no_channels_message", "")).strip() + if no_channels_message: + qt_widgets.QMessageBox.warning(button, "Warning", no_channels_message) return - try: - eegplot( - eeg, - srate=eeg.get("srate", 256), - limits=[ - float(eeg.get("xmin", 0.0) or 0.0) * 1000.0, - float(eeg.get("xmax", 0.0) or 0.0) * 1000.0, - ], - events=eeg.get("event", []), - winlength=5, - xgrid="off", - eloc_file=eeg.get("chanlocs", []), - title=f"Scroll channel activities -- eegplot() -- {eeg.get('setname', '')}".rstrip(), - ) - except (RuntimeError, ValueError) as exc: - _qt_core, qt_widgets = _require_qt() - qt_widgets.QMessageBox.warning(parent, "Warning", str(exc)) - - @staticmethod - def _open_rejection_browser(parent: Any, widgets: Mapping[str, Any], params: Mapping[str, Any]) -> None: - eeg = params.get("eeg") - if not isinstance(eeg, dict): - return - try: - status_widget = widgets.get("rejstatus") - status = int(QtDialogRenderer._read_widget(status_widget) if status_widget is not None else 1) - superpose = max(0, min(status - 1, 2)) - - def accept(eeg_out: dict[str, Any], _command: str) -> None: - eeg.clear() - eeg.update(eeg_out) - for tag in params.get("count_tags", ()): - widget = widgets.get(tag) - field = params.get("count_fields", {}).get(tag) - if widget is not None and field is not None and hasattr(widget, "setText"): - widget.setText(str(int(np.asarray((eeg.get("reject") or {}).get(field, []), dtype=bool).sum()))) - - pop_eegplot( - eeg, - icacomp=int(params.get("icacomp", 1)), - superpose=superpose, - reject=0, - command_callback=accept, - ) - except (RuntimeError, ValueError) as exc: - _qt_core, qt_widgets = _require_qt() - qt_widgets.QMessageBox.warning(parent, "Warning", str(exc)) - - @staticmethod - def _set_headplot_setup_mode(widgets: Mapping[str, Any], params: Mapping[str, Any], checked: bool) -> None: - source = widgets.get(params["source"]) - if not checked: - if source is not None and hasattr(source, "blockSignals"): - source.blockSignals(True) - source.setChecked(True) - source.blockSignals(False) - return - peer = widgets.get(params["peer"]) - if peer is not None: - peer.blockSignals(True) - peer.setChecked(False) - peer.blockSignals(False) - load_enabled = params["mode"] == "load" - QtDialogRenderer._set_enabled( - [widgets[tag] for tag in params.get("load_targets", ()) if tag in widgets], - load_enabled, + value, accepted = qt_widgets.QInputDialog.getText( + button, + "Select channel", + "Channel index or label", ) - QtDialogRenderer._set_enabled( - [widgets[tag] for tag in params.get("setup_targets", ()) if tag in widgets], - not load_enabled, + if not accepted or not value: + return + target.setText(value.strip()) + + +def _select_file(button: Any, target: Any, params: Mapping[str, Any], widgets: Mapping[str, Any]) -> None: + _qt_core, qt_widgets = _require_qt() + caption = str(params.get("caption", "Select file")) + file_filter = str(params.get("filter", "All files (*)")) + if params.get("mode") == "save": + filename, _selected_filter = qt_widgets.QFileDialog.getSaveFileName(button, caption, "", file_filter) + else: + filename, _selected_filter = qt_widgets.QFileDialog.getOpenFileName(button, caption, "", file_filter) + if not filename: + return + if hasattr(target, "setText"): + target.setText(filename) + return + if hasattr(target, "setEditable") and hasattr(target, "setEditText"): + target.setEditable(True) + target.setEditText(filename) + target.setProperty(_VALUE_PROPERTY, filename) + transform_target = widgets.get(params.get("transform_target", "")) + if transform_target is not None and params.get("custom_transform") is not None: + transform_target.setText(str(params["custom_transform"])) + + +def _open_eegplot(parent: Any, params: Mapping[str, Any]) -> None: + eeg = params.get("eeg") + if not isinstance(eeg, dict): + return + try: + eegplot( + eeg, + srate=eeg.get("srate", 256), + limits=[ + float(eeg.get("xmin", 0.0) or 0.0) * 1000.0, + float(eeg.get("xmax", 0.0) or 0.0) * 1000.0, + ], + events=eeg.get("event", []), + winlength=5, + xgrid="off", + eloc_file=eeg.get("chanlocs", []), + title=f"Scroll channel activities -- eegplot() -- {eeg.get('setname', '')}".rstrip(), ) - - @staticmethod - def _set_headplot_mesh_choice(widgets: Mapping[str, Any], params: Mapping[str, Any], index: int) -> None: - reference_target = widgets.get(params.get("reference_target", "")) - if ( - reference_target is not None - and hasattr(reference_target, "setCurrentIndex") - and index < reference_target.count() - ): - reference_target.setCurrentIndex(index) - target = widgets.get(params.get("transform_target", "")) - transforms = tuple(str(value) for value in params.get("transform_choices", ())) - if target is not None and 0 <= index < len(transforms) and transforms[index]: - target.setText(transforms[index]) - - @staticmethod - def _run_headplot_manual_coreg(parent: Any, widgets: Mapping[str, Any], params: Mapping[str, Any]) -> None: - transform_target = widgets.get(params.get("transform_target", "")) - if transform_target is None or not hasattr(transform_target, "setText"): - return - try: - # Manual headplot co-registration is the only generic Qt dialog path - # that needs matplotlib's 3-D stack, so keep it out of normal - # inputgui imports and startup. - from eegprep.functions.guifunc.coregister import run_coregister_dialog - - meshfile = QtDialogRenderer._choice_or_text( - widgets.get(params.get("mesh_source", "")), - tuple(str(value) for value in params.get("mesh_choices", ())), - ) - reference = QtDialogRenderer._choice_or_text( - widgets.get(params.get("reference_source", "")), - tuple(str(value) for value in params.get("reference_choices", ())), - ) - transform = run_coregister_dialog( - params.get("chanlocs", ()), - reference, - chaninfo=dict(params.get("chaninfo") or {}), - meshfile=meshfile, - transform=QtDialogRenderer._widget_text(transform_target), - parent=parent, - title=str(params.get("title", "Co-registration plot for headplot mesh")), - ) - except (RuntimeError, OSError, ValueError) as exc: - _qt_core, qt_widgets = _require_qt() - qt_widgets.QMessageBox.warning(parent, "Warning", str(exc)) - return - if transform is not None: - transform_target.setText(" ".join(f"{value:.6g}" for value in transform)) - - @staticmethod - def _choice_or_text(widget: Any, choices: tuple[str, ...]) -> str: - stored_value = widget.property(_VALUE_PROPERTY) if widget is not None and hasattr(widget, "property") else None - if stored_value is not None: - return str(stored_value) - if widget is not None and hasattr(widget, "currentIndex"): - index = int(widget.currentIndex()) - if 0 <= index < len(choices): - return choices[index] - if widget is not None and hasattr(widget, "text"): - return str(widget.text()) - return choices[0] if choices else "" - - @staticmethod - def _show_callback_message(parent: Any, params: Mapping[str, Any]) -> None: + except (RuntimeError, ValueError) as exc: _qt_core, qt_widgets = _require_qt() - qt_widgets.QMessageBox.information(parent, str(params.get("title", "EEGPrep")), str(params.get("message", ""))) - - @staticmethod - def _edit_text(parent: Any, target: Any, params: Mapping[str, Any]) -> None: + qt_widgets.QMessageBox.warning(parent, "Warning", str(exc)) + + +def _open_rejection_browser(parent: Any, widgets: Mapping[str, Any], params: Mapping[str, Any]) -> None: + eeg = params.get("eeg") + if not isinstance(eeg, dict): + return + try: + status_widget = widgets.get("rejstatus") + status = int(_read_widget(status_widget) if status_widget is not None else 1) + superpose = max(0, min(status - 1, 2)) + + def accept(eeg_out: dict[str, Any], _command: str) -> None: + eeg.clear() + eeg.update(eeg_out) + for tag in params.get("count_tags", ()): + widget = widgets.get(tag) + field = params.get("count_fields", {}).get(tag) + if widget is not None and field is not None and hasattr(widget, "setText"): + widget.setText(str(int(np.asarray((eeg.get("reject") or {}).get(field, []), dtype=bool).sum()))) + + pop_eegplot( + eeg, + icacomp=int(params.get("icacomp", 1)), + superpose=superpose, + reject=0, + command_callback=accept, + ) + except (RuntimeError, ValueError) as exc: _qt_core, qt_widgets = _require_qt() - stored_value = target.property(_VALUE_PROPERTY) - current = stored_value if stored_value is not None else params.get("value", "") - value, accepted = qt_widgets.QInputDialog.getMultiLineText( - parent, - str(params.get("title", "Edit text")), - str(params.get("label", "Text")), - str(current), + qt_widgets.QMessageBox.warning(parent, "Warning", str(exc)) + + +def _set_headplot_setup_mode(widgets: Mapping[str, Any], params: Mapping[str, Any], checked: bool) -> None: + source = widgets.get(params["source"]) + if not checked: + if source is not None and hasattr(source, "blockSignals"): + source.blockSignals(True) + source.setChecked(True) + source.blockSignals(False) + return + peer = widgets.get(params["peer"]) + if peer is not None: + peer.blockSignals(True) + peer.setChecked(False) + peer.blockSignals(False) + load_enabled = params["mode"] == "load" + _set_enabled( + [widgets[tag] for tag in params.get("load_targets", ()) if tag in widgets], + load_enabled, + ) + _set_enabled( + [widgets[tag] for tag in params.get("setup_targets", ()) if tag in widgets], + not load_enabled, + ) + + +def _set_headplot_mesh_choice(widgets: Mapping[str, Any], params: Mapping[str, Any], index: int) -> None: + reference_target = widgets.get(params.get("reference_target", "")) + if ( + reference_target is not None + and hasattr(reference_target, "setCurrentIndex") + and index < reference_target.count() + ): + reference_target.setCurrentIndex(index) + target = widgets.get(params.get("transform_target", "")) + transforms = tuple(str(value) for value in params.get("transform_choices", ())) + if target is not None and 0 <= index < len(transforms) and transforms[index]: + target.setText(transforms[index]) + + +def _run_headplot_manual_coreg(parent: Any, widgets: Mapping[str, Any], params: Mapping[str, Any]) -> None: + transform_target = widgets.get(params.get("transform_target", "")) + if transform_target is None or not hasattr(transform_target, "setText"): + return + try: + # Manual headplot co-registration is the only generic Qt dialog path + # that needs matplotlib's 3-D stack, so keep it out of normal + # inputgui imports and startup. + from eegprep.functions.guifunc.coregister import run_coregister_dialog + + meshfile = _choice_or_text( + widgets.get(params.get("mesh_source", "")), + tuple(str(value) for value in params.get("mesh_choices", ())), ) - if accepted: - target.setProperty(_VALUE_PROPERTY, str(value)) - - @staticmethod - def _select_interp_channels(button: Any, target: Any, params: Mapping[str, Any]) -> None: - source = str(params.get("source", "")).lower() - chanlocs = [dict(chan) for chan in params.get("chanlocs", ())] - removedchans = [dict(chan) for chan in params.get("removedchans", ())] - alleeg = [dict(eeg) for eeg in params.get("alleeg", ())] - - if source in {"removedchans", "nondatchan"}: - labels = [str(chan.get("labels", "")) for chan in removedchans] - chanlist, chanliststr, _allchanstr = pop_chansel(labels, parent=button) - if not chanlist: - return - selected = [removedchans[index - 1] for index in chanlist] - chanstr = "EEG.chaninfo.removedchans([" + " ".join(str(index) for index in chanlist) + "])" - QtDialogRenderer._store_interp_selection(target, selected, chanstr, chanliststr) - return + reference = _choice_or_text( + widgets.get(params.get("reference_source", "")), + tuple(str(value) for value in params.get("reference_choices", ())), + ) + transform = run_coregister_dialog( + params.get("chanlocs", ()), + reference, + chaninfo=dict(params.get("chaninfo") or {}), + meshfile=meshfile, + transform=_widget_text(transform_target), + parent=parent, + title=str(params.get("title", "Co-registration plot for headplot mesh")), + ) + except (RuntimeError, OSError, ValueError) as exc: + _qt_core, qt_widgets = _require_qt() + qt_widgets.QMessageBox.warning(parent, "Warning", str(exc)) + return + if transform is not None: + transform_target.setText(" ".join(f"{value:.6g}" for value in transform)) + + +def _choice_or_text(widget: Any, choices: tuple[str, ...]) -> str: + stored_value = widget.property(_VALUE_PROPERTY) if widget is not None and hasattr(widget, "property") else None + if stored_value is not None: + return str(stored_value) + if widget is not None and hasattr(widget, "currentIndex"): + index = int(widget.currentIndex()) + if 0 <= index < len(choices): + return choices[index] + if widget is not None and hasattr(widget, "text"): + return str(widget.text()) + return choices[0] if choices else "" - if source == "datchan": - labels = [str(chan.get("labels", "")) for chan in chanlocs] - chanlist, chanliststr, _allchanstr = pop_chansel(labels, parent=button) - if not chanlist: - return - selected = [index - 1 for index in chanlist] - chanstr = "[" + " ".join(str(index) for index in chanlist) + "]" - QtDialogRenderer._store_interp_selection(target, selected, chanstr, chanliststr) - return - _qt_core, qt_widgets = _require_qt() - dataset_index, accepted = qt_widgets.QInputDialog.getInt( - button, - "Choose dataset", - "Dataset index", - 1, - 1, - max(1, len(alleeg)), - ) - if not accepted: - return - if dataset_index < 1 or dataset_index > len(alleeg): - qt_widgets.QMessageBox.warning(button, "Warning", "Wrong index") - return +def _show_callback_message(parent: Any, params: Mapping[str, Any]) -> None: + _qt_core, qt_widgets = _require_qt() + qt_widgets.QMessageBox.information(parent, str(params.get("title", "EEGPrep")), str(params.get("message", ""))) - other = alleeg[dataset_index - 1] - other_chanlocs = [dict(chan) for chan in other.get("chanlocs", ())] - if source == "selectchan": - labels = [str(chan.get("labels", "")) for chan in other_chanlocs] - chanlist, _chanliststr, _allchanstr = pop_chansel(labels, parent=button) - else: - chanlist = list(range(1, len(other_chanlocs) + 1)) + +def _edit_text(parent: Any, target: Any, params: Mapping[str, Any]) -> None: + _qt_core, qt_widgets = _require_qt() + stored_value = target.property(_VALUE_PROPERTY) + current = stored_value if stored_value is not None else params.get("value", "") + value, accepted = qt_widgets.QInputDialog.getMultiLineText( + parent, + str(params.get("title", "Edit text")), + str(params.get("label", "Text")), + str(current), + ) + if accepted: + target.setProperty(_VALUE_PROPERTY, str(value)) + + +def _select_interp_channels(button: Any, target: Any, params: Mapping[str, Any]) -> None: + source = str(params.get("source", "")).lower() + chanlocs = [dict(chan) for chan in params.get("chanlocs", ())] + removedchans = [dict(chan) for chan in params.get("removedchans", ())] + alleeg = [dict(eeg) for eeg in params.get("alleeg", ())] + + if source in {"removedchans", "nondatchan"}: + labels = [str(chan.get("labels", "")) for chan in removedchans] + chanlist, chanliststr, _allchanstr = pop_chansel(labels, parent=button) if not chanlist: return - - current_labels = {str(chan.get("labels", "")).lower() for chan in chanlocs} - selected_indices = [ - index - for index in chanlist - if str(other_chanlocs[index - 1].get("labels", "")).lower() not in current_labels - ] - if not selected_indices: - qt_widgets.QMessageBox.warning(button, "Warning", "No new channels selected") + selected = [removedchans[index - 1] for index in chanlist] + chanstr = "EEG.chaninfo.removedchans([" + " ".join(str(index) for index in chanlist) + "])" + _store_interp_selection(target, selected, chanstr, chanliststr) + return + + if source == "datchan": + labels = [str(chan.get("labels", "")) for chan in chanlocs] + chanlist, chanliststr, _allchanstr = pop_chansel(labels, parent=button) + if not chanlist: return + selected = [index - 1 for index in chanlist] + chanstr = "[" + " ".join(str(index) for index in chanlist) + "]" + _store_interp_selection(target, selected, chanstr, chanliststr) + return + + _qt_core, qt_widgets = _require_qt() + dataset_index, accepted = qt_widgets.QInputDialog.getInt( + button, + "Choose dataset", + "Dataset index", + 1, + 1, + max(1, len(alleeg)), + ) + if not accepted: + return + if dataset_index < 1 or dataset_index > len(alleeg): + qt_widgets.QMessageBox.warning(button, "Warning", "Wrong index") + return + + other = alleeg[dataset_index - 1] + other_chanlocs = [dict(chan) for chan in other.get("chanlocs", ())] + if source == "selectchan": + labels = [str(chan.get("labels", "")) for chan in other_chanlocs] + chanlist, _chanliststr, _allchanstr = pop_chansel(labels, parent=button) + else: + chanlist = list(range(1, len(other_chanlocs) + 1)) + if not chanlist: + return + + current_labels = {str(chan.get("labels", "")).lower() for chan in chanlocs} + selected_indices = [ + index for index in chanlist if str(other_chanlocs[index - 1].get("labels", "")).lower() not in current_labels + ] + if not selected_indices: + qt_widgets.QMessageBox.warning(button, "Warning", "No new channels selected") + return + + if len(chanlist) == len(other_chanlocs): + selected = other_chanlocs + chanstr = f"ALLEEG({dataset_index}).chanlocs" + else: + selected_indices = sorted(selected_indices) + selected = [other_chanlocs[index - 1] for index in selected_indices] + chanstr = f"ALLEEG({dataset_index}).chanlocs([" + " ".join(str(index) for index in selected_indices) + "])" + display = " ".join(str(other_chanlocs[index - 1].get("labels", "")) for index in selected_indices) + _store_interp_selection(target, selected, chanstr, display) + + +def _store_interp_selection(target: Any, chans: Any, chanstr: str, display: str) -> None: + target.setProperty(_VALUE_PROPERTY, {"chans": chans, "chanstr": chanstr}) + target.setText(display.strip()) + + +def _set_reref_mode(widgets: dict[str, Any], mode: str, checked: bool) -> None: + if not checked: + if mode == "channels" and "ave" in widgets: + widgets["ave"].setChecked(True) + return + + average_mode = mode in {"average", "huber"} + for tag in ("ave", "huberef", "rerefstr"): + if tag in widgets: + widgets[tag].blockSignals(True) + widgets[tag].setChecked( + (tag == "ave" and mode == "average") + or (tag == "huberef" and mode == "huber") + or (tag == "rerefstr" and mode == "channels") + ) + widgets[tag].blockSignals(False) - if len(chanlist) == len(other_chanlocs): - selected = other_chanlocs - chanstr = f"ALLEEG({dataset_index}).chanlocs" - else: - selected_indices = sorted(selected_indices) - selected = [other_chanlocs[index - 1] for index in selected_indices] - chanstr = f"ALLEEG({dataset_index}).chanlocs([" + " ".join(str(index) for index in selected_indices) + "])" - display = " ".join(str(other_chanlocs[index - 1].get("labels", "")) for index in selected_indices) - QtDialogRenderer._store_interp_selection(target, selected, chanstr, display) - - @staticmethod - def _store_interp_selection(target: Any, chans: Any, chanstr: str, display: str) -> None: - target.setProperty(_VALUE_PROPERTY, {"chans": chans, "chanstr": chanstr}) - target.setText(display.strip()) - - @staticmethod - def _set_reref_mode(widgets: dict[str, Any], mode: str, checked: bool) -> None: - if not checked: - if mode == "channels" and "ave" in widgets: - widgets["ave"].setChecked(True) - return + for tag in ("reref", "refbr", "keepref"): + if tag in widgets: + widgets[tag].setEnabled(not average_mode) + if average_mode and "keepref" in widgets: + widgets["keepref"].setChecked(False) + + +def _set_enabled(widgets: list[Any], enabled: bool) -> None: + for widget in widgets: + widget.setEnabled(enabled) - average_mode = mode in {"average", "huber"} - for tag in ("ave", "huberef", "rerefstr"): - if tag in widgets: - widgets[tag].blockSignals(True) - widgets[tag].setChecked( - (tag == "ave" and mode == "average") - or (tag == "huberef" and mode == "huber") - or (tag == "rerefstr" and mode == "channels") - ) - widgets[tag].blockSignals(False) - - for tag in ("reref", "refbr", "keepref"): - if tag in widgets: - widgets[tag].setEnabled(not average_mode) - if average_mode and "keepref" in widgets: - widgets["keepref"].setChecked(False) - - @staticmethod - def _set_enabled(widgets: list[Any], enabled: bool) -> None: - for widget in widgets: - widget.setEnabled(enabled) - - @staticmethod - def _show_help(_qt_widgets: Any, dialog: Any, spec: DialogSpec) -> None: - dialog._eegprep_help_dialog = pophelp(spec.help_text or spec.function_name, parent=dialog) - - @staticmethod - def _read_widget(widget: Any) -> Any: - stored_value = widget.property(_VALUE_PROPERTY) - if stored_value is not None: - return stored_value - if hasattr(widget, "isChecked"): - return widget.isChecked() - if widget.property(_MULTI_SELECT_PROPERTY) and hasattr(widget, "selectedIndexes"): - return sorted({index.row() + 1 for index in widget.selectedIndexes()}) - if hasattr(widget, "currentRow"): - return widget.currentRow() + 1 - if hasattr(widget, "currentIndex"): - return widget.currentIndex() + 1 - if hasattr(widget, "toPlainText"): - return widget.toPlainText() - if hasattr(widget, "text"): - return widget.text() - return None + +def _show_help(_qt_widgets: Any, dialog: Any, spec: DialogSpec) -> None: + dialog._eegprep_help_dialog = pophelp(spec.help_text or spec.function_name, parent=dialog) + + +def _read_widget(widget: Any) -> Any: + stored_value = widget.property(_VALUE_PROPERTY) + if stored_value is not None: + return stored_value + if hasattr(widget, "isChecked"): + return widget.isChecked() + if widget.property(_MULTI_SELECT_PROPERTY) and hasattr(widget, "selectedIndexes"): + return sorted({index.row() + 1 for index in widget.selectedIndexes()}) + if hasattr(widget, "currentRow"): + return widget.currentRow() + 1 + if hasattr(widget, "currentIndex"): + return widget.currentIndex() + 1 + if hasattr(widget, "toPlainText"): + return widget.toPlainText() + if hasattr(widget, "text"): + return widget.text() + return None def _is_sequence_value(value: Any) -> bool: @@ -1309,3 +1356,56 @@ def _firpm_order_shape( def _firpm_default_devs(amplitudes: list[float]) -> list[float]: return [0.01 if value == 1 else 0.001 for value in amplitudes] + + +_QT_RENDERER_STATIC_HELPERS = ( + "_QDialog", + "_apply_eeglab_style", + "_row_weights", + "_row_stretch", + "_spacer_row_height", + "_add_buttons", + "_apply_spec_size", + "_apply_font_hints", + "_apply_widget_size_policy", + "_accept_if_valid", + "_validation_message", + "_validate_pop_reref_dialog", + "_validate_pop_interp_dialog", + "_callback_channels", + "_validate_channel_text", + "_parse_channel_text", + "_parse_numeric_text", + "_widget_number", + "_widget_vector", + "_combo_choice", + "_is_int_text", + "_widget_checked", + "_widget_text", + "_plot_tf_cycle_calc", + "_estimate_fir_kaiser_beta", + "_estimate_firws_order", + "_estimate_firpm_order", + "_plot_fir_response", + "_sync_numeric", + "_select_event_types", + "_select_channels", + "_select_file", + "_open_eegplot", + "_open_rejection_browser", + "_set_headplot_setup_mode", + "_set_headplot_mesh_choice", + "_run_headplot_manual_coreg", + "_choice_or_text", + "_show_callback_message", + "_edit_text", + "_select_interp_channels", + "_store_interp_selection", + "_set_reref_mode", + "_set_enabled", + "_show_help", + "_read_widget", +) +for _helper_name in _QT_RENDERER_STATIC_HELPERS: + setattr(QtDialogRenderer, _helper_name, staticmethod(globals()[_helper_name])) +del _helper_name diff --git a/src/eegprep/functions/guifunc/session.py b/src/eegprep/functions/guifunc/session.py index 951081ab..bbdc57f4 100644 --- a/src/eegprep/functions/guifunc/session.py +++ b/src/eegprep/functions/guifunc/session.py @@ -373,6 +373,34 @@ def add_history(self, command: str | None, *, notify: bool = True) -> None: if notify: self.notify_changed() + def clear_history(self, *, notify: bool = True) -> None: + """Clear command history and LASTCOM as one session mutation.""" + self.ALLCOM.clear() + self.LASTCOM = "" + if notify: + self.notify_changed() + + def remove_history(self, count: int, *, notify: bool = True) -> None: + """Remove the most recent ``count`` command-history entries.""" + remove_count = min(max(int(count), 0), len(self.ALLCOM)) + if remove_count: + del self.ALLCOM[-remove_count:] + self.LASTCOM = self.ALLCOM[-1] if self.ALLCOM else "" + if notify: + self.notify_changed() + + def history_command_at(self, index: int) -> str: + """Return the 1-based command from most recent history first.""" + if index < 1 or index > len(self.ALLCOM): + return "" + return list(reversed(self.ALLCOM))[index - 1] + + def clear_last_command(self, *, notify: bool = True) -> None: + """Clear LASTCOM without deleting ALLCOM.""" + self.LASTCOM = "" + if notify: + self.notify_changed() + def _append_current_dataset_history(self, command: str | None) -> None: if not command: return diff --git a/src/eegprep/functions/miscfunc/event_utils.py b/src/eegprep/functions/miscfunc/event_utils.py new file mode 100644 index 00000000..1b96d65a --- /dev/null +++ b/src/eegprep/functions/miscfunc/event_utils.py @@ -0,0 +1,59 @@ +"""Shared EEGLAB-style event helpers.""" + +from __future__ import annotations + +from typing import Any + +import numpy as np + +from eegprep.functions.adminfunc.eeg_options import EEG_OPTIONS + + +def boundary_event_indices(EEG: Any) -> list[int]: + """Return 0-based indices of EEGLAB boundary events.""" + events = _event_records(EEG) + if not events or not isinstance(events[0], dict) or "type" not in events[0]: + return [] + first_type = events[0].get("type") + if isinstance(first_type, bytes): + first_type = first_type.decode("utf-8") + if isinstance(first_type, str): + return [ + index for index, event in enumerate(events) if _string_event_type(event.get("type")).startswith("boundary") + ] + if EEG_OPTIONS["option_boundary99"]: + return [index for index, event in enumerate(events) if _numeric_event_type(event.get("type")) == -99] + return [] + + +def is_boundary_event(event: dict[str, Any]) -> bool: + """Return whether one event matches EEGLAB boundary semantics.""" + return bool(boundary_event_indices([event])) + + +def _event_records(EEG: Any) -> list[Any]: + if EEG is None: + return [] + if isinstance(EEG, dict) and "event" in EEG and "setname" in EEG: + EEG = EEG["event"] + if isinstance(EEG, np.ndarray): + EEG = EEG.tolist() + if isinstance(EEG, dict): + return [EEG] + if isinstance(EEG, (list, tuple)): + return list(EEG) + return [] + + +def _string_event_type(value: Any) -> str: + if isinstance(value, bytes): + return value.decode("utf-8") + if isinstance(value, str): + return value + return "" + + +def _numeric_event_type(value: Any) -> float | None: + if isinstance(value, (int, float, np.integer, np.floating)): + return float(value) + return None diff --git a/src/eegprep/functions/miscfunc/value_parsing.py b/src/eegprep/functions/miscfunc/value_parsing.py new file mode 100644 index 00000000..8b642580 --- /dev/null +++ b/src/eegprep/functions/miscfunc/value_parsing.py @@ -0,0 +1,138 @@ +"""Shared EEGLAB-style value parsing helpers.""" + +from __future__ import annotations + +from collections.abc import Iterable +import re +from typing import Any + +import numpy as np + + +_TOKEN_PATTERN = re.compile(r"'([^']*)'|\"([^\"]*)\"|([^,\s]+)") +_RANGE_TOKEN = re.compile(r"^([^:]+):([^:]+)(?::([^:]+))?$") + + +def parse_key_value_args( + args: tuple[Any, ...], + kwargs: dict[str, Any] | None = None, + *, + lowercase_keys: bool = True, + lowercase_kwargs: bool = False, +) -> dict[str, Any]: + """Parse EEGLAB-style key/value positional arguments.""" + if len(args) % 2: + raise ValueError("Key/value arguments must be in pairs") + options: dict[str, Any] = {} + for key, value in (kwargs or {}).items(): + parsed_key = str(key).lower() if lowercase_kwargs else key + options[parsed_key] = value + for index in range(0, len(args), 2): + key = args[index] + if isinstance(key, bytes): + key = key.decode("utf-8") + if not isinstance(key, str): + raise ValueError("Keys must be strings") + parsed_key = key.lower() if lowercase_keys else key + options[parsed_key] = args[index + 1] + return options + + +def parse_text_tokens(text: Any, *, parse_ints: bool = False) -> list[Any]: + """Parse MATLAB text/cell-list token strings used by GUI dialogs.""" + tokens = _TOKEN_PATTERN.findall(str(text).strip().strip("{}")) + values = [next(part for part in token if part) for token in tokens] + if not parse_ints: + return values + parsed = [] + for value in values: + try: + parsed.append(int(value)) + except ValueError: + parsed.append(value) + return parsed + + +def parse_numeric_sequence(value: Any, *, dtype: type = float) -> list[Any]: + """Parse EEGLAB-style numeric vectors, including ``start:stop`` ranges.""" + if value is None: + return [] + if isinstance(value, np.ndarray): + return [dtype(item) for item in value.ravel().tolist()] + if isinstance(value, (int, float, np.integer, np.floating)): + return [dtype(value)] + if isinstance(value, Iterable) and not isinstance(value, (str, bytes, dict)): + values: list[Any] = [] + for item in value: + if isinstance(item, str): + values.extend(parse_numeric_sequence(item, dtype=dtype)) + else: + values.append(dtype(item)) + return values + text = str(value).strip().strip("[]") + if not text: + return [] + values = [] + for token in re.split(r"[\s,;]+", text): + if not token: + continue + match = _RANGE_TOKEN.match(token) + if match: + values.extend(_parse_range_token(match, dtype=dtype)) + else: + values.append(dtype(_parse_numeric_atom(token))) + return values + + +def is_empty_value(value: Any) -> bool: + """Return whether a GUI/history value means empty in EEGLAB dialogs.""" + if value is None: + return True + if isinstance(value, str): + return value.strip().strip("[]{}") == "" + if isinstance(value, np.ndarray): + return value.size == 0 + return isinstance(value, (list, tuple, set, dict)) and len(value) == 0 + + +def is_on(value: Any) -> bool: + """Normalize EEGLAB-style on/off values.""" + if isinstance(value, str): + return value.strip().lower() in {"1", "on", "true", "yes"} + if isinstance(value, np.ndarray): + return bool(value.size and np.asarray(value).ravel()[0]) + if isinstance(value, Iterable) and not isinstance(value, (str, bytes, dict)): + values = list(value) + return bool(values and values[0]) + return bool(value) + + +def _parse_range_token(match: re.Match[str], *, dtype: type) -> list[Any]: + first = _parse_numeric_atom(match.group(1)) + second = _parse_numeric_atom(match.group(2)) + third = match.group(3) + if third is None: + start, stop = first, second + step = 1.0 if stop >= start else -1.0 + else: + start, step, stop = first, second, _parse_numeric_atom(third) + if step == 0 or not np.all(np.isfinite([start, step, stop])): + raise ValueError("Invalid colon range") + if (stop - start) * step < 0: + return [] + count = int(np.floor((stop - start) / step + 1e-9)) + 1 + values = [start + index * step for index in range(max(count, 0))] + if values and np.isclose(values[-1], stop, rtol=0.0, atol=max(abs(step), 1.0) * 1e-12): + values[-1] = stop + return [dtype(value) for value in values] + + +def _parse_numeric_atom(value: Any) -> float: + text = str(value).strip() + if text.lower() == "nan": + return float("nan") + if text.lower() == "inf": + return float("inf") + if text.lower() == "-inf": + return float("-inf") + return float(text) diff --git a/src/eegprep/functions/popfunc/_eegplot_rejection.py b/src/eegprep/functions/popfunc/_eegplot_rejection.py index a4a9e129..d3178cc3 100644 --- a/src/eegprep/functions/popfunc/_eegplot_rejection.py +++ b/src/eegprep/functions/popfunc/_eegplot_rejection.py @@ -15,17 +15,21 @@ rejection_data, update_reject_fields, ) -from eegprep.functions.popfunc.pop_eegplot import ( - DEFAULT_REJECTION_COLORS, - MANUAL_REJECTION_COLOR, - REJECTION_FAMILIES, - pad_rejection_rows, - rejection_row_count, -) from eegprep.functions.popfunc.pop_rejepoch import pop_rejepoch from eegprep.functions.sigprocfunc.eegplot import eegplot, eegplot2trial, trial2eegplot +MANUAL_REJECTION_COLOR = (1.0, 0.9, 0.9) +DEFAULT_REJECTION_COLORS = { + "manual": (1.0000, 1.0000, 0.7830), + "thresh": (0.8487, 1.0000, 0.5008), + "const": (0.6940, 1.0000, 0.7008), + "jp": (1.0000, 0.6991, 0.7537), + "kurt": (0.6880, 0.7042, 1.0000), + "freq": (0.9596, 0.7193, 1.0000), +} +REJECTION_FAMILIES = ("manual", "thresh", "const", "jp", "kurt", "freq") + # Autorej shares manual-color browser marks and is not superposed as a separate EEGLAB family. _AUTO_REJECTION_COLOR = MANUAL_REJECTION_COLOR @@ -100,6 +104,49 @@ def run_epoched_rejection( ) +def run_epoched_mark_rejection( + EEG: dict[str, Any], + icacomp: int | bool, + elecrange: Any, + superpose: int | bool, + reject: int | bool, + *, + marks_fn: Callable[[dict[str, Any], np.ndarray, list[int]], tuple[np.ndarray, np.ndarray, Any]], + kind: str, + error_message: str, + command_fn: Callable[[list[int], Any], str], + display: bool = False, + command_callback: Any | None = None, + show: bool = True, +) -> tuple[dict[str, Any], list[int], str, Any]: + """Run an epoched rejection method that already computes trial/row marks.""" + out = copy_eeg(EEG) + data, row_count = rejection_data(out, icacomp) + if int(out.get("trials", data.shape[2]) or data.shape[2]) <= 1: + raise ValueError(error_message) + elecrange = one_based_indices(elecrange, limit=row_count, default_all=True) + marks, marks_e, payload = marks_fn(out, data, elecrange) + update_reject_fields(out, icacomp=icacomp, kind=kind, reject=marks, reject_e=marks_e) + rejected = (np.flatnonzero(marks) + 1).tolist() + command = command_fn(elecrange, payload) + if display: + open_epoched_rejection_browser( + out, + data=data, + icacomp=icacomp, + elecrange=elecrange, + kind=kind, + superpose=superpose, + reject=reject, + command=command, + command_callback=command_callback, + show=show, + ) + elif int(bool(reject)) and rejected: + out = pop_rejepoch(out, rejected, 0) + return out, rejected, command, payload + + def vistype_from_gui(value: Any) -> int: """Map an EEGLAB visualization-mode popup value to a vistype flag.""" if isinstance(value, str): @@ -290,6 +337,58 @@ def ensure_rejection_defaults(EEG: dict[str, Any]) -> None: reject.setdefault(f"rej{family}col", np.asarray(color, dtype=float)) +def pad_rejection_rows(values: Any, row_count: int, trials: int) -> np.ndarray: + """Zero-pad/crop a row-mask array to ``(row_count, trials)``.""" + out = np.zeros((row_count, trials), dtype=bool) + arr = np.asarray(values, dtype=bool) + if arr.ndim == 1 and arr.size: + arr = arr.reshape(1, -1) + if arr.ndim == 2: + rows = min(row_count, arr.shape[0]) + cols = min(trials, arr.shape[1]) + out[:rows, :cols] = arr[:rows, :cols] + return out + + +def rejection_row_count(EEG: dict[str, Any], icacomp: int | bool) -> int: + """Return the number of channel or component rows for rejection marks.""" + if int(bool(icacomp)): + return int(EEG.get("nbchan", np.asarray(EEG.get("data")).shape[0]) or 0) + weights = np.asarray(EEG.get("icaweights", [])) + return int(weights.shape[0]) if weights.ndim == 2 else 0 + + +def manual_rejection_color(EEG: dict[str, Any]) -> tuple[float, float, float]: + """Return the EEG manual rejection color.""" + return rejection_family_color(EEG.get("reject") or {}, "manual", MANUAL_REJECTION_COLOR) + + +def displayed_rejection_families(reject: dict[str, Any]) -> tuple[str, ...]: + """Return the EEGLAB rejection families currently displayed in a browser.""" + disprej = reject.get("disprej") + if disprej is not None and np.asarray(disprej, dtype=object).size: + return tuple(str(item) for item in np.asarray(disprej, dtype=object).ravel() if str(item) in REJECTION_FAMILIES) + return tuple(family for family in REJECTION_FAMILIES if has_rejection_family(reject, family)) + + +def has_rejection_family(reject: dict[str, Any], family: str) -> bool: + """Return whether data or component marks exist for a rejection family.""" + return any(np.asarray(reject.get(field, [])).size for field in (f"rej{family}", f"icarej{family}")) + + +def rejection_family_color( + reject: dict[str, Any], + family: str, + default: tuple[float, float, float] | None = None, +) -> tuple[float, float, float]: + """Return a normalized RGB color for a rejection family.""" + fallback = default if default is not None else DEFAULT_REJECTION_COLORS.get(family, MANUAL_REJECTION_COLOR) + values = np.asarray(reject.get(f"rej{family}col", fallback), dtype=float).ravel() + if values.size < 3: + return fallback + return tuple(float(item) for item in values[:3]) + + def _apply_superposed_family_rows( EEG: dict[str, Any], rows: np.ndarray, @@ -302,7 +401,7 @@ def _apply_superposed_family_rows( pnts: int, ) -> None: reject = EEG.setdefault("reject", {}) - families = set(_displayed_families(reject)) + families = set(displayed_rejection_families(reject)) families.add(_family_from_kind(kind)) for family in sorted(families): family_kind = f"rej{family}" @@ -390,14 +489,7 @@ def _trial_marks(value: Any, trials: int) -> np.ndarray: def _displayed_families(reject: dict[str, Any]) -> tuple[str, ...]: - disprej = reject.get("disprej") - if disprej is not None and np.asarray(disprej, dtype=object).size: - return tuple(str(item) for item in np.asarray(disprej, dtype=object).ravel() if str(item) in REJECTION_FAMILIES) - return tuple(family for family in REJECTION_FAMILIES if _has_family_marks(reject, family)) - - -def _has_family_marks(reject: dict[str, Any], family: str) -> bool: - return any(np.asarray(reject.get(field, [])).size for field in (f"rej{family}", f"icarej{family}")) + return displayed_rejection_families(reject) def _kind_color(EEG: dict[str, Any], kind: str) -> tuple[float, float, float]: @@ -409,11 +501,7 @@ def _kind_color(EEG: dict[str, Any], kind: str) -> tuple[float, float, float]: def _family_color(reject: dict[str, Any], family: str) -> tuple[float, float, float]: - default = DEFAULT_REJECTION_COLORS.get(family, MANUAL_REJECTION_COLOR) - values = np.asarray(reject.get(f"rej{family}col", default), dtype=float).ravel() - if values.size < 3: - return default - return tuple(float(item) for item in values[:3]) + return rejection_family_color(reject, family) def _family_from_kind(kind: str) -> str: diff --git a/src/eegprep/functions/popfunc/_event_utils.py b/src/eegprep/functions/popfunc/_event_utils.py index e70b1d85..4c5aec1d 100644 --- a/src/eegprep/functions/popfunc/_event_utils.py +++ b/src/eegprep/functions/popfunc/_event_utils.py @@ -7,7 +7,8 @@ import numpy as np -from eegprep.functions.popfunc._pop_utils import is_empty_value as _is_empty +from eegprep.functions.miscfunc.event_utils import is_boundary_event as _is_boundary +from eegprep.functions.miscfunc.value_parsing import is_empty_value as _is_empty def events_as_list(events: Any) -> list[dict[str, Any]]: @@ -92,10 +93,7 @@ def sort_events(events: list[dict[str, Any]]) -> list[dict[str, Any]]: def is_boundary_event(event: dict[str, Any]) -> bool: """Return true for EEGLAB boundary events.""" - event_type = event.get("type", "") - if isinstance(event_type, (int, float, np.integer, np.floating)) and float(event_type) == -1: - return True - return str(event_type).lower() == "boundary" + return _is_boundary(event) def event_value_for_history(value: Any) -> Any: diff --git a/src/eegprep/functions/popfunc/_ica_utils.py b/src/eegprep/functions/popfunc/_ica_utils.py index f0738664..86c7305d 100644 --- a/src/eegprep/functions/popfunc/_ica_utils.py +++ b/src/eegprep/functions/popfunc/_ica_utils.py @@ -1,5 +1,8 @@ import numpy as np +from eegprep.functions.miscfunc.misc import finite_matmul, finite_pinv +from eegprep.functions.miscfunc.pinv import pinv + def flatten_ica_data(data): """Flatten channel-major EEG data using EEGLAB/MATLAB epoch ordering.""" @@ -20,6 +23,8 @@ def finalize_ica_fields(EEG, *, sortcomps='off', posact='off'): ``EEG['icawinv']`` and returns ``EEG``. Shared by the runica, AMICA, and Picard backends so the post-decomposition behavior stays identical. """ + EEG['icawinv'] = finite_pinv(finite_matmul(EEG['icaweights'], EEG['icasphere']), solver=pinv) + # Optionally sort components by mean descending activation variance if sortcomps in ('on', True): # Flatten icaact to 2D for variance computation @@ -52,4 +57,5 @@ def finalize_ica_fields(EEG, *, sortcomps='off', posact='off'): EEG['icawinv'][:, r] = -EEG['icawinv'][:, r] EEG['icaweights'][r, :] = -EEG['icaweights'][r, :] + EEG['icawinv'] = finite_pinv(finite_matmul(EEG['icaweights'], EEG['icasphere']), solver=pinv) return EEG diff --git a/src/eegprep/functions/popfunc/_pop_utils.py b/src/eegprep/functions/popfunc/_pop_utils.py index a30702d1..8133f79f 100644 --- a/src/eegprep/functions/popfunc/_pop_utils.py +++ b/src/eegprep/functions/popfunc/_pop_utils.py @@ -2,107 +2,18 @@ from __future__ import annotations -from collections.abc import Iterable -import re from pathlib import PurePath from typing import Any, Callable import numpy as np - -_TOKEN_PATTERN = re.compile(r"'([^']*)'|\"([^\"]*)\"|([^,\s]+)") -_RANGE_TOKEN = re.compile(r"^([^:]+):([^:]+)(?::([^:]+))?$") - - -def parse_key_value_args( - args: tuple[Any, ...], - kwargs: dict[str, Any] | None = None, - *, - lowercase_keys: bool = True, - lowercase_kwargs: bool = False, -) -> dict[str, Any]: - """Parse EEGLAB-style key/value positional arguments.""" - if len(args) % 2: - raise ValueError("Key/value arguments must be in pairs") - options: dict[str, Any] = {} - for key, value in (kwargs or {}).items(): - parsed_key = str(key).lower() if lowercase_kwargs else key - options[parsed_key] = value - for index in range(0, len(args), 2): - key = args[index] - if isinstance(key, bytes): - key = key.decode("utf-8") - if not isinstance(key, str): - raise ValueError("Keys must be strings") - parsed_key = key.lower() if lowercase_keys else key - options[parsed_key] = args[index + 1] - return options - - -def parse_text_tokens(text: Any, *, parse_ints: bool = False) -> list[Any]: - """Parse MATLAB text/cell-list token strings used by pop dialogs.""" - tokens = _TOKEN_PATTERN.findall(str(text).strip().strip("{}")) - values = [next(part for part in token if part) for token in tokens] - if not parse_ints: - return values - parsed = [] - for value in values: - try: - parsed.append(int(value)) - except ValueError: - parsed.append(value) - return parsed - - -def parse_numeric_sequence(value: Any, *, dtype: type = float) -> list[Any]: - """Parse EEGLAB-style numeric vectors, including ``start:stop`` ranges.""" - if value is None: - return [] - if isinstance(value, np.ndarray): - return [dtype(item) for item in value.ravel().tolist()] - if isinstance(value, (int, float, np.integer, np.floating)): - return [dtype(value)] - if isinstance(value, Iterable) and not isinstance(value, (str, bytes, dict)): - values: list[Any] = [] - for item in value: - if isinstance(item, str): - values.extend(parse_numeric_sequence(item, dtype=dtype)) - else: - values.append(dtype(item)) - return values - text = str(value).strip().strip("[]") - if not text: - return [] - values = [] - for token in re.split(r"[\s,]+", text): - if not token: - continue - match = _RANGE_TOKEN.match(token) - if match: - values.extend(_parse_range_token(match, dtype=dtype)) - else: - values.append(dtype(_parse_numeric_atom(token))) - return values - - -def is_empty_value(value: Any) -> bool: - """Return whether a GUI/history value means empty in EEGLAB pop dialogs.""" - if value is None: - return True - if isinstance(value, str): - return value.strip().strip("[]{}") == "" - if isinstance(value, np.ndarray): - return value.size == 0 - return isinstance(value, (list, tuple, set, dict)) and len(value) == 0 - - -def is_on(value: Any) -> bool: - """Normalize EEGLAB-style on/off values.""" - if isinstance(value, str): - return value.strip().lower() in {"1", "on", "true", "yes"} - if isinstance(value, np.ndarray): - return bool(value.size and np.asarray(value).ravel()[0]) - return bool(value) +from eegprep.functions.miscfunc.value_parsing import ( + is_empty_value as is_empty_value, + is_on as is_on, + parse_key_value_args as parse_key_value_args, + parse_numeric_sequence as parse_numeric_sequence, + parse_text_tokens as parse_text_tokens, +) def format_history_value( @@ -202,34 +113,3 @@ def _format_history_number(value: Any, formatter: Callable[[Any], str] | None) - if value.is_integer(): return str(int(value)) return str(value) - - -def _parse_range_token(match: re.Match[str], *, dtype: type) -> list[Any]: - first = _parse_numeric_atom(match.group(1)) - second = _parse_numeric_atom(match.group(2)) - third = match.group(3) - if third is None: - start, stop = first, second - step = 1.0 if stop >= start else -1.0 - else: - start, step, stop = first, second, _parse_numeric_atom(third) - if step == 0 or not np.all(np.isfinite([start, step, stop])): - raise ValueError("Invalid colon range") - if (stop - start) * step < 0: - return [] - count = int(np.floor((stop - start) / step + 1e-9)) + 1 - values = [start + index * step for index in range(max(count, 0))] - if values and np.isclose(values[-1], stop, rtol=0.0, atol=max(abs(step), 1.0) * 1e-12): - values[-1] = stop - return [dtype(value) for value in values] - - -def _parse_numeric_atom(value: Any) -> float: - text = str(value).strip() - if text.lower() == "nan": - return float("nan") - if text.lower() == "inf": - return float("inf") - if text.lower() == "-inf": - return float("-inf") - return float(text) diff --git a/src/eegprep/functions/popfunc/eeg_amica.py b/src/eegprep/functions/popfunc/eeg_amica.py index b4ab5399..dfc7d626 100644 --- a/src/eegprep/functions/popfunc/eeg_amica.py +++ b/src/eegprep/functions/popfunc/eeg_amica.py @@ -140,4 +140,4 @@ def load_amica_model(EEG, mods, model_num=0): EEG['icaact'] = (EEG['icaweights'] @ EEG['icasphere']) @ data EEG['icaact'] = reshape_ica_activations(EEG['icaact'], EEG['pnts'], EEG['trials']) - return EEG + return finalize_ica_fields(EEG) diff --git a/src/eegprep/functions/popfunc/eeg_eegrej.py b/src/eegprep/functions/popfunc/eeg_eegrej.py index 3e841a0f..1ea0c289 100644 --- a/src/eegprep/functions/popfunc/eeg_eegrej.py +++ b/src/eegprep/functions/popfunc/eeg_eegrej.py @@ -4,24 +4,14 @@ from typing import List, Dict, Optional, Tuple import numpy as np from copy import deepcopy +from eegprep.functions.miscfunc.event_utils import boundary_event_indices +from eegprep.functions.miscfunc.event_utils import is_boundary_event as _is_boundary_event from ..miscfunc.misc import round_mat logger = logging.getLogger(__name__) -def _is_boundary_event(event: Dict) -> bool: - t = event.get("type") - if isinstance(t, str): - return t.lower() == "boundary" - if isinstance(t, (int, float)): - try: - return int(t) == -99 - except Exception: - return False - return False - - def _eegrej( indata, regions, timelength, events: Optional[List[Dict]] = None ) -> Tuple[np.ndarray, float, List[Dict], np.ndarray]: @@ -373,14 +363,7 @@ def _combine_regions(regs): def _find_boundary_event_indices(events): - idx = [] - for i, ev in enumerate(events): - t = ev.get("type") - if isinstance(t, str) and t.lower() == "boundary": - idx.append(i) - elif isinstance(t, (int, float)) and int(t) == -99: - idx.append(i) - return np.array(idx, dtype=int) + return np.asarray(boundary_event_indices(events), dtype=int) def _insert_boundaries(events, old_pnts, regions): diff --git a/src/eegprep/functions/popfunc/eeg_findboundaries.py b/src/eegprep/functions/popfunc/eeg_findboundaries.py index 9bccc9d7..9b887a8d 100644 --- a/src/eegprep/functions/popfunc/eeg_findboundaries.py +++ b/src/eegprep/functions/popfunc/eeg_findboundaries.py @@ -1,6 +1,6 @@ """EEG boundary finding functions.""" -from eegprep.functions.adminfunc.eeg_options import EEG_OPTIONS +from eegprep.functions.miscfunc.event_utils import boundary_event_indices def eeg_findboundaries(*, EEG): @@ -22,41 +22,4 @@ def eeg_findboundaries(*, EEG): # In MATLAB: help eeg_findboundaries; return return [] - boundaries = [] - # isempty(EEG) - if EEG == {} or EEG == []: - return boundaries - - # Determine tmpevent - if isinstance(EEG, dict) and ('event' in EEG and 'setname' in EEG): - tmpevent = EEG['event'] - else: - tmpevent = EEG - - # If tmpevent lacks 'type' - if isinstance(tmpevent, list): - if len(tmpevent) == 0 or not isinstance(tmpevent[0], dict) or 'type' not in tmpevent[0]: - return boundaries - elif isinstance(tmpevent, dict): - if 'type' not in tmpevent: - return boundaries - # Normalize to list for unified handling - tmpevent = [tmpevent] - else: - return boundaries - - first_type = tmpevent[0].get('type') - if isinstance(first_type, str): - # boundaries = strmatch('boundary', { tmpevent.type }); - boundaries = [ - i - for i, ev in enumerate(tmpevent) - if isinstance(ev.get('type'), str) and ev.get('type', '').startswith('boundary') - ] - elif EEG_OPTIONS['option_boundary99']: - # boundaries = find([ tmpevent.type ] == -99); - boundaries = [i for i, ev in enumerate(tmpevent) if ev.get('type') == -99] - else: - boundaries = [] - - return boundaries + return boundary_event_indices(EEG) diff --git a/src/eegprep/functions/popfunc/pop_adjustevents.py b/src/eegprep/functions/popfunc/pop_adjustevents.py index 7a2d89bd..74c53d5e 100644 --- a/src/eegprep/functions/popfunc/pop_adjustevents.py +++ b/src/eegprep/functions/popfunc/pop_adjustevents.py @@ -10,8 +10,8 @@ import numpy as np from eegprep.functions.adminfunc.eeg_checkset import eeg_checkset -from eegprep.functions.adminfunc.eeg_options import EEG_OPTIONS from eegprep.functions.guifunc.spec import CallbackSpec, ControlSpec, DialogSpec +from eegprep.functions.miscfunc.event_utils import is_boundary_event from eegprep.functions.popfunc._pop_utils import ( format_history_value, parse_key_value_args, @@ -210,13 +210,7 @@ def _check_force(EEG: dict, events: list[dict[str, Any]], force: str) -> None: def _has_boundary_event(events: list[dict[str, Any]]) -> bool: - for event in events: - event_type = event.get("type") - if isinstance(event_type, str) and event_type.startswith("boundary"): - return True - if EEG_OPTIONS.get("option_boundary99") and event_type == -99: - return True - return False + return any(is_boundary_event(event) for event in events) def _unique_event_types(events: list[dict[str, Any]]) -> list[Any]: diff --git a/src/eegprep/functions/popfunc/pop_eegplot.py b/src/eegprep/functions/popfunc/pop_eegplot.py index 5ebedd52..1335de37 100644 --- a/src/eegprep/functions/popfunc/pop_eegplot.py +++ b/src/eegprep/functions/popfunc/pop_eegplot.py @@ -7,11 +7,19 @@ from eegprep.functions.popfunc._plot_utils import history_command import numpy as np +from eegprep.functions.popfunc._eegplot_rejection import ( + DEFAULT_REJECTION_COLORS, + MANUAL_REJECTION_COLOR, + displayed_rejection_families, + manual_rejection_color, + pad_rejection_rows, + rejection_family_color, + rejection_row_count, +) from eegprep.functions.popfunc._rejection import copy_eeg from eegprep.functions.popfunc.eeg_eegrej import eeg_eegrej from eegprep.functions.popfunc.pop_rejepoch import pop_rejepoch from eegprep.functions.sigprocfunc.eegplot import ( - DEFAULT_WINREJ_COLOR, eegplot, eegplot2event, eegplot2trial, @@ -21,16 +29,6 @@ ) -MANUAL_REJECTION_COLOR = (1.0, 0.9, 0.9) -DEFAULT_REJECTION_COLORS = { - "manual": (1.0000, 1.0000, 0.7830), - "thresh": (0.8487, 1.0000, 0.5008), - "const": (0.6940, 1.0000, 0.7008), - "jp": (1.0000, 0.6991, 0.7537), - "kurt": (0.6880, 0.7042, 1.0000), - "freq": (0.9596, 0.7193, 1.0000), -} -REJECTION_FAMILIES = ("manual", "thresh", "const", "jp", "kurt", "freq") CONTINUOUS_MANUAL_WINREJ = "rejmanualwinrej" CONTINUOUS_ICA_MANUAL_WINREJ = "icarejmanualwinrej" @@ -65,7 +63,7 @@ def pop_eegplot( ) options.setdefault("events", EEG.get("event", [])) accept_callback = options.pop("command_callback", None) - options.setdefault("wincolor", _manual_color(EEG)) + options.setdefault("wincolor", manual_rejection_color(EEG)) options.setdefault("butlabel", "Reject" if int(bool(reject)) else "Update Marks") trials = int(EEG.get("trials", 1) or 1) if trials > 1: @@ -149,7 +147,7 @@ def apply_eegplot_rejections( pnts = int(out.get("pnts", np.asarray(out.get("data")).shape[1])) if rows.size: - trial_marks, row_marks = eegplot2trial(rows, pnts, trials, _manual_color(out), None) + trial_marks, row_marks = eegplot2trial(rows, pnts, trials, manual_rejection_color(out), None) store_superpose = superpose else: trial_marks = np.zeros(trials, dtype=bool) @@ -168,16 +166,16 @@ def _initial_epoch_winrej(EEG: dict[str, Any], icacomp: int, superpose: int) -> row_count = rejection_row_count(EEG, icacomp) manual, manual_e = _reject_arrays(reject, "rejmanual", trials, row_count, icacomp=icacomp) if int(superpose) == 0: - return trial2eegplot(manual, manual_e, pnts, _manual_color(EEG)) + return trial2eegplot(manual, manual_e, pnts, manual_rejection_color(EEG)) if int(superpose) == 2: return _superposed_epoch_winrej(EEG, reject, icacomp, trials, row_count, pnts, manual, manual_e) rows = [] - old_color = tuple(min(component + 0.15, 1.0) for component in _manual_color(EEG)) + old_color = tuple(min(component + 0.15, 1.0) for component in manual_rejection_color(EEG)) old, old_e = _reject_arrays(reject, "rejglobal", trials, row_count, icacomp=1) old_rows = trial2eegplot(old, old_e, pnts, old_color) if old_rows.size: rows.append(old_rows) - manual_rows = trial2eegplot(manual, manual_e, pnts, _manual_color(EEG)) + manual_rows = trial2eegplot(manual, manual_e, pnts, manual_rejection_color(EEG)) if manual_rows.size: rows.append(manual_rows) return np.vstack(rows) if rows else np.zeros((0, 5 + row_count), dtype=float) @@ -207,9 +205,9 @@ def _superposed_epoch_winrej( manual_e: np.ndarray, ) -> np.ndarray: rows = [] - manual_color = _manual_color(EEG) - for family in _displayed_rejection_families(reject): - color = _reject_color(reject, family, DEFAULT_REJECTION_COLORS.get(family, manual_color)) + manual_color = manual_rejection_color(EEG) + for family in displayed_rejection_families(reject): + color = rejection_family_color(reject, family, DEFAULT_REJECTION_COLORS.get(family, manual_color)) if tuple(color) == tuple(manual_color): continue marks, marks_e = _reject_arrays(reject, f"rej{family}", trials, row_count, icacomp=icacomp) @@ -296,27 +294,6 @@ def _reject_arrays( return out, row_marks -def pad_rejection_rows(values: np.ndarray, row_count: int, trials: int) -> np.ndarray: - """Zero-pad/crop a row-mask array to ``(row_count, trials)``.""" - out = np.zeros((row_count, trials), dtype=bool) - arr = np.asarray(values, dtype=bool) - if arr.ndim == 1 and arr.size: - arr = arr.reshape(1, -1) - if arr.ndim == 2: - rows = min(row_count, arr.shape[0]) - cols = min(trials, arr.shape[1]) - out[:rows, :cols] = arr[:rows, :cols] - return out - - -def rejection_row_count(EEG: dict[str, Any], icacomp: int) -> int: - """Return the number of channel or component rows for rejection marks.""" - if int(bool(icacomp)): - return int(EEG.get("nbchan", np.asarray(EEG.get("data")).shape[0]) or 0) - weights = np.asarray(EEG.get("icaweights", [])) - return int(weights.shape[0]) if weights.ndim == 2 else 0 - - def _require_ica(EEG: dict[str, Any]) -> None: if _nonempty_array(EEG.get("icaact")): return @@ -338,35 +315,4 @@ def _dataset_pnts(EEG: dict[str, Any]) -> int: return 0 -def _displayed_rejection_families(reject: dict[str, Any]) -> tuple[str, ...]: - disprej = reject.get("disprej") - if disprej is None or np.asarray(disprej, dtype=object).size == 0: - return tuple(family for family in REJECTION_FAMILIES if _has_rejection_family(reject, family)) - values = np.asarray(disprej, dtype=object).ravel().tolist() - return tuple(str(value) for value in values if str(value) in REJECTION_FAMILIES) - - -def _has_rejection_family(reject: dict[str, Any], family: str) -> bool: - data_marks = reject.get(f"rej{family}") - component_marks = reject.get(f"icarej{family}") - return (data_marks is not None and np.asarray(data_marks).size > 0) or ( - component_marks is not None and np.asarray(component_marks).size > 0 - ) - - -def _manual_color(EEG: dict[str, Any]) -> tuple[float, float, float]: - reject = EEG.get("reject") or {} - return _reject_color(reject, "manual", MANUAL_REJECTION_COLOR) - - -def _reject_color( - reject: dict[str, Any], family: str, default: tuple[float, float, float] -) -> tuple[float, float, float]: - color = reject.get(f"rej{family}col", default) - values = np.asarray(color if color is not None else DEFAULT_WINREJ_COLOR, dtype=float).ravel() - if values.size < 3: - return default - return tuple(float(item) for item in values[:3]) - - __all__ = ["apply_eegplot_rejections", "eegplot_accept_creates_dataset", "pop_eegplot"] diff --git a/src/eegprep/functions/popfunc/pop_eegthresh.py b/src/eegprep/functions/popfunc/pop_eegthresh.py index 93a8d640..a67d42a4 100644 --- a/src/eegprep/functions/popfunc/pop_eegthresh.py +++ b/src/eegprep/functions/popfunc/pop_eegthresh.py @@ -8,17 +8,9 @@ from eegprep.functions.guifunc.inputgui import inputgui from eegprep.functions.guifunc.spec import ControlSpec, DialogSpec -from eegprep.functions.popfunc._eegplot_rejection import open_epoched_rejection_browser +from eegprep.functions.popfunc._eegplot_rejection import run_epoched_mark_rejection from eegprep.functions.popfunc._pop_utils import format_history_value -from eegprep.functions.popfunc._rejection import ( - copy_eeg, - eegthresh_marks, - one_based_indices, - parse_numeric_sequence, - rejection_data, - update_reject_fields, -) -from eegprep.functions.popfunc.pop_rejepoch import pop_rejepoch +from eegprep.functions.popfunc._rejection import eegthresh_marks, parse_numeric_sequence def pop_eegthresh( @@ -166,45 +158,50 @@ def _apply_one( command_callback: Any | None = None, show: bool = True, ) -> tuple[dict[str, Any], list[int], str]: - out = copy_eeg(EEG) - data, row_count = rejection_data(out, icacomp) - trials = int(out.get("trials", data.shape[2]) or data.shape[2]) - if trials <= 1: - raise ValueError("pop_eegthresh requires epoched data") - elecrange = one_based_indices(elecrange, limit=row_count, default_all=True) - negthresh = [-10.0] if negthresh is None else negthresh - posthresh = [10.0] if posthresh is None else posthresh - starttime = [float(out.get("xmin", 0.0))] if starttime is None else starttime - endtime = [float(out.get("xmax", 0.0))] if endtime is None else endtime - marks, marks_e = eegthresh_marks( - data, + def _marks(out: dict[str, Any], data: np.ndarray, normalized_elecrange: list[int]): + normalized = { + "negthresh": [-10.0] if negthresh is None else negthresh, + "posthresh": [10.0] if posthresh is None else posthresh, + "starttime": [float(out.get("xmin", 0.0))] if starttime is None else starttime, + "endtime": [float(out.get("xmax", 0.0))] if endtime is None else endtime, + } + marks, marks_e = eegthresh_marks( + data, + normalized_elecrange, + normalized["negthresh"], + normalized["posthresh"], + (float(out.get("xmin", 0.0)), float(out.get("xmax", 0.0))), + normalized["starttime"], + normalized["endtime"], + ) + return marks, marks_e, normalized + + def _command(normalized_elecrange: list[int], normalized: dict[str, Any]) -> str: + return _history_command( + icacomp, + normalized_elecrange, + normalized["negthresh"], + normalized["posthresh"], + normalized["starttime"], + normalized["endtime"], + superpose, + int(bool(reject)), + ) + + out, rejected, command, _normalized = run_epoched_mark_rejection( + EEG, + icacomp, elecrange, - negthresh, - posthresh, - (float(out.get("xmin", 0.0)), float(out.get("xmax", 0.0))), - starttime, - endtime, - ) - update_reject_fields(out, icacomp=icacomp, kind="rejthresh", reject=marks, reject_e=marks_e) - rejected = (np.flatnonzero(marks) + 1).tolist() - command = _history_command( - icacomp, elecrange, negthresh, posthresh, starttime, endtime, superpose, int(bool(reject)) + superpose, + reject, + marks_fn=_marks, + kind="rejthresh", + error_message="pop_eegthresh requires epoched data", + command_fn=_command, + display=display, + command_callback=command_callback, + show=show, ) - if display: - open_epoched_rejection_browser( - out, - data=data, - icacomp=icacomp, - elecrange=elecrange, - kind="rejthresh", - superpose=superpose, - reject=reject, - command=command, - command_callback=command_callback, - show=show, - ) - elif int(bool(reject)) and rejected: - out = pop_rejepoch(out, rejected, 0) return out, rejected, command diff --git a/src/eegprep/functions/popfunc/pop_load_frombids.py b/src/eegprep/functions/popfunc/pop_load_frombids.py index ebf0b27c..6373d132 100644 --- a/src/eegprep/functions/popfunc/pop_load_frombids.py +++ b/src/eegprep/functions/popfunc/pop_load_frombids.py @@ -2,7 +2,6 @@ import os import copy -from importlib.resources import files from typing import Dict, Any, Tuple, Union, Optional import logging import warnings @@ -12,11 +11,12 @@ chanlocs_to_coords, clear_chanloc, coords_ALS_to_angular, - coords_any_to_RAS, coords_RAS_to_ALS, coords_to_mm, ) -from eegprep.functions.miscfunc.misc import ExceptionUnlessDebug, ToolError, round_mat +from eegprep.plugins.EEG_BIDS.montage import apply_montage_inference +from eegprep.plugins.EEG_BIDS.raw import load_raw_eeg_file +from eegprep.functions.miscfunc.misc import ExceptionUnlessDebug, round_mat import numpy as np @@ -30,13 +30,6 @@ event_timing_columns = ['onset', 'duration', 'sample'] -# remove matching leading/trailing quotes in pairs (repeat if nested) -def _strip_matching_quotes(name: str) -> str: - while len(name) >= 2 and name[0] == name[-1] and name[0] in ("'", '"'): - name = name[1:-1] - return name - - def pop_load_frombids( filename: str, *, @@ -124,7 +117,7 @@ def error(msg: str): logger.error(msg) report['Errors'].append(msg) - path, ext = os.path.splitext(filename) + _path, ext = os.path.splitext(filename) ext = ext.lower() root = root_for_fpath(filename) @@ -132,320 +125,14 @@ def error(msg: str): if verbose: logger.info(f"Loading EEG data from {filename}...") basename = os.path.basename(filename) - if ext == '.set': - from .pop_loadset import pop_loadset - - EEG = pop_loadset(filename) - EEG['data'] = EEG['data'].astype(dtype) - report['ImporterUsed'] = 'pop_loadset' - Fs = EEG['srate'] - times_sec = EEG['times'] / 1000.0 - elif ext in ['.edf', '.bdf', '.vhdr']: - from neo import NeoReadWriteError - - if ext == '.vhdr': - from neo.rawio.brainvisionrawio import BrainVisionRawIO as NeoIO - - report['ImporterUsed'] = 'neo.rawio.brainvisionrawio.BrainVisionRawIO' - elif ext in ['.edf', '.bdf']: - from neo.rawio.edfrawio import EDFRawIO as NeoIO - - report['ImporterUsed'] = 'neo.rawio.edfrawio.EDFRawIO' - else: - # if you're getting this, there's an elif statement missing here for one of - # the formats allowed above - raise ValueError(f"Unexpected file format: {ext}. Please add support for this format if needed.") - # load from NEO - io = NeoIO(filename) - try: - io.parse_header() - except NeoReadWriteError as e: - classname = io.__class__.__name__ - raise ToolError( - f"Encountered error with NEO {classname} importer on {filename!r}: {e}. Skipping file." - ) from e - if (nStreams := io.signal_streams_count()) > 1: - warning( - f"The raw data file {filename} appears to contain more than one stream; using only the first stream." - ) - elif not nStreams: - raise ValueError(f"The raw data file {filename} does not contain any data.") - if (nBlocks := io.block_count()) > 1: - warning( - f"The raw data file {filename} appears to contain " - f"more than one recording; this is not meaningful " - f"in a BIDS context; using only the first block." - ) - elif not nBlocks: - raise ValueError(f"The raw data file {filename} does not contain any data.") - if (nSegments := io.segment_count(0)) > 1: - raise NotImplementedError( - f"The raw data file {filename} appears to contain " - f"more than one segment; This importer currently " - f"only supports continuous EEG data." - ) - elif not nSegments: - raise ValueError(f"The raw data file {filename} does not contain any data.") - - nChannels = io.signal_channels_count(0) - nSamples = io.get_signal_size(0, 0, 0) - chnIdxs = list(range(nChannels)) - - report['NumStreams'] = nStreams - report['NumBlocks'] = nBlocks - report['NumSegments'] = nSegments - - if verbose: - logger.info(" retrieving EEG data from file...") - data_T = io.get_analogsignal_chunk( - block_index=0, seg_index=0, channel_indexes=chnIdxs, i_start=None, i_stop=None - ) - old_scale = np.std(data_T, axis=0) - data_T = io.rescale_signal_raw_to_float(data_T, dtype=dtype, channel_indexes=chnIdxs) - new_scale = np.std(data_T, axis=0) - scale_ratios = new_scale / old_scale - uq_ratios = np.unique(scale_ratios) - if len(uq_ratios) == 1: - report['ScaleApplied'] = uq_ratios.item() - else: - report['ScalesApplied'] = scale_ratios.tolist() - - # data time codes - Fs = io.get_signal_sampling_rate(0) - t0 = io.get_signal_t_start(block_index=0, seg_index=0, stream_index=0) - report['RawStartTime'] = t0 - time_ofs = getattr(io, '_global_time', 0.0) # default to 0 if not set - report['StartTimeOffset'] = time_ofs - t0 += time_ofs - report['CombinedStartTime'] = t0 - times_sec = t0 + np.arange(0, nSamples, dtype=float) / Fs - - # construct the chanlocs data structure - chns = io.header['signal_channels'] - # get the units for all channels - try: - units = chns['units'].tolist() - except KeyError: - units = ['uV'] * nChannels - uq_unit = np.unique(units) - if len(uq_unit) == 1 and uq_unit[0] not in ('uV', 'microvolts'): - warning( - f"Your channel unit does not appear to be in microvolts (uV) " - f"but is documented instead as {uq_unit[0]}. EEG scale might be incorrect. " - ) - - labels = chns['name'].tolist() - - # other available per-channel fields from neo: - # - id - # - sampling_rate (assumed to be uniform across all channels) - # - dtype - # - gain (accounted for in rescaling) - # - offset (accounted for in rescaling) - # - stream_id - # - buffer_id - - # preinitialize data structure - chanlocs = np.asarray( - [ - { - 'labels': lab, - 'sph_radius': numeric_null, - 'sph_theta': numeric_null, - 'sph_phi': numeric_null, - 'theta': numeric_null, - 'radius': numeric_null, - 'X': numeric_null, - 'Y': numeric_null, - 'Z': numeric_null, - 'type': 'EEG', - 'ref': numeric_null, - # 'urchan': numeric_null --> not present if urchanlocs not populated - } - for lab in labels - ] - ) - - # try to read out channel coordinates from side-channel info, if any - if ext == '.vhdr': - if verbose: - logger.info(" parsing VHDR-specific channel locations...") - try: - annots = io.raw_annotations['blocks'][0]['segments'][0]['signals'][0]['__array_annotations__'] - sph_radius, theta, phi = annots['coordinates_0'], annots['coordinates_1'], annots['coordinates_2'] - valid = (sph_radius != 0) | (theta != 0) | (phi != 0) - sph_theta = phi - 90 * np.sign(theta) - sph_phi = -np.abs(theta) + 90 - except KeyError: - warning(f"Channel coordinates not found in {filename}. Using default values for channel locations.") - valid = np.zeros(nChannels, dtype=bool) - elif ext in ['.edf', '.bdf']: - # EDF/BDF files do not have channel coordinates, so we use default values - valid = np.zeros(nChannels, dtype=bool) - else: - raise ValueError( - f"Unsupported file format for channel coordinates extraction: {ext}. " - f"Supported formats are .edf, .bdf, .vhdr." - ) - - if np.any(valid): - if verbose: - logger.info(" applying channel locations from EEG file...") - # set the channel locations to the extent that we have them - for loc, val, sph_r, sph_p, sph_t in zip(chanlocs, valid, sph_radius, sph_phi, sph_theta): - if val: - # write coordinates in - loc['sph_radius'] = sph_r - loc['sph_theta'] = sph_t - loc['sph_phi'] = sph_p - # also derive topo coords (sph2topo) - az = sph_p - horiz = sph_t - angle = -horiz - radius = 0.5 - az / 180 - loc['theta'] = angle - loc['radius'] = radius - # and derive cartesian coordinates (sph2cart) - az = np.deg2rad(sph_t) - elev = np.deg2rad(sph_p) - z = sph_r * np.sin(elev) - x = sph_r * np.cos(elev) * np.cos(az) - y = sph_r * np.cos(elev) * np.sin(az) - loc['X'] = x - loc['Y'] = y - loc['Z'] = z - - # construct the events data structure - if (nEvtChns := io.event_channels_count()) > 0: - if verbose: - logger.info(" reading in event data from EEG file...") - ev_all_times = [] - ev_all_durs = [] - ev_all_channels = [] - ev_all_data = [] - # All channels containing events get collapsed into a single axis of instances. - # The instance 'label' contains the original channel name. - # The instance 'data' contains the original event marker. - for ev_ch_ix in range(nEvtChns): - ev_times, ev_durs, ev_labels = io.get_event_timestamps( - block_index=0, - seg_index=0, - event_channel_index=ev_ch_ix, - t_start=None, - t_stop=None, - # (no other args) - ) - ev_all_times.extend(io.rescale_event_timestamp(ev_times)) - if ev_durs is not None: - ev_all_durs.extend(ev_durs) - else: - ev_all_durs.extend([1] * len(ev_times)) - ev_all_channels.extend(np.repeat(io.header['event_channels'][ev_ch_ix]['name'], len(ev_times))) - ev_all_data.extend(ev_labels) - # apply heuristics to deduce the event type - if ext == '.vhdr': - # BrainVision has the event name in the data, but when that's empty, - # we use the channel name as the event type. - ev_types = ev_all_data - ev_codes = ev_all_channels - elif ext in ['.edf', '.bdf']: - ev_types = [str(d) for d in ev_all_data] - ev_codes = [str(chn) for chn in ev_all_channels] - else: - # if you get this you need to add support for this file format here - raise ValueError( - f"Unsupported file format for event extraction: {ext}. Supported formats are .edf, .bdf, .vhdr." - ) - ev_lats = np.searchsorted(times_sec, ev_all_times) # +1 for MATLAB format compatibility (1-based index) - ev_durs = np.array(ev_all_durs, dtype=float) - ev_urevts = np.arange(len(ev_all_times)) - events = np.array( - [ - { - 'duration': dur, - 'latency': lat, - 'type': typ or ('boundary' if code == 'New Segment' else ''), - 'code': code, - 'urevent': chn, - } - for dur, lat, typ, code, chn in zip(ev_durs, ev_lats, ev_types, ev_codes, ev_urevts) - ] - ) - else: - events = numeric_null - - # this isn't really encoded in Neo's data structure, nor does pop_loadbv() seem - # to read it out, even though .vhdr CAN have it annotated (either the [Comments] section - # of the channel infos in the .vhdr file, or separately in each channel under [Channel Infos] - reference = 'unknown' - - EEG = { - 'setname': '', - 'filename': basename, - 'filepath': os.path.dirname(filename), - # these will be set from BIDS - 'subject': '', - 'group': '', - 'condition': '', - 'session': numeric_null, - 'comments': '', - # raw data array - 'nbchan': nChannels, - 'trials': 1, # assuming single trial for raw EEG datain - 'pnts': nSamples, - 'srate': Fs, - 'xmin': times_sec[0], - 'xmax': times_sec[-1], - 'times': times_sec * 1000, # in ms - 'data': data_T.T, - # ICA data structures - 'icaact': numeric_null, - 'icawinv': numeric_null, - 'icasphere': numeric_null, - 'icaweights': numeric_null, - 'icachansind': numeric_null, - # channel info - 'chanlocs': chanlocs, - 'urchanlocs': numeric_null, - 'chaninfo': { - 'plotrad': numeric_null, - 'shrink': numeric_null, - 'nosedir': '+X', - 'nodatchans': numeric_null, - 'icachansind': numeric_null, - }, - 'ref': reference, - # event data structures - 'event': events, - 'urevent': copy.deepcopy(events), - 'eventdescription': [], - # epoch info - 'epoch': numeric_null, - 'epochdescription': [], - # rejection info (note: could pre-populate) - 'reject': {}, - 'stats': {}, - # spectral data (not used) - 'specdata': numeric_null, - 'specicaact': numeric_null, - # spline fil - 'splinefile': '', - 'icasplinefile': '', - # DIPFIT info - 'dipfit': numeric_null, - # history info - 'history': '', - 'saved': 'justloaded', - # additional metadata - 'etc': {}, - 'run': numeric_null, - } - elif ext in ['.fdt', '.vmrk', '.eeg']: - raise ValueError( - f"pop_load_frombids should be called with the main data file, but was called on a sidecar file: {filename}." - ) - else: - raise ValueError(f"Unsupported file format: {ext}. Supported formats are .set, .edf, .bdf, .vhdr.") + EEG, Fs, times_sec, raw_report = load_raw_eeg_file( + filename, + dtype=dtype, + numeric_null=numeric_null, + warning=warning, + verbose=verbose, + ) + report.update(raw_report) report['EEGFileHadLocations'] = sum(chanloc_has_coords(ch) for ch in EEG['chanlocs']) report['ChanlocsFrom'] = os.path.relpath(filename, root) @@ -903,178 +590,14 @@ def error(msg: str): if infer_locations is None: infer_locations = not have_coords # only if no coordinates are present - if infer_locations: - from scipy.io.matlab import loadmat - # Portions of this code are Copyright (c) 2015-2025 Syntrogi Inc. dba Intheon; - # used under the terms of the BSD 2-Clause License. - - # set nosedir to +X (ALS) since that's the only coord system that we convert to here - EEG['chaninfo']['nosedir'] = '+X' - - # find best-matching montage file out of available options - # we're scoring by coverage of data channels first, and coverage in - # locfile second (the latter because we want to use the smallest locfile - # that covers the cap since sometimes there's one that has a superset - # of the names, but with different locations, e.g., 128ch vs 256ch) - datalabels = [cl['labels'].lower() for cl in EEG['chanlocs']] - - # remove channel prefixes if any - chanprefixes = ['brainvision rda_', 'rda_', 'eeg ', 'eeg-', 'eeg'] - for prefix in chanprefixes: - datalabels = [label.replace(prefix, '') for label in datalabels] - - # remove suffixes after minus sign (if reference is present in the channel label) - datalabels = [label.split('-')[0] for label in datalabels] - datalabels = [_strip_matching_quotes(label) for label in datalabels] - - opt_score, best_data, best_cap = (0, 0), None, '(not set)' - fractions = [] - caplabels = [] - - # Determine montage path and files to check. Resolve the packaged - # montages directory through importlib.resources so the lookup does not - # depend on this module's location on disk. - montage_path = str(files("eegprep").joinpath("resources").joinpath("montages")) - - if not os.path.isdir(montage_path): - raise RuntimeError( - f"Could not find montages directory at {montage_path}. This may indicate a corrupted installation." - ) - - if isinstance(infer_locations, str): - # Custom montage file specified - if os.path.isabs(infer_locations): - # Absolute path provided - override montage_path - montage_path = os.path.dirname(infer_locations) - filenames = [os.path.basename(infer_locations)] - else: - # Relative path - use standard montage directory - filenames = [infer_locations] - else: - # Use all available montage files - filenames = sorted(os.listdir(montage_path)) - - for filename in filenames: - # skip non-montage files - if not filename.endswith('.locs'): - continue - try: - data = loadmat(os.path.join(montage_path, filename), squeeze_me=True) - except Exception: - raise ValueError( - f"Failed to load montage file {filename}. " - f"Make sure it is a valid .locs file (MATLAB v7 .mat format)." - ) - caplabels = [label.lower() for label in data['labels']] - fraction_in_data = np.mean([n in caplabels for n in datalabels]) - fraction_in_locfile = np.mean([n in datalabels for n in caplabels]) - # bonus score for 10-20 preference - if {'c3', 'cz', 'fcz', 'c4'}.issubset(caplabels): - bonus1020 = 1 - else: - bonus1020 = 0 - score = (fraction_in_data, bonus1020, fraction_in_locfile) - if score > opt_score: - opt_score = score - best_data = data - best_cap = filename - fractions.append(fraction_in_data) - fractions = sorted(fractions, reverse=True) - best_fraction = opt_score[0] - - if best_data is None: - if isinstance(infer_locations, str): - raise RuntimeError( - f'The channel labels in your data do not match the specified montage file ({infer_locations}).' - ) - else: - raise RuntimeError('Channel labels do not match any known or specified montage.') - - # additional diagnostics - skip_locations = False - percent_found = int(100 * best_fraction) - if best_fraction < 0.25: - error( - "The given data has a very poor match to all " - "known montages (%s percent of channels found); " - "not assigning locations (got: %s)" % (percent_found, datalabels) - ) - skip_locations = True - elif best_fraction < 0.5: - if len(fractions) > 1 and best_fraction / 1.5 < fractions[1]: - warning( - "The given data has a poor match and multiple " - "montages are partially matching potentially " - "ambiguously (%s percent of channels found); " - "please double-check assigned locations." % percent_found - ) - else: - warning( - "The given data has a poor match to all known " - "montages (%s percent of channels found); please " - "double-check assigned locations." % percent_found - ) - elif best_fraction < 0.75 and len(fractions) > 1 and best_fraction / 1.5 < fractions[1]: - warning( - "The given data has a reasonable match to known " - "montages but multiple montages are potentially " - "matching (%s percent of channels found); " - "locations may be wrong." % percent_found - ) - elif best_fraction < 1.0: - warning( - "Not all channel locations could be matched to a " - "known montage; some channels may be non-EEG " - "channels ({} percent of channels found).".format(percent_found) - ) - - if not skip_locations: - report['ChanlocsFrom'] = os.path.basename(best_cap) - if '10-5' in best_cap: - labeling = '10-20' # normalize to 10-20 - else: - labeling, ext = os.path.splitext(os.path.basename(best_cap)) - EEG['etc']['labelscheme'] = labeling - - # transform coordinates from file into the EEGLAB coordinate system - # unit=millimeters, x=A (front), y=L (left), z=S (up) - unit = best_data['meta']['unit'][()] - x = best_data['meta']['x'][()] - y = best_data['meta']['y'][()] - z = best_data['meta']['z'][()] - coords = best_data['coordinates'] - coords = coords_to_mm(coords, unit) - coords = coords_any_to_RAS(coords, x, y, z) - coords = coords_RAS_to_ALS(coords) - sph_theta, sph_phi, sph_radius, polar_theta, polar_radius = coords_ALS_to_angular(coords) - - # cross-reference location indices from best montage - caplabels = [label.lower() for label in best_data['labels']] - for di, dl in enumerate(datalabels): - rec = EEG['chanlocs'][di] - for ci, cl in enumerate(caplabels): - if dl == cl: - xyz = coords[ci, :] - rec['X'] = xyz[0] - rec['Y'] = xyz[1] - rec['Z'] = xyz[2] - rec['sph_radius'] = sph_radius[ci] - rec['sph_theta'] = sph_theta[ci] - rec['sph_phi'] = sph_phi[ci] - rec['theta'] = polar_theta[ci] - rec['radius'] = polar_radius[ci] - break - else: - # otherwise clear the locs to invalid - clear_chanloc(rec, numeric_null) - else: - # unambiguously 10-20 locations (excl. for example C3/C4 or F3/F4 since these - # are also in the Biosemi montage and likely some others) - candidate_locs = {"fp1", "fp2", "fz", "t3", "cz", "t4", "t5", "p3", "pz", "p4", "t6", "o1", "o2"} - if any(ch['labels'].lower() in candidate_locs for ch in EEG['chanlocs']): - EEG['etc']['labelscheme'] = '10-20' - else: - EEG['etc']['labelscheme'] = 'unknown' + apply_montage_inference( + EEG, + infer_locations, + numeric_null=numeric_null, + report=report, + warning=warning, + error=error, + ) EEG = eeg_checkset(EEG) try: diff --git a/src/eegprep/functions/popfunc/pop_rejspec.py b/src/eegprep/functions/popfunc/pop_rejspec.py index 99171bdd..18c925cc 100644 --- a/src/eegprep/functions/popfunc/pop_rejspec.py +++ b/src/eegprep/functions/popfunc/pop_rejspec.py @@ -8,17 +8,12 @@ from eegprep.functions.guifunc.inputgui import inputgui from eegprep.functions.guifunc.spec import ControlSpec, DialogSpec -from eegprep.functions.popfunc._eegplot_rejection import open_epoched_rejection_browser +from eegprep.functions.popfunc._eegplot_rejection import run_epoched_mark_rejection from eegprep.functions.popfunc._pop_utils import format_history_value, parse_key_value_args from eegprep.functions.popfunc._rejection import ( - copy_eeg, - one_based_indices, parse_numeric_sequence, - rejection_data, spectrum_marks, - update_reject_fields, ) -from eegprep.functions.popfunc.pop_rejepoch import pop_rejepoch def pop_rejspec( @@ -169,46 +164,44 @@ def _apply_one( command_callback: Any | None = None, show: bool = True, ) -> tuple[dict[str, Any], list[int], str]: - out = copy_eeg(EEG) - data, row_count = rejection_data(out, icacomp) - if int(out.get("trials", data.shape[2]) or data.shape[2]) <= 1: - raise ValueError("pop_rejspec requires epoched data") - elecrange = one_based_indices(options.get("elecrange"), limit=row_count, default_all=True) - threshold = options.get("threshold", [-30, 30]) - freqlimits = options.get("freqlimits", [15, 30]) - method = str(options.get("method", "multitaper")).lower() - marks, marks_e, spectra = spectrum_marks( - data, elecrange, float(out.get("srate", 1.0)), threshold, freqlimits, method - ) - if int(bool(icacomp)): - out["specdata"] = spectra - else: - out["specicaact"] = spectra - update_reject_fields(out, icacomp=icacomp, kind="rejfreq", reject=marks, reject_e=marks_e) - rejected = (np.flatnonzero(marks) + 1).tolist() - normalized_options = dict(options) - normalized_options["elecrange"] = elecrange - normalized_options["method"] = method - normalized_options.setdefault("threshold", threshold) - normalized_options.setdefault("freqlimits", freqlimits) - normalized_options.setdefault("eegplotplotallrej", 0) - normalized_options.setdefault("eegplotreject", 0) - command = _history_command(icacomp, normalized_options) - if display: - open_epoched_rejection_browser( - out, - data=data, - icacomp=icacomp, - elecrange=elecrange, - kind="rejfreq", - superpose=_int_option(normalized_options.get("eegplotplotallrej", 0)), - reject=int(bool(_int_option(normalized_options.get("eegplotreject", 0)))), - command=command, - command_callback=command_callback, - show=show, + def _marks(out: dict[str, Any], data: np.ndarray, elecrange: list[int]): + threshold = options.get("threshold", [-30, 30]) + freqlimits = options.get("freqlimits", [15, 30]) + method = str(options.get("method", "multitaper")).lower() + marks, marks_e, spectra = spectrum_marks( + data, elecrange, float(out.get("srate", 1.0)), threshold, freqlimits, method ) - elif int(bool(options.get("eegplotreject", 0))) and rejected: - out = pop_rejepoch(out, rejected, 0) + if int(bool(icacomp)): + out["specdata"] = spectra + else: + out["specicaact"] = spectra + normalized_options = dict(options) + normalized_options["elecrange"] = elecrange + normalized_options["method"] = method + normalized_options.setdefault("threshold", threshold) + normalized_options.setdefault("freqlimits", freqlimits) + normalized_options.setdefault("eegplotplotallrej", 0) + normalized_options.setdefault("eegplotreject", 0) + return marks, marks_e, normalized_options + + def _command(_elecrange: list[int], normalized_options: dict[str, Any]) -> str: + return _history_command(icacomp, normalized_options) + + normalized_reject = int(bool(_int_option(options.get("eegplotreject", 0)))) + out, rejected, command, _normalized_options = run_epoched_mark_rejection( + EEG, + icacomp, + options.get("elecrange"), + _int_option(options.get("eegplotplotallrej", 0)), + normalized_reject, + marks_fn=_marks, + kind="rejfreq", + error_message="pop_rejspec requires epoched data", + command_fn=_command, + display=display, + command_callback=command_callback, + show=show, + ) return out, rejected, command diff --git a/src/eegprep/functions/popfunc/pop_rejtrend.py b/src/eegprep/functions/popfunc/pop_rejtrend.py index 24042879..11b5643c 100644 --- a/src/eegprep/functions/popfunc/pop_rejtrend.py +++ b/src/eegprep/functions/popfunc/pop_rejtrend.py @@ -8,17 +8,12 @@ from eegprep.functions.guifunc.inputgui import inputgui from eegprep.functions.guifunc.spec import ControlSpec, DialogSpec -from eegprep.functions.popfunc._eegplot_rejection import open_epoched_rejection_browser +from eegprep.functions.popfunc._eegplot_rejection import run_epoched_mark_rejection from eegprep.functions.popfunc._pop_utils import format_history_value from eegprep.functions.popfunc._rejection import ( - copy_eeg, - one_based_indices, parse_numeric_sequence, - rejection_data, trend_marks, - update_reject_fields, ) -from eegprep.functions.popfunc.pop_rejepoch import pop_rejepoch def pop_rejtrend( @@ -151,33 +146,46 @@ def _apply_one( command_callback: Any | None = None, show: bool = True, ) -> tuple[dict[str, Any], str]: - out = copy_eeg(EEG) - data, row_count = rejection_data(out, icacomp) - if int(out.get("trials", data.shape[2]) or data.shape[2]) <= 1: - raise ValueError("pop_rejtrend requires epoched data") - elecrange = one_based_indices(elecrange, limit=row_count, default_all=True) - winsize = int(parse_numeric_sequence(winsize if winsize is not None else [data.shape[1]], dtype=float)[0]) - minslope = float(parse_numeric_sequence(minslope, dtype=float)[0]) - minstd = float(parse_numeric_sequence(minstd, dtype=float)[0]) - marks, marks_e = trend_marks(data, elecrange, winsize, minslope, minstd) - update_reject_fields(out, icacomp=icacomp, kind="rejconst", reject=marks, reject_e=marks_e) - rejected = (np.flatnonzero(marks) + 1).tolist() - command = _history_command(icacomp, elecrange, winsize, minslope, minstd, superpose, reject) - if display: - open_epoched_rejection_browser( - out, - data=data, - icacomp=icacomp, - elecrange=elecrange, - kind="rejconst", - superpose=superpose, - reject=reject, - command=command, - command_callback=command_callback, - show=show, + def _marks(_out: dict[str, Any], data: np.ndarray, normalized_elecrange: list[int]): + normalized = { + "winsize": int(parse_numeric_sequence(winsize if winsize is not None else [data.shape[1]], dtype=float)[0]), + "minslope": float(parse_numeric_sequence(minslope, dtype=float)[0]), + "minstd": float(parse_numeric_sequence(minstd, dtype=float)[0]), + } + marks, marks_e = trend_marks( + data, + normalized_elecrange, + normalized["winsize"], + normalized["minslope"], + normalized["minstd"], ) - elif int(bool(reject)) and rejected: - out = pop_rejepoch(out, rejected, 0) + return marks, marks_e, normalized + + def _command(normalized_elecrange: list[int], normalized: dict[str, Any]) -> str: + return _history_command( + icacomp, + normalized_elecrange, + normalized["winsize"], + normalized["minslope"], + normalized["minstd"], + superpose, + reject, + ) + + out, _rejected, command, _normalized = run_epoched_mark_rejection( + EEG, + icacomp, + elecrange, + superpose, + reject, + marks_fn=_marks, + kind="rejconst", + error_message="pop_rejtrend requires epoched data", + command_fn=_command, + display=display, + command_callback=command_callback, + show=show, + ) return out, command diff --git a/src/eegprep/functions/popfunc/pop_resample.py b/src/eegprep/functions/popfunc/pop_resample.py index 73f8c147..75a64df1 100644 --- a/src/eegprep/functions/popfunc/pop_resample.py +++ b/src/eegprep/functions/popfunc/pop_resample.py @@ -12,9 +12,9 @@ from scipy.signal.windows import kaiser from eegprep.functions.adminfunc.eeglabcompat import get_eeglab -from eegprep.functions.adminfunc.eeg_options import EEG_OPTIONS from eegprep.functions.guifunc.inputgui import inputgui from eegprep.functions.guifunc.spec import CallbackSpec, ControlSpec, DialogSpec +from eegprep.functions.miscfunc.event_utils import is_boundary_event as _shared_is_boundary_event from eegprep.functions.popfunc._file_io import events_to_records from eegprep.plugins.firfilt.firws import firws from eegprep.plugins.firfilt.firwsord import firwsord @@ -230,10 +230,7 @@ def _segment_bounds(EEG, old_pnts): def _is_boundary_event(event): - event_type = event.get("type") if isinstance(event, dict) else None - if isinstance(event_type, str): - return event_type.lower().startswith("boundary") - return bool(EEG_OPTIONS.get("option_boundary99")) and event_type == -99 + return isinstance(event, dict) and _shared_is_boundary_event(event) def _resample_segment(segment, p, q, *, method, fc, df): diff --git a/src/eegprep/functions/popfunc/pop_topoplot.py b/src/eegprep/functions/popfunc/pop_topoplot.py index c777d0ba..40585c80 100644 --- a/src/eegprep/functions/popfunc/pop_topoplot.py +++ b/src/eegprep/functions/popfunc/pop_topoplot.py @@ -13,7 +13,8 @@ from eegprep.functions.guifunc.spec import ControlSpec, DialogSpec from eegprep.functions.miscfunc.misc import round_mat from eegprep.functions.popfunc._chanutils import chanlocs_as_list -from eegprep.functions.popfunc._plot_utils import component_map_data, python_literal +from eegprep.functions.popfunc._plot_utils import component_map_data +from eegprep.functions.popfunc._plot_utils import history_command as plot_history_command from eegprep.functions.popfunc._pop_utils import is_on as _is_on from eegprep.functions.popfunc._pop_utils import parse_key_value_args, parse_numeric_sequence, parse_text_tokens from eegprep.functions.sigprocfunc.topoplot import topoplot @@ -430,17 +431,15 @@ def _history_command( plotdip: int, options: dict[str, Any], ) -> str: - pieces = [ - "EEG", - f"typeplot={int(typeplot)}", - f"items={python_literal(items)}", - f"topotitle={python_literal(topotitle)}", - f"rowcols={python_literal(list(rowcols))}", - f"plotdip={int(plotdip)}", - ] - for key, value in options.items(): - pieces.append(f"{key}={python_literal(value)}") - return f"pop_topoplot({', '.join(pieces)})" + kwargs = { + "typeplot": int(typeplot), + "items": items, + "topotitle": topotitle, + "rowcols": list(rowcols), + "plotdip": int(plotdip), + } + kwargs.update(options) + return plot_history_command("pop_topoplot", **kwargs) __all__ = ["plot_channel_locations", "pop_topoplot", "pop_topoplot_dialog_spec"] diff --git a/src/eegprep/functions/sigprocfunc/coregister.py b/src/eegprep/functions/sigprocfunc/coregister.py index c6b0eeae..f52b4b8e 100644 --- a/src/eegprep/functions/sigprocfunc/coregister.py +++ b/src/eegprep/functions/sigprocfunc/coregister.py @@ -11,8 +11,8 @@ import numpy as np from scipy.optimize import least_squares +from eegprep.functions.miscfunc.value_parsing import parse_numeric_sequence as _parse_numeric_sequence_value from eegprep.functions.popfunc._chanutils import chanlocs_as_list -from eegprep.functions.popfunc._pop_utils import parse_numeric_sequence as _parse_numeric_sequence_value DEFAULT_COREGISTER_TRANSFORM = np.asarray([0.0, -10.0, 0.0, -0.1, 0.0, -1.6, 1100.0, 1100.0, 1100.0]) _SCALE_LOWER_BOUND = 1e-9 diff --git a/src/eegprep/functions/sigprocfunc/eegrej.py b/src/eegprep/functions/sigprocfunc/eegrej.py index f329e3d8..c6aadb4e 100644 --- a/src/eegprep/functions/sigprocfunc/eegrej.py +++ b/src/eegprep/functions/sigprocfunc/eegrej.py @@ -2,18 +2,10 @@ import numpy as np from typing import List, Dict, Optional, Tuple +from eegprep.functions.miscfunc.event_utils import is_boundary_event as _is_boundary_event from ..miscfunc.misc import round_mat -def _is_boundary_event(event: Dict) -> bool: - t = event.get("type") - if isinstance(t, str): - return t.lower() == "boundary" - if isinstance(t, (int, float)): - return int(t) == -99 - return False - - def eegrej( indata, regions, timelength, events: Optional[List[Dict]] = None ) -> Tuple[np.ndarray, float, List[Dict], np.ndarray]: diff --git a/src/eegprep/functions/sigprocfunc/headplot.py b/src/eegprep/functions/sigprocfunc/headplot.py index 35a464a2..9c88c390 100644 --- a/src/eegprep/functions/sigprocfunc/headplot.py +++ b/src/eegprep/functions/sigprocfunc/headplot.py @@ -16,9 +16,9 @@ from scipy.special import eval_legendre from eegprep.functions.miscfunc.misc import finite_matmul, finite_pinv +from eegprep.functions.miscfunc.value_parsing import is_on as _is_on_value +from eegprep.functions.miscfunc.value_parsing import parse_numeric_sequence as _parse_numeric_sequence_value from eegprep.functions.popfunc._chanutils import chanlocs_as_list -from eegprep.functions.popfunc._pop_utils import is_on as _is_on_value -from eegprep.functions.popfunc._pop_utils import parse_numeric_sequence as _parse_numeric_sequence_value from eegprep.functions.sigprocfunc.coregister import ( DEFAULT_COREGISTER_TRANSFORM, apply_coregistration_transform, diff --git a/src/eegprep/functions/sigprocfunc/readlocs.py b/src/eegprep/functions/sigprocfunc/readlocs.py index 1e2418c6..bbcf4cbd 100644 --- a/src/eegprep/functions/sigprocfunc/readlocs.py +++ b/src/eegprep/functions/sigprocfunc/readlocs.py @@ -10,8 +10,8 @@ import scipy.io from scipy.io.matlab import mat_struct +from eegprep.functions.miscfunc.value_parsing import is_empty_value, parse_key_value_args, parse_numeric_sequence from eegprep.functions.popfunc._chanutils import chanlocs_as_list -from eegprep.functions.popfunc._pop_utils import is_empty_value, parse_key_value_args, parse_numeric_sequence from eegprep.functions.sigprocfunc.convertlocs import convertlocs diff --git a/src/eegprep/functions/sigprocfunc/writelocs.py b/src/eegprep/functions/sigprocfunc/writelocs.py index a0b5c78f..d6b4b281 100644 --- a/src/eegprep/functions/sigprocfunc/writelocs.py +++ b/src/eegprep/functions/sigprocfunc/writelocs.py @@ -7,8 +7,8 @@ import numpy as np +from eegprep.functions.miscfunc.value_parsing import is_empty_value, parse_key_value_args, parse_numeric_sequence from eegprep.functions.popfunc._chanutils import chanlocs_as_list -from eegprep.functions.popfunc._pop_utils import is_empty_value, parse_key_value_args, parse_numeric_sequence from eegprep.functions.sigprocfunc.convertlocs import convertlocs from eegprep.functions.sigprocfunc.readlocs import CHANNEL_FORMATS diff --git a/src/eegprep/functions/statistics/__init__.py b/src/eegprep/functions/statistics/__init__.py index b366ca8a..fb78a2be 100644 --- a/src/eegprep/functions/statistics/__init__.py +++ b/src/eegprep/functions/statistics/__init__.py @@ -2,8 +2,6 @@ from importlib import import_module -from eegprep.functions.statistics import _core as _statistics_core - # Import same-name thin modules before binding package callables. Without this, # a later ``import eegprep.functions.statistics.fdr`` can replace # ``statistics.fdr`` with the submodule object. @@ -26,28 +24,30 @@ for _module_name in _THIN_MODULES: import_module(f"{__name__}.{_module_name}") -ConcatenatedData = _statistics_core.ConcatenatedData -FDRResult = _statistics_core.FDRResult -StatcondResult = _statistics_core.StatcondResult -SurrogateDistribution = _statistics_core.SurrogateDistribution -TwoWayAnovaResult = _statistics_core.TwoWayAnovaResult -TwoWayEffects = _statistics_core.TwoWayEffects -anova1_cell = _statistics_core.anova1_cell -anova1rm_cell = _statistics_core.anova1rm_cell -anova2_cell = _statistics_core.anova2_cell -anova2rm_cell = _statistics_core.anova2rm_cell -concatdata = _statistics_core.concatdata -corrcoef_cell = _statistics_core.corrcoef_cell -fdr = _statistics_core.fdr -stat_surrogate_ci = _statistics_core.stat_surrogate_ci -stat_surrogate_pvals = _statistics_core.stat_surrogate_pvals -statcond = _statistics_core.statcond -surrogdistrib = _statistics_core.surrogdistrib -teststat = _statistics_core.teststat -ttest2_cell = _statistics_core.ttest2_cell -ttest_cell = _statistics_core.ttest_cell +_MODULES = {_module_name: import_module(f"{__name__}.{_module_name}") for _module_name in _THIN_MODULES} + +ConcatenatedData = _MODULES["concatdata"].ConcatenatedData +FDRResult = _MODULES["fdr"].FDRResult +StatcondResult = _MODULES["statcond"].StatcondResult +SurrogateDistribution = _MODULES["surrogdistrib"].SurrogateDistribution +TwoWayAnovaResult = _MODULES["anova2_cell"].TwoWayAnovaResult +TwoWayEffects = _MODULES["statcond"].TwoWayEffects +anova1_cell = _MODULES["anova1_cell"].anova1_cell +anova1rm_cell = _MODULES["anova1rm_cell"].anova1rm_cell +anova2_cell = _MODULES["anova2_cell"].anova2_cell +anova2rm_cell = _MODULES["anova2rm_cell"].anova2rm_cell +concatdata = _MODULES["concatdata"].concatdata +corrcoef_cell = _MODULES["corrcoef_cell"].corrcoef_cell +fdr = _MODULES["fdr"].fdr +stat_surrogate_ci = _MODULES["stat_surrogate_ci"].stat_surrogate_ci +stat_surrogate_pvals = _MODULES["stat_surrogate_pvals"].stat_surrogate_pvals +statcond = _MODULES["statcond"].statcond +surrogdistrib = _MODULES["surrogdistrib"].surrogdistrib +teststat = _MODULES["teststat"].teststat +ttest2_cell = _MODULES["ttest2_cell"].ttest2_cell +ttest_cell = _MODULES["ttest_cell"].ttest_cell -del import_module, _module_name, _THIN_MODULES, _statistics_core +del import_module, _module_name, _MODULES, _THIN_MODULES __all__ = [ "ConcatenatedData", diff --git a/src/eegprep/functions/statistics/_core.py b/src/eegprep/functions/statistics/_core.py index c99aee11..07d99897 100644 --- a/src/eegprep/functions/statistics/_core.py +++ b/src/eegprep/functions/statistics/_core.py @@ -1,1043 +1,45 @@ -"""NumPy implementations of EEGLAB-style statistics helpers.""" - -from __future__ import annotations - -from collections.abc import Iterator, Sequence -from dataclasses import dataclass -from typing import Any - -import numpy as np -from scipy import stats as scipy_stats - - -@dataclass(frozen=True) -class FDRResult: - """False discovery rate threshold and mask.""" - - threshold: np.ndarray | float - mask: np.ndarray - - def __iter__(self) -> Iterator[np.ndarray | float]: - yield self.threshold - yield self.mask - - -@dataclass(frozen=True) -class ConcatenatedData: - """Data concatenated across condition case axes.""" - - data: np.ndarray - lengths: np.ndarray - grid_shape: tuple[int, int] - - def __iter__(self) -> Iterator[np.ndarray | tuple[int, int]]: - yield self.data - yield self.lengths - yield self.grid_shape - - -@dataclass(frozen=True) -class TwoWayEffects: - """Row, column, and interaction values for a two-way design.""" - - rows: Any - columns: Any - interaction: Any - - def __iter__(self) -> Iterator[Any]: - yield self.rows - yield self.columns - yield self.interaction - - -@dataclass(frozen=True) -class TwoWayAnovaResult: - """Two-way ANOVA statistics and degrees of freedom.""" - - rows: np.ndarray - columns: np.ndarray - interaction: np.ndarray - df_rows: tuple[int, int] - df_columns: tuple[int, int] - df_interaction: tuple[int, int] - - def as_effects(self) -> TwoWayEffects: - return TwoWayEffects(self.rows, self.columns, self.interaction) - - def df_effects(self) -> TwoWayEffects: - return TwoWayEffects(self.df_rows, self.df_columns, self.df_interaction) - - -@dataclass(frozen=True) -class SurrogateDistribution: - """Surrogate condition grids produced by permutation or bootstrap.""" - - samples: tuple[tuple[tuple[np.ndarray, ...], ...], ...] - - def __iter__(self) -> Iterator[tuple[tuple[np.ndarray, ...], ...]]: - return iter(self.samples) - - def __len__(self) -> int: - return len(self.samples) - - -@dataclass(frozen=True) -class StatcondResult: - """Result returned by :func:`statcond`.""" - - stat: Any - df: Any - pvalue: Any - surrogate: Any - method: str - paired: bool - ci: Any = None - mask: Any = None - - def __iter__(self) -> Iterator[Any]: - yield self.stat - yield self.df - yield self.pvalue - yield self.surrogate - - -def fdr(pvals: Any, q: float | None = None, fdr_type: str = "parametric") -> FDRResult: - """Compute Benjamini-Hochberg or Benjamini-Yekutieli FDR thresholds. - - Args: - pvals: Numeric p-value array with values in the closed interval [0, 1]. - q: Desired false discovery rate. If omitted, an EEGLAB-style array of - corrected thresholds is returned. - fdr_type: ``"parametric"`` for Benjamini-Hochberg or - ``"nonparametric"``/``"nonParametric"`` for Benjamini-Yekutieli. - - Returns: - Threshold and boolean mask with the same shape as ``pvals``. - """ - - values = np.asarray(pvals) - if not np.issubdtype(values.dtype, np.number): - raise TypeError("pvals must be numeric") - if values.size == 0: - return FDRResult(np.array([], dtype=float), np.array([], dtype=bool)) - finite_mask = np.isfinite(values) - finite_values = values[finite_mask] - if np.any((finite_values < 0) | (finite_values > 1)): - raise ValueError("pvals must contain probabilities between 0 and 1") - - if q is None: - threshold = np.ones(values.shape, dtype=float) - thresholds = np.exp(np.linspace(np.log(0.1), np.log(0.000001), 1000)) - for current in thresholds: - current_result = fdr(values, float(current), fdr_type=fdr_type) - threshold[current_result.mask] = current - return FDRResult(threshold, finite_mask & (values <= threshold)) - - q_value = float(q) - if not 0 <= q_value <= 1: - raise ValueError("q must be between 0 and 1") - - fdr_type_name = fdr_type.lower() - if fdr_type_name not in {"parametric", "nonparametric"}: - raise ValueError("fdr_type must be 'parametric' or 'nonparametric'") - - if finite_values.size == 0: - return FDRResult(0.0, np.zeros(values.shape, dtype=bool)) - - flat = np.sort(finite_values.reshape(-1)) - count = flat.size - indices = np.arange(1, count + 1, dtype=float) - correction = 1.0 if fdr_type_name == "parametric" else float(np.sum(1.0 / indices)) - accepted = flat <= indices / count * q_value / correction - threshold_value = float(flat[np.flatnonzero(accepted).max()]) if np.any(accepted) else 0.0 - return FDRResult(threshold_value, finite_mask & (values <= threshold_value)) - - -def stat_surrogate_pvals(distribution: Any, observed: Any, tail: str = "both") -> np.ndarray: - """Compute empirical p-values against a surrogate distribution. - - Args: - distribution: Surrogate statistic array whose last axis stores - surrogate replications. - observed: Observed statistic array matching ``distribution.shape[:-1]``. - tail: ``"right"``/``"upper"``/``"one"``, ``"left"``/``"lower"``, or - ``"both"``. - """ - - surrogates = _as_numeric_array(distribution, "distribution") - observed_values = _as_numeric_array(observed, "observed", require_axis=False) - if surrogates.ndim < 1: - raise ValueError("distribution must have a surrogate axis") - if observed_values.shape != surrogates.shape[:-1]: - raise ValueError("observed shape must match distribution without its last axis") - - n_samples = surrogates.shape[-1] - expanded = np.expand_dims(observed_values, axis=-1) - p_right = np.sum(surrogates >= expanded, axis=-1) / n_samples - tail_name = tail.lower() - if tail_name in {"right", "upper", "one"}: - return p_right - - p_left = 1 - p_right + np.sum(surrogates == expanded, axis=-1) / n_samples - if tail_name in {"left", "lower"}: - return p_left - if tail_name == "both": - return np.minimum(2 * np.minimum(p_right, p_left), 1.0) - raise ValueError("tail must be 'right', 'upper', 'left', 'lower', 'one', or 'both'") - - -def stat_surrogate_ci(distribution: Any, alpha: float = 0.05, tail: str = "both") -> np.ndarray: - """Compute surrogate confidence intervals along the last axis. - - Args: - distribution: Surrogate statistic array whose last axis stores - surrogate replications. - alpha: Type-I error rate. - tail: ``"upper"``, ``"lower"``, ``"one"``, or ``"both"``. - """ - - values = _as_numeric_array(distribution, "distribution") - if values.ndim < 1: - raise ValueError("distribution must have a surrogate axis") - alpha_value = float(alpha) - if not 0 <= alpha_value <= 1: - raise ValueError("alpha must be between 0 and 1") - - sorted_values = np.sort(values, axis=-1) - n_samples = sorted_values.shape[-1] - ci_alpha = alpha_value / 2 if tail.lower() == "both" else alpha_value - low = int(np.floor(ci_alpha * n_samples + 0.5)) - high = n_samples - low - low_index = min(max(low, 0), n_samples - 1) - high_index = min(max(high - 1, 0), n_samples - 1) - mean_values = np.mean(sorted_values, axis=-1) - - tail_name = tail.lower() - if tail_name == "upper": - lower = mean_values - upper = sorted_values[..., high_index] - elif tail_name == "lower": - lower = sorted_values[..., low_index] - upper = mean_values - elif tail_name in {"both", "one"}: - lower = sorted_values[..., low_index] - upper = sorted_values[..., high_index] - else: - raise ValueError("tail must be 'upper', 'lower', 'one', or 'both'") - - return np.stack((lower, upper), axis=0) - - -def concatdata(data: Any, *, axis: int = -1) -> ConcatenatedData: - """Concatenate condition arrays along their case axis. - - Args: - data: One- or two-dimensional sequence of condition arrays. - axis: Axis in each condition array that stores cases. - """ - - grid = _condition_grid(data, axis=axis, min_cases=1) - arrays = _flatten_grid(grid) - feature_shape = arrays[0].shape[:-1] - for index, array in enumerate(arrays): - if array.shape[:-1] != feature_shape: - raise ValueError(f"condition {index} has feature shape {array.shape[:-1]}, expected {feature_shape}") - - lengths = np.zeros(len(arrays) + 1, dtype=int) - lengths[1:] = np.cumsum([array.shape[-1] for array in arrays]) - concatenated = np.concatenate(arrays, axis=-1) - return ConcatenatedData(concatenated, lengths, (len(grid), len(grid[0]))) - - -def corrcoef_cell(a: Any, b: Any | None = None, *, axis: int = -1) -> np.ndarray: - """Compute pairwise correlations along a case axis.""" - - first, second = _two_arrays(a, b, "corrcoef_cell", axis=axis) - if first.shape != second.shape: - raise ValueError("corrcoef_cell requires arrays with identical shapes") - if first.shape[-1] < 2: - raise ValueError("corrcoef_cell requires at least two cases") - - first_centered = first - np.mean(first, axis=-1, keepdims=True) - second_centered = second - np.mean(second, axis=-1, keepdims=True) - covariance = np.sum(first_centered * second_centered, axis=-1) - first_power = np.sum(first_centered * first_centered, axis=-1) - second_power = np.sum(second_centered * second_centered, axis=-1) - with np.errstate(divide="ignore", invalid="ignore"): - return covariance / np.sqrt(first_power * second_power) - - -def ttest_cell(a: Any, b: Any | None = None, *, axis: int = -1) -> tuple[np.ndarray, int]: - """Compute paired t-statistics across the case axis.""" - - first, second = _two_arrays(a, b, "ttest_cell", axis=axis) - if first.shape != second.shape: - raise ValueError("ttest_cell requires paired arrays with identical shapes") - n_cases = first.shape[-1] - if n_cases < 2: - raise ValueError("ttest_cell requires at least two paired cases") - - difference = first - second - mean_difference = _stat_mean(difference, axis=-1) - sd_difference = _stat_std(difference, axis=-1) - with np.errstate(divide="ignore", invalid="ignore"): - t_values = mean_difference / sd_difference * np.sqrt(n_cases) - return t_values, n_cases - 1 - - -def ttest2_cell( - a: Any, - b: Any | None = None, - variance: str = "homogenous", - *, - axis: int = -1, -) -> tuple[np.ndarray, np.ndarray | int]: - """Compute unpaired t-statistics across the case axis.""" - - if isinstance(b, str): - variance = b - b = None - first, second = _two_arrays(a, b, "ttest2_cell", axis=axis) - if first.shape[:-1] != second.shape[:-1]: - raise ValueError("ttest2_cell requires matching feature shapes before the case axis") - if first.shape[-1] < 2 or second.shape[-1] < 2: - raise ValueError("ttest2_cell requires at least two cases in each group") - - variance_name = variance.lower() - if variance_name not in {"homogenous", "inhomogenous"}: - raise ValueError("variance must be 'homogenous' or 'inhomogenous'") - - first_n = first.shape[-1] - second_n = second.shape[-1] - first_mean = _stat_mean(first, axis=-1) - second_mean = _stat_mean(second, axis=-1) - if variance_name == "inhomogenous": - first_scaled = np.var(first, axis=-1, ddof=1) / first_n - second_scaled = np.var(second, axis=-1, ddof=1) / second_n - standard_error = np.sqrt(first_scaled + second_scaled) - with np.errstate(divide="ignore", invalid="ignore"): - t_values = (first_mean - second_mean) / standard_error - df = (first_scaled + second_scaled) ** 2 / ( - first_scaled**2 / (first_n - 1) + second_scaled**2 / (second_n - 1) - ) - return t_values, df - - first_sd = _stat_std(first, axis=-1) - second_sd = _stat_std(second, axis=-1) - pooled_sd = np.sqrt(((first_n - 1) * first_sd**2 + (second_n - 1) * second_sd**2) / (first_n + second_n - 2)) - with np.errstate(divide="ignore", invalid="ignore"): - t_values = (first_mean - second_mean) / pooled_sd / np.sqrt(1 / first_n + 1 / second_n) - return t_values, first_n + second_n - 2 - - -def anova1_cell(data: Any, *, axis: int = -1) -> tuple[np.ndarray, tuple[int, int]]: - """Compute one-way unpaired ANOVA F-statistics across condition arrays.""" - - arrays = _one_way_arrays(data, axis=axis, paired=False) - if len(arrays) < 2: - raise ValueError("anova1_cell requires at least two conditions") - feature_shape = arrays[0].shape[:-1] - for index, array in enumerate(arrays): - if array.shape[:-1] != feature_shape: - raise ValueError(f"condition {index} has feature shape {array.shape[:-1]}, expected {feature_shape}") - if array.shape[-1] < 2: - raise ValueError("anova1_cell requires at least two cases in each condition") - - counts = np.array([array.shape[-1] for array in arrays], dtype=float) - means = np.stack([_stat_mean(array, axis=-1) for array in arrays], axis=-1) - total_n = int(np.sum(counts)) - grand_mean = np.sum(means * counts, axis=-1) / total_n - ss_between = np.sum(counts * (means - np.expand_dims(grand_mean, -1)) ** 2, axis=-1) - ss_within = np.zeros(feature_shape, dtype=float) - for array, mean in zip(arrays, np.moveaxis(means, -1, 0), strict=True): - ss_within = ss_within + _sum_square_residuals(array, mean, axis=-1) - - df_between = len(arrays) - 1 - df_within = total_n - len(arrays) - with np.errstate(divide="ignore", invalid="ignore"): - f_values = (ss_between / df_between) / (ss_within / df_within) - return f_values, (df_between, df_within) - - -def anova1rm_cell(data: Any, *, axis: int = -1) -> tuple[np.ndarray, tuple[int, int]]: - """Compute one-way repeated-measures ANOVA F-statistics.""" - - arrays = _one_way_arrays(data, axis=axis, paired=True) - if len(arrays) < 2: - raise ValueError("anova1rm_cell requires at least two conditions") - _require_same_shapes(arrays, "anova1rm_cell") - n_cases = arrays[0].shape[-1] - if n_cases < 2: - raise ValueError("anova1rm_cell requires at least two repeated cases") - - stacked = np.stack(arrays, axis=-2) - values = _anova_values(stacked) - n_conditions = len(arrays) - condition_subject = values - condition_sums = np.sum(condition_subject, axis=-1) - subject_sums = np.sum(condition_subject, axis=-2) - total_sum = np.sum(condition_sums, axis=-1) - - df_condition = n_conditions - 1 - df_error = (n_conditions - 1) * (n_cases - 1) - expected_condition = np.sum(condition_sums**2, axis=-1) / n_cases - expected_subject = np.sum(subject_sums**2, axis=-1) / n_conditions - expected_condition_subject = np.sum(condition_subject**2, axis=(-2, -1)) - expected_total = total_sum**2 / (n_conditions * n_cases) - - ss_condition = expected_condition - expected_total - ss_error = expected_condition_subject - expected_condition - expected_subject + expected_total - with np.errstate(divide="ignore", invalid="ignore"): - f_values = (ss_condition / df_condition) / (ss_error / df_error) - return f_values, (df_condition, df_error) - - -def anova2_cell(data: Any, *, axis: int = -1) -> TwoWayAnovaResult: - """Compute balanced two-way unpaired ANOVA F-statistics.""" - - stacked = _two_way_stack(data, axis=axis, name="anova2_cell") - values = _anova_values(stacked) - n_rows = stacked.shape[-3] - n_columns = stacked.shape[-2] - n_cases = stacked.shape[-1] - if n_rows < 2 or n_columns < 2: - raise ValueError("anova2_cell requires at least two rows and two columns") - if n_cases < 2: - raise ValueError("anova2_cell requires at least two cases in each cell") - - means = _stat_mean(stacked, axis=-1) - residual_ss = np.sum((values - np.expand_dims(means, -1)) ** 2, axis=-1) - error_ss = np.sum(residual_ss, axis=(-2, -1)) - grand = np.mean(means, axis=(-2, -1)) - row_means = np.mean(means, axis=-1) - column_means = np.mean(means, axis=-2) - - row_ss = n_columns * n_cases * np.sum((row_means - np.expand_dims(grand, -1)) ** 2, axis=-1) - column_ss = n_rows * n_cases * np.sum((column_means - np.expand_dims(grand, -1)) ** 2, axis=-1) - interaction_terms = ( - means - - np.expand_dims(row_means, -1) - - np.expand_dims(column_means, -2) - + np.expand_dims(np.expand_dims(grand, -1), -1) - ) - interaction_ss = n_cases * np.sum(interaction_terms**2, axis=(-2, -1)) - - df_error = n_rows * n_columns * (n_cases - 1) - df_rows = (n_rows - 1, df_error) - df_columns = (n_columns - 1, df_error) - df_interaction = ((n_rows - 1) * (n_columns - 1), df_error) - with np.errstate(divide="ignore", invalid="ignore"): - row_f = (row_ss / df_rows[0]) / (error_ss / df_error) - column_f = (column_ss / df_columns[0]) / (error_ss / df_error) - interaction_f = (interaction_ss / df_interaction[0]) / (error_ss / df_error) - return TwoWayAnovaResult(row_f, column_f, interaction_f, df_rows, df_columns, df_interaction) - - -def anova2rm_cell(data: Any, *, axis: int = -1) -> TwoWayAnovaResult: - """Compute two-way repeated-measures ANOVA F-statistics.""" - - stacked = _two_way_stack(data, axis=axis, name="anova2rm_cell") - values = _anova_values(stacked) - n_rows = stacked.shape[-3] - n_columns = stacked.shape[-2] - n_subjects = stacked.shape[-1] - if n_rows < 2 or n_columns < 2: - raise ValueError("anova2rm_cell requires at least two rows and two columns") - if n_subjects < 2: - raise ValueError("anova2rm_cell requires at least two repeated cases") - - ab_sums = np.sum(values, axis=-1) - row_subject_sums = np.sum(values, axis=-2) - column_subject_sums = np.sum(values, axis=-3) - row_sums = np.sum(ab_sums, axis=-1) - column_sums = np.sum(ab_sums, axis=-2) - subject_sums = np.sum(row_subject_sums, axis=-2) - total_sum = np.sum(row_sums, axis=-1) - - df_rows_num = n_rows - 1 - df_columns_num = n_columns - 1 - df_interaction_num = (n_rows - 1) * (n_columns - 1) - df_row_subject = (n_rows - 1) * (n_subjects - 1) - df_column_subject = (n_columns - 1) * (n_subjects - 1) - df_interaction_subject = (n_rows - 1) * (n_columns - 1) * (n_subjects - 1) - - expected_rows = np.sum(row_sums**2, axis=-1) / (n_columns * n_subjects) - expected_columns = np.sum(column_sums**2, axis=-1) / (n_rows * n_subjects) - expected_ab = np.sum(ab_sums**2, axis=(-2, -1)) / n_subjects - expected_subjects = np.sum(subject_sums**2, axis=-1) / (n_rows * n_columns) - expected_row_subject = np.sum(row_subject_sums**2, axis=(-2, -1)) / n_columns - expected_column_subject = np.sum(column_subject_sums**2, axis=(-2, -1)) / n_rows - expected_y = np.sum(values**2, axis=(-3, -2, -1)) - expected_total = total_sum**2 / (n_rows * n_columns * n_subjects) - - ss_rows = expected_rows - expected_total - ss_columns = expected_columns - expected_total - ss_interaction = expected_ab - expected_rows - expected_columns + expected_total - ss_subjects = expected_subjects - expected_total - ss_row_subject = expected_row_subject - expected_rows - expected_subjects + expected_total - ss_column_subject = expected_column_subject - expected_columns - expected_subjects + expected_total - ss_interaction_subject = ( - expected_y - - expected_ab - - expected_row_subject - - expected_column_subject - + expected_rows - + expected_columns - + expected_subjects - - expected_total - ) - del ss_subjects - - with np.errstate(divide="ignore", invalid="ignore"): - row_f = (ss_rows / df_rows_num) / (ss_row_subject / df_row_subject) - column_f = (ss_columns / df_columns_num) / (ss_column_subject / df_column_subject) - interaction_f = (ss_interaction / df_interaction_num) / (ss_interaction_subject / df_interaction_subject) - return TwoWayAnovaResult( - row_f, - column_f, - interaction_f, - (df_rows_num, df_row_subject), - (df_columns_num, df_column_subject), - (df_interaction_num, df_interaction_subject), - ) - - -def surrogdistrib( - data: Any, - *, - method: str = "perm", - pairing: str = "on", - naccu: int = 1, - axis: int = -1, - rng: np.random.Generator | int | None = None, -) -> SurrogateDistribution: - """Build bootstrap or permutation surrogate condition grids. - - Args: - data: One- or two-dimensional sequence of condition arrays. - method: ``"perm"``/``"permutation"`` or ``"bootstrap"``. - pairing: ``"on"`` to preserve case identity across conditions or - ``"off"`` to resample from the pooled case axis. - naccu: Number of surrogate grids to generate. - axis: Axis in each condition array that stores cases. - rng: Optional NumPy generator or seed for deterministic resampling. - """ - - method_name = _normalize_method(method) - if method_name == "param": - raise ValueError("surrogdistrib only supports permutation or bootstrap methods") - pairing_name = pairing.lower() - if pairing_name not in {"on", "off"}: - raise ValueError("pairing must be 'on' or 'off'") - count = int(naccu) - if count < 1: - raise ValueError("naccu must be at least 1") - - generator = _rng(rng) - grid = _condition_grid(data, axis=axis, min_cases=1) - samples = tuple( - _resampled_grid(grid, bootstrap=method_name == "bootstrap", paired=pairing_name == "on", rng=generator) - for _ in range(count) - ) - return SurrogateDistribution(samples) - - -def statcond( - data: Any, - *, - paired: str | bool = "auto", - method: str = "param", - mode: str | None = None, - naccu: int = 200, - variance: str = "homogenous", - forceanova: bool = False, - tail: str = "both", - axis: int = -1, - rng: np.random.Generator | int | None = None, - alpha: float | None = None, - surrog: Any = None, - stats: Any = None, - return_resampling_array: bool = False, -) -> StatcondResult | SurrogateDistribution: - """Compare condition arrays using EEGLAB-style t-tests or ANOVAs. - - Args: - data: One- or two-dimensional sequence of condition arrays. The case - dimension is the last axis by default. - paired: ``"auto"``, ``"on"``/``True``, or ``"off"``/``False``. - method: ``"param"``, ``"perm"``, or ``"bootstrap"``. - mode: Legacy alias for ``method``. - naccu: Number of surrogate samples for nonparametric methods. - variance: ``"homogenous"`` or ``"inhomogenous"`` for unpaired t-tests. - forceanova: Use one-way ANOVA instead of a two-condition t-test. - tail: Empirical-tail mode for supplied or computed surrogates. - axis: Axis in each condition array that stores cases. - rng: Optional NumPy generator or seed for deterministic resampling. - alpha: Optional threshold for confidence intervals and masks; requires - a nonparametric method or supplied surrogate statistics. - surrog: Precomputed surrogate statistic array. - stats: Observed statistic to pair with ``surrog``. - return_resampling_array: Return surrogate condition grids instead of - computing statistics. - """ - - method_name = _normalize_method(mode or method) - grid = _condition_grid(data, axis=axis, min_cases=2) - paired_flag = _paired_flag(grid, paired) - if return_resampling_array: - if method_name == "param": - raise ValueError("return_resampling_array requires 'perm' or 'bootstrap'") - return surrogdistrib( - grid, - method=method_name, - pairing="on" if paired_flag else "off", - naccu=naccu, - rng=rng, - ) - - if surrog is not None: - if stats is None: - raise ValueError("stats must be supplied when surrog is supplied") - observed_stat = stats - observed_df = None - surrogate_stat = surrog - pvalue = _surrogate_pvalues(surrogate_stat, observed_stat, tail) - ci = None - mask = None - if alpha is not None: - ci = _surrogate_ci(surrogate_stat, alpha, _ci_tail(tail)) - mask = _effect_map(pvalue, lambda value: value < alpha) - return StatcondResult( - observed_stat, observed_df, pvalue, surrogate_stat, method_name, paired_flag, ci=ci, mask=mask - ) - - observed_stat, observed_df, statistic_kind = _compute_statistic( - grid, - paired=paired_flag, - variance=variance, - forceanova=forceanova, - ) - surrogate_stat = None - if method_name == "param": - pvalue = _parametric_pvalues(observed_stat, observed_df, statistic_kind) - else: - surrogate_stat = _compute_surrogate_statistics( - grid, - paired=paired_flag, - method=method_name, - naccu=naccu, - variance=variance, - forceanova=forceanova, - rng=rng, - ) - empirical_tail = "one" if statistic_kind.startswith("f") else tail - pvalue = _surrogate_pvalues(surrogate_stat, observed_stat, empirical_tail) - - ci = None - mask = None - if alpha is not None: - if surrogate_stat is None: - raise ValueError("alpha confidence intervals require a nonparametric method or supplied surrogates") - empirical_tail = "one" if statistic_kind.startswith("f") else tail - ci = _surrogate_ci(surrogate_stat, alpha, _ci_tail(empirical_tail)) - mask = _effect_map(pvalue, lambda value: value < alpha) - - return StatcondResult( - observed_stat, observed_df, pvalue, surrogate_stat, method_name, paired_flag, ci=ci, mask=mask - ) - - -def teststat(seed: int = 0) -> dict[str, float]: - """Run deterministic smoke checks for the EEGPrep statistics package.""" - - rng = np.random.default_rng(seed) - first = rng.normal(size=(3, 12)) - second = first + rng.normal(loc=0.25, scale=0.2, size=(3, 12)) - paired_result = statcond([first, second], paired="on", method="param") - if not isinstance(paired_result, StatcondResult): - raise AssertionError("paired statcond unexpectedly returned surrogate grids") - t_values, df = ttest_cell(first, second) - np.testing.assert_allclose(paired_result.stat, t_values) - if paired_result.df != df: - raise AssertionError("paired t-test degrees of freedom changed") - - groups = [rng.normal(loc=offset, size=(3, 10)) for offset in (0.0, 0.2, 0.5)] - one_way = statcond(groups, paired="off", method="param") - if not isinstance(one_way, StatcondResult): - raise AssertionError("one-way statcond unexpectedly returned surrogate grids") - direct_one_way, one_way_df = anova1_cell(groups) - np.testing.assert_allclose(one_way.stat, direct_one_way) - if one_way.df != one_way_df: - raise AssertionError("one-way ANOVA degrees of freedom changed") - - grid = ( - (rng.normal(size=(2, 9)), rng.normal(loc=0.1, size=(2, 9))), - (rng.normal(loc=0.2, size=(2, 9)), rng.normal(loc=0.4, size=(2, 9))), - ) - two_way = statcond(grid, paired="on", method="param") - if not isinstance(two_way, StatcondResult): - raise AssertionError("two-way statcond unexpectedly returned surrogate grids") - if not isinstance(two_way.stat, TwoWayEffects): - raise AssertionError("two-way statcond did not return factor effects") - - return { - "paired_t_mean": float(np.mean(paired_result.stat)), - "one_way_f_mean": float(np.mean(one_way.stat)), - "two_way_interaction_mean": float(np.mean(two_way.stat.interaction)), - } - - -def _as_numeric_array(value: Any, name: str, *, axis: int = -1, require_axis: bool = True) -> np.ndarray: - array = np.asarray(value) - if not np.issubdtype(array.dtype, np.number): - raise TypeError(f"{name} must be numeric") - if require_axis and array.ndim == 0: - raise ValueError(f"{name} must have at least one dimension") - if np.any(~np.isfinite(array)): - raise ValueError(f"{name} must contain only finite values") - if not require_axis: - return array.astype(np.complex128 if np.iscomplexobj(array) else np.float64, copy=False) - - if not -array.ndim <= axis < array.ndim: - raise ValueError(f"axis {axis} is out of bounds for array with {array.ndim} dimensions") - normalized_axis = axis % array.ndim - if normalized_axis != array.ndim - 1: - array = np.moveaxis(array, normalized_axis, -1) - return array.astype(np.complex128 if np.iscomplexobj(array) else np.float64, copy=False) - - -def _condition_grid(data: Any, *, axis: int = -1, min_cases: int = 1) -> tuple[tuple[np.ndarray, ...], ...]: - raw_grid = _raw_condition_grid(data) - grid: list[tuple[np.ndarray, ...]] = [] - for row_index, row in enumerate(raw_grid): - converted_row = [] - for column_index, value in enumerate(row): - array = _as_numeric_array(value, f"condition ({row_index}, {column_index})", axis=axis) - if array.shape[-1] < min_cases: - raise ValueError( - f"condition ({row_index}, {column_index}) must have at least {min_cases} cases on the case axis" - ) - converted_row.append(array) - grid.append(tuple(converted_row)) - - if len(grid) > 1 and len(grid[0]) == 1: - grid = [tuple(row[0] for row in grid)] - return tuple(grid) - - -def _raw_condition_grid(data: Any) -> tuple[tuple[Any, ...], ...]: - if isinstance(data, np.ndarray) and data.dtype == object: - if data.ndim == 1: - return (tuple(data.tolist()),) - if data.ndim == 2: - return tuple(tuple(row) for row in data.tolist()) - raise ValueError("object-array condition grids must be one- or two-dimensional") - if not isinstance(data, Sequence) or isinstance(data, str | bytes): - raise TypeError("data must be a sequence of condition arrays") - if len(data) == 0: - raise ValueError("data must contain at least one condition") - - first = data[0] - if _looks_like_condition_row(first): - rows = [] - expected_width: int | None = None - for row in data: - if not _looks_like_condition_row(row): - raise ValueError("data rows must all be condition sequences") - row_tuple = tuple(row) - if len(row_tuple) == 0: - raise ValueError("condition rows must not be empty") - if expected_width is None: - expected_width = len(row_tuple) - elif len(row_tuple) != expected_width: - raise ValueError("all condition rows must have the same length") - rows.append(row_tuple) - return tuple(rows) - return (tuple(data),) - - -def _looks_like_condition_row(value: Any) -> bool: - return isinstance(value, Sequence) and not isinstance(value, np.ndarray | str | bytes) - - -def _flatten_grid(grid: tuple[tuple[np.ndarray, ...], ...]) -> list[np.ndarray]: - return [array for row in grid for array in row] - - -def _two_arrays(a: Any, b: Any | None, caller: str, *, axis: int) -> tuple[np.ndarray, np.ndarray]: - if b is None: - grid = _condition_grid(a, axis=axis, min_cases=1) - arrays = _flatten_grid(grid) - if len(arrays) != 2: - raise ValueError(f"{caller} requires exactly two arrays") - return arrays[0], arrays[1] - return ( - _as_numeric_array(a, "a", axis=axis), - _as_numeric_array(b, "b", axis=axis), - ) - - -def _one_way_arrays(data: Any, *, axis: int, paired: bool) -> list[np.ndarray]: - grid = _condition_grid(data, axis=axis, min_cases=2 if paired else 1) - if len(grid) != 1: - raise ValueError("one-way ANOVA helpers require a one-dimensional condition sequence") - return list(grid[0]) - - -def _require_same_shapes(arrays: Sequence[np.ndarray], name: str) -> None: - expected_shape = arrays[0].shape - for index, array in enumerate(arrays): - if array.shape != expected_shape: - raise ValueError(f"{name} requires identical shapes; condition {index} has shape {array.shape}") - - -def _two_way_stack(data: Any, *, axis: int, name: str) -> np.ndarray: - grid = _condition_grid(data, axis=axis, min_cases=2) - if len(grid) < 1 or len(grid[0]) < 1: - raise ValueError(f"{name} requires a non-empty condition grid") - expected_shape = grid[0][0].shape - for row_index, row in enumerate(grid): - for column_index, array in enumerate(row): - if array.shape != expected_shape: - raise ValueError( - f"{name} requires balanced cell shapes; condition ({row_index}, {column_index}) has " - f"shape {array.shape}, expected {expected_shape}" - ) - return np.stack([np.stack(row, axis=-2) for row in grid], axis=-3) - - -def _stat_mean(array: np.ndarray, *, axis: int) -> np.ndarray: - result = np.mean(array, axis=axis) - if np.iscomplexobj(array): - return np.abs(result) - return result - - -def _stat_std(array: np.ndarray, *, axis: int) -> np.ndarray: - if np.iscomplexobj(array): - return np.std(np.abs(array), axis=axis, ddof=1) - return np.sqrt(np.sum((array - np.mean(array, axis=axis, keepdims=True)) ** 2, axis=axis) / (array.shape[axis] - 1)) - - -def _anova_values(array: np.ndarray) -> np.ndarray: - if np.iscomplexobj(array): - return np.abs(array) - return array - - -def _sum_square_residuals(array: np.ndarray, mean: np.ndarray, *, axis: int) -> np.ndarray: - values = _anova_values(array) - return np.sum((values - np.expand_dims(mean, axis=axis)) ** 2, axis=axis) - - -def _normalize_method(method: str) -> str: - method_name = method.lower() - if method_name == "parametric": - return "param" - if method_name == "permutation": - return "perm" - if method_name in {"param", "perm", "bootstrap"}: - return method_name - raise ValueError("method must be 'param', 'perm', 'permutation', 'parametric', or 'bootstrap'") - - -def _paired_flag(grid: tuple[tuple[np.ndarray, ...], ...], paired: str | bool) -> bool: - counts = [array.shape[-1] for array in _flatten_grid(grid)] - can_pair = len(set(counts)) == 1 - if paired == "auto": - return can_pair - if isinstance(paired, str): - paired_name = paired.lower() - if paired_name not in {"on", "off"}: - raise ValueError("paired must be 'auto', 'on', 'off', True, or False") - requested = paired_name == "on" - else: - requested = bool(paired) - if requested and not can_pair: - raise ValueError("paired statistics require the same number of cases in every condition") - return requested - - -def _compute_statistic( - grid: tuple[tuple[np.ndarray, ...], ...], - *, - paired: bool, - variance: str, - forceanova: bool, -) -> tuple[Any, Any, str]: - rows = len(grid) - columns = len(grid[0]) - if rows == 1: - if columns == 2 and not forceanova: - if paired: - stat, df = ttest_cell(grid[0][0], grid[0][1]) - else: - stat, df = ttest2_cell(grid[0][0], grid[0][1], variance=variance) - return stat, df, "t" - if paired: - stat, df = anova1rm_cell(grid[0]) - else: - stat, df = anova1_cell(grid[0]) - return stat, df, "f_one_way" - - anova = anova2rm_cell(grid) if paired else anova2_cell(grid) - return anova.as_effects(), anova.df_effects(), "f_two_way" - - -def _parametric_pvalues(stat: Any, df: Any, statistic_kind: str) -> Any: - if statistic_kind == "t": - return 2 * scipy_stats.t.sf(np.abs(stat), df) - if isinstance(stat, TwoWayEffects): - return TwoWayEffects( - scipy_stats.f.sf(stat.rows, df.rows[0], df.rows[1]), - scipy_stats.f.sf(stat.columns, df.columns[0], df.columns[1]), - scipy_stats.f.sf(stat.interaction, df.interaction[0], df.interaction[1]), - ) - return scipy_stats.f.sf(stat, df[0], df[1]) - - -def _compute_surrogate_statistics( - grid: tuple[tuple[np.ndarray, ...], ...], - *, - paired: bool, - method: str, - naccu: int, - variance: str, - forceanova: bool, - rng: np.random.Generator | int | None, -) -> Any: - distribution = surrogdistrib( - grid, - method=method, - pairing="on" if paired else "off", - naccu=naccu, - rng=rng, - ) - stats = [] - for sample in distribution: - sample_stat, _sample_df, _kind = _compute_statistic( - sample, - paired=paired, - variance=variance, - forceanova=forceanova, - ) - stats.append(sample_stat) - return _stack_effects(stats) - - -def _stack_effects(values: Sequence[Any]) -> Any: - first = values[0] - if isinstance(first, TwoWayEffects): - return TwoWayEffects( - np.stack([value.rows for value in values], axis=-1), - np.stack([value.columns for value in values], axis=-1), - np.stack([value.interaction for value in values], axis=-1), - ) - return np.stack(values, axis=-1) - - -def _surrogate_pvalues(surrogate: Any, observed: Any, tail: str) -> Any: - if isinstance(surrogate, TwoWayEffects): - return TwoWayEffects( - stat_surrogate_pvals(surrogate.rows, observed.rows, tail), - stat_surrogate_pvals(surrogate.columns, observed.columns, tail), - stat_surrogate_pvals(surrogate.interaction, observed.interaction, tail), - ) - return stat_surrogate_pvals(surrogate, observed, tail) - - -def _surrogate_ci(surrogate: Any, alpha: float, tail: str) -> Any: - if isinstance(surrogate, TwoWayEffects): - return TwoWayEffects( - stat_surrogate_ci(surrogate.rows, alpha, tail), - stat_surrogate_ci(surrogate.columns, alpha, tail), - stat_surrogate_ci(surrogate.interaction, alpha, tail), - ) - return stat_surrogate_ci(surrogate, alpha, tail) - - -def _ci_tail(tail: str) -> str: - tail_name = tail.lower() - if tail_name == "right": - return "upper" - if tail_name == "left": - return "lower" - return tail_name - - -def _effect_map(effect: Any, func: Any) -> Any: - if isinstance(effect, TwoWayEffects): - return TwoWayEffects(func(effect.rows), func(effect.columns), func(effect.interaction)) - return func(effect) - - -def _rng(rng: np.random.Generator | int | None) -> np.random.Generator: - if isinstance(rng, np.random.Generator): - return rng - return np.random.default_rng(rng) - - -def _resampled_grid( - grid: tuple[tuple[np.ndarray, ...], ...], - *, - bootstrap: bool, - paired: bool, - rng: np.random.Generator, -) -> tuple[tuple[np.ndarray, ...], ...]: - arrays = _flatten_grid(grid) - feature_shape = arrays[0].shape[:-1] - for index, array in enumerate(arrays): - if array.shape[:-1] != feature_shape: - raise ValueError(f"condition {index} has feature shape {array.shape[:-1]}, expected {feature_shape}") - counts = [array.shape[-1] for array in arrays] - - if paired: - if len(set(counts)) != 1: - raise ValueError("paired surrogate resampling requires equal case counts") - resampled = _paired_resample(arrays, bootstrap=bootstrap, rng=rng) - else: - resampled = _unpaired_resample(arrays, counts, bootstrap=bootstrap, rng=rng) - - iterator = iter(resampled) - return tuple(tuple(next(iterator) for _column in row) for row in grid) - - -def _paired_resample( - arrays: Sequence[np.ndarray], - *, - bootstrap: bool, - rng: np.random.Generator, -) -> list[np.ndarray]: - n_conditions = len(arrays) - n_cases = arrays[0].shape[-1] - output = [np.empty_like(array) for array in arrays] - for case_index in range(n_cases): - source_conditions = ( - rng.integers(0, n_conditions, size=n_conditions) if bootstrap else rng.permutation(n_conditions) - ) - for target_condition, source_condition in enumerate(source_conditions): - output[target_condition][..., case_index] = arrays[int(source_condition)][..., case_index] - return output - - -def _unpaired_resample( - arrays: Sequence[np.ndarray], - counts: Sequence[int], - *, - bootstrap: bool, - rng: np.random.Generator, -) -> list[np.ndarray]: - pooled = np.concatenate(arrays, axis=-1) - total_cases = pooled.shape[-1] - if bootstrap: - indices = rng.integers(0, total_cases, size=total_cases) - else: - indices = rng.permutation(total_cases) - - output = [] - start = 0 - for count in counts: - stop = start + count - output.append(np.take(pooled, indices[start:stop], axis=-1)) - start = stop - return output +"""Compatibility re-exports for statistics helpers. + +Implementations live in the same-name public modules. This private facade is +kept so older internal imports from ``statistics._core`` continue to resolve +while the large implementation module is retired. +""" + +from eegprep.functions.statistics._shared import TwoWayAnovaResult, TwoWayEffects +from eegprep.functions.statistics.anova1_cell import anova1_cell +from eegprep.functions.statistics.anova1rm_cell import anova1rm_cell +from eegprep.functions.statistics.anova2_cell import anova2_cell +from eegprep.functions.statistics.anova2rm_cell import anova2rm_cell +from eegprep.functions.statistics.concatdata import ConcatenatedData, concatdata +from eegprep.functions.statistics.corrcoef_cell import corrcoef_cell +from eegprep.functions.statistics.fdr import FDRResult, fdr +from eegprep.functions.statistics.stat_surrogate_ci import stat_surrogate_ci +from eegprep.functions.statistics.stat_surrogate_pvals import stat_surrogate_pvals +from eegprep.functions.statistics.statcond import StatcondResult, statcond +from eegprep.functions.statistics.surrogdistrib import SurrogateDistribution, surrogdistrib +from eegprep.functions.statistics.teststat import teststat +from eegprep.functions.statistics.ttest2_cell import ttest2_cell +from eegprep.functions.statistics.ttest_cell import ttest_cell + +__all__ = [ + "ConcatenatedData", + "FDRResult", + "StatcondResult", + "SurrogateDistribution", + "TwoWayAnovaResult", + "TwoWayEffects", + "anova1_cell", + "anova1rm_cell", + "anova2_cell", + "anova2rm_cell", + "concatdata", + "corrcoef_cell", + "fdr", + "stat_surrogate_ci", + "stat_surrogate_pvals", + "statcond", + "surrogdistrib", + "teststat", + "ttest2_cell", + "ttest_cell", +] diff --git a/src/eegprep/functions/statistics/_shared.py b/src/eegprep/functions/statistics/_shared.py new file mode 100644 index 00000000..cc0facca --- /dev/null +++ b/src/eegprep/functions/statistics/_shared.py @@ -0,0 +1,224 @@ +"""Shared implementation helpers for EEGLAB-style statistics functions.""" + +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any + +import numpy as np + + +@dataclass(frozen=True) +class TwoWayEffects: + """Row, column, and interaction values for a two-way design.""" + + rows: Any + columns: Any + interaction: Any + + def __iter__(self): + yield self.rows + yield self.columns + yield self.interaction + + +@dataclass(frozen=True) +class TwoWayAnovaResult: + """Two-way ANOVA statistics and degrees of freedom.""" + + rows: np.ndarray + columns: np.ndarray + interaction: np.ndarray + df_rows: tuple[int, int] + df_columns: tuple[int, int] + df_interaction: tuple[int, int] + + def as_effects(self) -> TwoWayEffects: + return TwoWayEffects(self.rows, self.columns, self.interaction) + + def df_effects(self) -> TwoWayEffects: + return TwoWayEffects(self.df_rows, self.df_columns, self.df_interaction) + + +def as_numeric_array(value: Any, name: str, *, axis: int = -1, require_axis: bool = True) -> np.ndarray: + array = np.asarray(value) + if not np.issubdtype(array.dtype, np.number): + raise TypeError(f"{name} must be numeric") + if require_axis and array.ndim == 0: + raise ValueError(f"{name} must have at least one dimension") + if np.any(~np.isfinite(array)): + raise ValueError(f"{name} must contain only finite values") + if not require_axis: + return array.astype(np.complex128 if np.iscomplexobj(array) else np.float64, copy=False) + + if not -array.ndim <= axis < array.ndim: + raise ValueError(f"axis {axis} is out of bounds for array with {array.ndim} dimensions") + normalized_axis = axis % array.ndim + if normalized_axis != array.ndim - 1: + array = np.moveaxis(array, normalized_axis, -1) + return array.astype(np.complex128 if np.iscomplexobj(array) else np.float64, copy=False) + + +def condition_grid(data: Any, *, axis: int = -1, min_cases: int = 1) -> tuple[tuple[np.ndarray, ...], ...]: + raw_grid = raw_condition_grid(data) + grid: list[tuple[np.ndarray, ...]] = [] + for row_index, row in enumerate(raw_grid): + converted_row = [] + for column_index, value in enumerate(row): + array = as_numeric_array(value, f"condition ({row_index}, {column_index})", axis=axis) + if array.shape[-1] < min_cases: + raise ValueError( + f"condition ({row_index}, {column_index}) must have at least {min_cases} cases on the case axis" + ) + converted_row.append(array) + grid.append(tuple(converted_row)) + + if len(grid) > 1 and len(grid[0]) == 1: + grid = [tuple(row[0] for row in grid)] + return tuple(grid) + + +def raw_condition_grid(data: Any) -> tuple[tuple[Any, ...], ...]: + if isinstance(data, np.ndarray) and data.dtype == object: + if data.ndim == 1: + return (tuple(data.tolist()),) + if data.ndim == 2: + return tuple(tuple(row) for row in data.tolist()) + raise ValueError("object-array condition grids must be one- or two-dimensional") + if not isinstance(data, Sequence) or isinstance(data, str | bytes): + raise TypeError("data must be a sequence of condition arrays") + if len(data) == 0: + raise ValueError("data must contain at least one condition") + + first = data[0] + if looks_like_condition_row(first): + rows = [] + expected_width: int | None = None + for row in data: + if not looks_like_condition_row(row): + raise ValueError("data rows must all be condition sequences") + row_tuple = tuple(row) + if len(row_tuple) == 0: + raise ValueError("condition rows must not be empty") + if expected_width is None: + expected_width = len(row_tuple) + elif len(row_tuple) != expected_width: + raise ValueError("all condition rows must have the same length") + rows.append(row_tuple) + return tuple(rows) + return (tuple(data),) + + +def looks_like_condition_row(value: Any) -> bool: + return isinstance(value, Sequence) and not isinstance(value, np.ndarray | str | bytes) + + +def flatten_grid(grid: tuple[tuple[np.ndarray, ...], ...]) -> list[np.ndarray]: + return [array for row in grid for array in row] + + +def two_arrays(a: Any, b: Any | None, caller: str, *, axis: int) -> tuple[np.ndarray, np.ndarray]: + if b is None: + grid = condition_grid(a, axis=axis, min_cases=1) + arrays = flatten_grid(grid) + if len(arrays) != 2: + raise ValueError(f"{caller} requires exactly two arrays") + return arrays[0], arrays[1] + return ( + as_numeric_array(a, "a", axis=axis), + as_numeric_array(b, "b", axis=axis), + ) + + +def one_way_arrays(data: Any, *, axis: int, paired: bool) -> list[np.ndarray]: + grid = condition_grid(data, axis=axis, min_cases=2 if paired else 1) + if len(grid) != 1: + raise ValueError("one-way ANOVA helpers require a one-dimensional condition sequence") + return list(grid[0]) + + +def require_same_shapes(arrays: Sequence[np.ndarray], name: str) -> None: + expected_shape = arrays[0].shape + for index, array in enumerate(arrays): + if array.shape != expected_shape: + raise ValueError(f"{name} requires identical shapes; condition {index} has shape {array.shape}") + + +def two_way_stack(data: Any, *, axis: int, name: str) -> np.ndarray: + grid = condition_grid(data, axis=axis, min_cases=2) + if len(grid) < 1 or len(grid[0]) < 1: + raise ValueError(f"{name} requires a non-empty condition grid") + expected_shape = grid[0][0].shape + for row_index, row in enumerate(grid): + for column_index, array in enumerate(row): + if array.shape != expected_shape: + raise ValueError( + f"{name} requires balanced cell shapes; condition ({row_index}, {column_index}) has " + f"shape {array.shape}, expected {expected_shape}" + ) + return np.stack([np.stack(row, axis=-2) for row in grid], axis=-3) + + +def stat_mean(array: np.ndarray, *, axis: int) -> np.ndarray: + result = np.mean(array, axis=axis) + if np.iscomplexobj(array): + return np.abs(result) + return result + + +def stat_std(array: np.ndarray, *, axis: int) -> np.ndarray: + if np.iscomplexobj(array): + return np.std(np.abs(array), axis=axis, ddof=1) + return np.sqrt(np.sum((array - np.mean(array, axis=axis, keepdims=True)) ** 2, axis=axis) / (array.shape[axis] - 1)) + + +def anova_values(array: np.ndarray) -> np.ndarray: + if np.iscomplexobj(array): + return np.abs(array) + return array + + +def sum_square_residuals(array: np.ndarray, mean: np.ndarray, *, axis: int) -> np.ndarray: + values = anova_values(array) + return np.sum((values - np.expand_dims(mean, axis=axis)) ** 2, axis=axis) + + +def normalize_method(method: str) -> str: + method_name = method.lower() + if method_name == "parametric": + return "param" + if method_name == "permutation": + return "perm" + if method_name in {"param", "perm", "bootstrap"}: + return method_name + raise ValueError("method must be 'param', 'perm', 'permutation', 'parametric', or 'bootstrap'") + + +def paired_flag(grid: tuple[tuple[np.ndarray, ...], ...], paired: str | bool) -> bool: + counts = [array.shape[-1] for array in flatten_grid(grid)] + can_pair = len(set(counts)) == 1 + if paired == "auto": + return can_pair + if isinstance(paired, str): + paired_name = paired.lower() + if paired_name not in {"on", "off"}: + raise ValueError("paired must be 'auto', 'on', 'off', True, or False") + requested = paired_name == "on" + else: + requested = bool(paired) + if requested and not can_pair: + raise ValueError("paired statistics require the same number of cases in every condition") + return requested + + +def effect_map(effect: Any, func: Any) -> Any: + if isinstance(effect, TwoWayEffects): + return TwoWayEffects(func(effect.rows), func(effect.columns), func(effect.interaction)) + return func(effect) + + +def rng_from_seed(rng: np.random.Generator | int | None) -> np.random.Generator: + if isinstance(rng, np.random.Generator): + return rng + return np.random.default_rng(rng) diff --git a/src/eegprep/functions/statistics/anova1_cell.py b/src/eegprep/functions/statistics/anova1_cell.py index 167e7457..aa0c3961 100644 --- a/src/eegprep/functions/statistics/anova1_cell.py +++ b/src/eegprep/functions/statistics/anova1_cell.py @@ -1,5 +1,41 @@ """One-way unpaired ANOVA helper.""" -from eegprep.functions.statistics._core import anova1_cell +from __future__ import annotations + +from typing import Any + +import numpy as np + +from eegprep.functions.statistics._shared import one_way_arrays, stat_mean, sum_square_residuals + + +def anova1_cell(data: Any, *, axis: int = -1) -> tuple[np.ndarray, tuple[int, int]]: + """Compute one-way unpaired ANOVA F-statistics across condition arrays.""" + + arrays = one_way_arrays(data, axis=axis, paired=False) + if len(arrays) < 2: + raise ValueError("anova1_cell requires at least two conditions") + feature_shape = arrays[0].shape[:-1] + for index, array in enumerate(arrays): + if array.shape[:-1] != feature_shape: + raise ValueError(f"condition {index} has feature shape {array.shape[:-1]}, expected {feature_shape}") + if array.shape[-1] < 2: + raise ValueError("anova1_cell requires at least two cases in each condition") + + counts = np.array([array.shape[-1] for array in arrays], dtype=float) + means = np.stack([stat_mean(array, axis=-1) for array in arrays], axis=-1) + total_n = int(np.sum(counts)) + grand_mean = np.sum(means * counts, axis=-1) / total_n + ss_between = np.sum(counts * (means - np.expand_dims(grand_mean, -1)) ** 2, axis=-1) + ss_within = np.zeros(feature_shape, dtype=float) + for array, mean in zip(arrays, np.moveaxis(means, -1, 0), strict=True): + ss_within = ss_within + sum_square_residuals(array, mean, axis=-1) + + df_between = len(arrays) - 1 + df_within = total_n - len(arrays) + with np.errstate(divide="ignore", invalid="ignore"): + f_values = (ss_between / df_between) / (ss_within / df_within) + return f_values, (df_between, df_within) + __all__ = ["anova1_cell"] diff --git a/src/eegprep/functions/statistics/anova1rm_cell.py b/src/eegprep/functions/statistics/anova1rm_cell.py index 4cba771f..b36bc3c4 100644 --- a/src/eegprep/functions/statistics/anova1rm_cell.py +++ b/src/eegprep/functions/statistics/anova1rm_cell.py @@ -1,5 +1,45 @@ """One-way repeated-measures ANOVA helper.""" -from eegprep.functions.statistics._core import anova1rm_cell +from __future__ import annotations + +from typing import Any + +import numpy as np + +from eegprep.functions.statistics._shared import anova_values, one_way_arrays, require_same_shapes + + +def anova1rm_cell(data: Any, *, axis: int = -1) -> tuple[np.ndarray, tuple[int, int]]: + """Compute one-way repeated-measures ANOVA F-statistics.""" + + arrays = one_way_arrays(data, axis=axis, paired=True) + if len(arrays) < 2: + raise ValueError("anova1rm_cell requires at least two conditions") + require_same_shapes(arrays, "anova1rm_cell") + n_cases = arrays[0].shape[-1] + if n_cases < 2: + raise ValueError("anova1rm_cell requires at least two repeated cases") + + stacked = np.stack(arrays, axis=-2) + values = anova_values(stacked) + n_conditions = len(arrays) + condition_subject = values + condition_sums = np.sum(condition_subject, axis=-1) + subject_sums = np.sum(condition_subject, axis=-2) + total_sum = np.sum(condition_sums, axis=-1) + + df_condition = n_conditions - 1 + df_error = (n_conditions - 1) * (n_cases - 1) + expected_condition = np.sum(condition_sums**2, axis=-1) / n_cases + expected_subject = np.sum(subject_sums**2, axis=-1) / n_conditions + expected_condition_subject = np.sum(condition_subject**2, axis=(-2, -1)) + expected_total = total_sum**2 / (n_conditions * n_cases) + + ss_condition = expected_condition - expected_total + ss_error = expected_condition_subject - expected_condition - expected_subject + expected_total + with np.errstate(divide="ignore", invalid="ignore"): + f_values = (ss_condition / df_condition) / (ss_error / df_error) + return f_values, (df_condition, df_error) + __all__ = ["anova1rm_cell"] diff --git a/src/eegprep/functions/statistics/anova2_cell.py b/src/eegprep/functions/statistics/anova2_cell.py index 186fb937..dab2cf02 100644 --- a/src/eegprep/functions/statistics/anova2_cell.py +++ b/src/eegprep/functions/statistics/anova2_cell.py @@ -1,5 +1,53 @@ """Two-way unpaired ANOVA helper.""" -from eegprep.functions.statistics._core import TwoWayAnovaResult, anova2_cell +from __future__ import annotations + +from typing import Any + +import numpy as np + +from eegprep.functions.statistics._shared import TwoWayAnovaResult, anova_values, stat_mean, two_way_stack + + +def anova2_cell(data: Any, *, axis: int = -1) -> TwoWayAnovaResult: + """Compute balanced two-way unpaired ANOVA F-statistics.""" + + stacked = two_way_stack(data, axis=axis, name="anova2_cell") + values = anova_values(stacked) + n_rows = stacked.shape[-3] + n_columns = stacked.shape[-2] + n_cases = stacked.shape[-1] + if n_rows < 2 or n_columns < 2: + raise ValueError("anova2_cell requires at least two rows and two columns") + if n_cases < 2: + raise ValueError("anova2_cell requires at least two cases in each cell") + + means = stat_mean(stacked, axis=-1) + residual_ss = np.sum((values - np.expand_dims(means, -1)) ** 2, axis=-1) + error_ss = np.sum(residual_ss, axis=(-2, -1)) + grand = np.mean(means, axis=(-2, -1)) + row_means = np.mean(means, axis=-1) + column_means = np.mean(means, axis=-2) + + row_ss = n_columns * n_cases * np.sum((row_means - np.expand_dims(grand, -1)) ** 2, axis=-1) + column_ss = n_rows * n_cases * np.sum((column_means - np.expand_dims(grand, -1)) ** 2, axis=-1) + interaction_terms = ( + means + - np.expand_dims(row_means, -1) + - np.expand_dims(column_means, -2) + + np.expand_dims(np.expand_dims(grand, -1), -1) + ) + interaction_ss = n_cases * np.sum(interaction_terms**2, axis=(-2, -1)) + + df_error = n_rows * n_columns * (n_cases - 1) + df_rows = (n_rows - 1, df_error) + df_columns = (n_columns - 1, df_error) + df_interaction = ((n_rows - 1) * (n_columns - 1), df_error) + with np.errstate(divide="ignore", invalid="ignore"): + row_f = (row_ss / df_rows[0]) / (error_ss / df_error) + column_f = (column_ss / df_columns[0]) / (error_ss / df_error) + interaction_f = (interaction_ss / df_interaction[0]) / (error_ss / df_error) + return TwoWayAnovaResult(row_f, column_f, interaction_f, df_rows, df_columns, df_interaction) + __all__ = ["TwoWayAnovaResult", "anova2_cell"] diff --git a/src/eegprep/functions/statistics/anova2rm_cell.py b/src/eegprep/functions/statistics/anova2rm_cell.py index f09747aa..5bc7307b 100644 --- a/src/eegprep/functions/statistics/anova2rm_cell.py +++ b/src/eegprep/functions/statistics/anova2rm_cell.py @@ -1,5 +1,81 @@ """Two-way repeated-measures ANOVA helper.""" -from eegprep.functions.statistics._core import TwoWayAnovaResult, anova2rm_cell +from __future__ import annotations + +from typing import Any + +import numpy as np + +from eegprep.functions.statistics._shared import TwoWayAnovaResult, anova_values, two_way_stack + + +def anova2rm_cell(data: Any, *, axis: int = -1) -> TwoWayAnovaResult: + """Compute two-way repeated-measures ANOVA F-statistics.""" + + stacked = two_way_stack(data, axis=axis, name="anova2rm_cell") + values = anova_values(stacked) + n_rows = stacked.shape[-3] + n_columns = stacked.shape[-2] + n_subjects = stacked.shape[-1] + if n_rows < 2 or n_columns < 2: + raise ValueError("anova2rm_cell requires at least two rows and two columns") + if n_subjects < 2: + raise ValueError("anova2rm_cell requires at least two repeated cases") + + ab_sums = np.sum(values, axis=-1) + row_subject_sums = np.sum(values, axis=-2) + column_subject_sums = np.sum(values, axis=-3) + row_sums = np.sum(ab_sums, axis=-1) + column_sums = np.sum(ab_sums, axis=-2) + subject_sums = np.sum(row_subject_sums, axis=-2) + total_sum = np.sum(row_sums, axis=-1) + + df_rows_num = n_rows - 1 + df_columns_num = n_columns - 1 + df_interaction_num = (n_rows - 1) * (n_columns - 1) + df_row_subject = (n_rows - 1) * (n_subjects - 1) + df_column_subject = (n_columns - 1) * (n_subjects - 1) + df_interaction_subject = (n_rows - 1) * (n_columns - 1) * (n_subjects - 1) + + expected_rows = np.sum(row_sums**2, axis=-1) / (n_columns * n_subjects) + expected_columns = np.sum(column_sums**2, axis=-1) / (n_rows * n_subjects) + expected_ab = np.sum(ab_sums**2, axis=(-2, -1)) / n_subjects + expected_subjects = np.sum(subject_sums**2, axis=-1) / (n_rows * n_columns) + expected_row_subject = np.sum(row_subject_sums**2, axis=(-2, -1)) / n_columns + expected_column_subject = np.sum(column_subject_sums**2, axis=(-2, -1)) / n_rows + expected_y = np.sum(values**2, axis=(-3, -2, -1)) + expected_total = total_sum**2 / (n_rows * n_columns * n_subjects) + + ss_rows = expected_rows - expected_total + ss_columns = expected_columns - expected_total + ss_interaction = expected_ab - expected_rows - expected_columns + expected_total + ss_subjects = expected_subjects - expected_total + ss_row_subject = expected_row_subject - expected_rows - expected_subjects + expected_total + ss_column_subject = expected_column_subject - expected_columns - expected_subjects + expected_total + ss_interaction_subject = ( + expected_y + - expected_ab + - expected_row_subject + - expected_column_subject + + expected_rows + + expected_columns + + expected_subjects + - expected_total + ) + del ss_subjects + + with np.errstate(divide="ignore", invalid="ignore"): + row_f = (ss_rows / df_rows_num) / (ss_row_subject / df_row_subject) + column_f = (ss_columns / df_columns_num) / (ss_column_subject / df_column_subject) + interaction_f = (ss_interaction / df_interaction_num) / (ss_interaction_subject / df_interaction_subject) + return TwoWayAnovaResult( + row_f, + column_f, + interaction_f, + (df_rows_num, df_row_subject), + (df_columns_num, df_column_subject), + (df_interaction_num, df_interaction_subject), + ) + __all__ = ["TwoWayAnovaResult", "anova2rm_cell"] diff --git a/src/eegprep/functions/statistics/concatdata.py b/src/eegprep/functions/statistics/concatdata.py index aa274e43..bd8587f8 100644 --- a/src/eegprep/functions/statistics/concatdata.py +++ b/src/eegprep/functions/statistics/concatdata.py @@ -1,5 +1,49 @@ """Concatenate condition arrays along their case axis.""" -from eegprep.functions.statistics._core import ConcatenatedData, concatdata +from __future__ import annotations + +from collections.abc import Iterator +from dataclasses import dataclass +from typing import Any + +import numpy as np + +from eegprep.functions.statistics._shared import condition_grid, flatten_grid + + +@dataclass(frozen=True) +class ConcatenatedData: + """Data concatenated across condition case axes.""" + + data: np.ndarray + lengths: np.ndarray + grid_shape: tuple[int, int] + + def __iter__(self) -> Iterator[np.ndarray | tuple[int, int]]: + yield self.data + yield self.lengths + yield self.grid_shape + + +def concatdata(data: Any, *, axis: int = -1) -> ConcatenatedData: + """Concatenate condition arrays along their case axis. + + Args: + data: One- or two-dimensional sequence of condition arrays. + axis: Axis in each condition array that stores cases. + """ + + grid = condition_grid(data, axis=axis, min_cases=1) + arrays = flatten_grid(grid) + feature_shape = arrays[0].shape[:-1] + for index, array in enumerate(arrays): + if array.shape[:-1] != feature_shape: + raise ValueError(f"condition {index} has feature shape {array.shape[:-1]}, expected {feature_shape}") + + lengths = np.zeros(len(arrays) + 1, dtype=int) + lengths[1:] = np.cumsum([array.shape[-1] for array in arrays]) + concatenated = np.concatenate(arrays, axis=-1) + return ConcatenatedData(concatenated, lengths, (len(grid), len(grid[0]))) + __all__ = ["ConcatenatedData", "concatdata"] diff --git a/src/eegprep/functions/statistics/corrcoef_cell.py b/src/eegprep/functions/statistics/corrcoef_cell.py index e6f6f7b1..3360ff7b 100644 --- a/src/eegprep/functions/statistics/corrcoef_cell.py +++ b/src/eegprep/functions/statistics/corrcoef_cell.py @@ -1,5 +1,30 @@ """Pairwise correlation helper.""" -from eegprep.functions.statistics._core import corrcoef_cell +from __future__ import annotations + +from typing import Any + +import numpy as np + +from eegprep.functions.statistics._shared import two_arrays + + +def corrcoef_cell(a: Any, b: Any | None = None, *, axis: int = -1) -> np.ndarray: + """Compute pairwise correlations along a case axis.""" + + first, second = two_arrays(a, b, "corrcoef_cell", axis=axis) + if first.shape != second.shape: + raise ValueError("corrcoef_cell requires arrays with identical shapes") + if first.shape[-1] < 2: + raise ValueError("corrcoef_cell requires at least two cases") + + first_centered = first - np.mean(first, axis=-1, keepdims=True) + second_centered = second - np.mean(second, axis=-1, keepdims=True) + covariance = np.sum(first_centered * second_centered, axis=-1) + first_power = np.sum(first_centered * first_centered, axis=-1) + second_power = np.sum(second_centered * second_centered, axis=-1) + with np.errstate(divide="ignore", invalid="ignore"): + return covariance / np.sqrt(first_power * second_power) + __all__ = ["corrcoef_cell"] diff --git a/src/eegprep/functions/statistics/fdr.py b/src/eegprep/functions/statistics/fdr.py index 35a35dd7..5ca743c8 100644 --- a/src/eegprep/functions/statistics/fdr.py +++ b/src/eegprep/functions/statistics/fdr.py @@ -1,5 +1,76 @@ """False discovery rate helper.""" -from eegprep.functions.statistics._core import FDRResult, fdr +from __future__ import annotations + +from collections.abc import Iterator +from dataclasses import dataclass +from typing import Any + +import numpy as np + + +@dataclass(frozen=True) +class FDRResult: + """False discovery rate threshold and mask.""" + + threshold: np.ndarray | float + mask: np.ndarray + + def __iter__(self) -> Iterator[np.ndarray | float]: + yield self.threshold + yield self.mask + + +def fdr(pvals: Any, q: float | None = None, fdr_type: str = "parametric") -> FDRResult: + """Compute Benjamini-Hochberg or Benjamini-Yekutieli FDR thresholds. + + Args: + pvals: Numeric p-value array with values in the closed interval [0, 1]. + q: Desired false discovery rate. If omitted, an EEGLAB-style array of + corrected thresholds is returned. + fdr_type: ``"parametric"`` for Benjamini-Hochberg or + ``"nonparametric"``/``"nonParametric"`` for Benjamini-Yekutieli. + + Returns: + Threshold and boolean mask with the same shape as ``pvals``. + """ + + values = np.asarray(pvals) + if not np.issubdtype(values.dtype, np.number): + raise TypeError("pvals must be numeric") + if values.size == 0: + return FDRResult(np.array([], dtype=float), np.array([], dtype=bool)) + finite_mask = np.isfinite(values) + finite_values = values[finite_mask] + if np.any((finite_values < 0) | (finite_values > 1)): + raise ValueError("pvals must contain probabilities between 0 and 1") + + if q is None: + threshold = np.ones(values.shape, dtype=float) + thresholds = np.exp(np.linspace(np.log(0.1), np.log(0.000001), 1000)) + for current in thresholds: + current_result = fdr(values, float(current), fdr_type=fdr_type) + threshold[current_result.mask] = current + return FDRResult(threshold, finite_mask & (values <= threshold)) + + q_value = float(q) + if not 0 <= q_value <= 1: + raise ValueError("q must be between 0 and 1") + + fdr_type_name = fdr_type.lower() + if fdr_type_name not in {"parametric", "nonparametric"}: + raise ValueError("fdr_type must be 'parametric' or 'nonparametric'") + + if finite_values.size == 0: + return FDRResult(0.0, np.zeros(values.shape, dtype=bool)) + + flat = np.sort(finite_values.reshape(-1)) + count = flat.size + indices = np.arange(1, count + 1, dtype=float) + correction = 1.0 if fdr_type_name == "parametric" else float(np.sum(1.0 / indices)) + accepted = flat <= indices / count * q_value / correction + threshold_value = float(flat[np.flatnonzero(accepted).max()]) if np.any(accepted) else 0.0 + return FDRResult(threshold_value, finite_mask & (values <= threshold_value)) + __all__ = ["FDRResult", "fdr"] diff --git a/src/eegprep/functions/statistics/stat_surrogate_ci.py b/src/eegprep/functions/statistics/stat_surrogate_ci.py index b886ccca..96b6ba46 100644 --- a/src/eegprep/functions/statistics/stat_surrogate_ci.py +++ b/src/eegprep/functions/statistics/stat_surrogate_ci.py @@ -1,5 +1,54 @@ """Surrogate confidence interval helper.""" -from eegprep.functions.statistics._core import stat_surrogate_ci +from __future__ import annotations + +from typing import Any + +import numpy as np + +from eegprep.functions.statistics._shared import as_numeric_array + + +def stat_surrogate_ci(distribution: Any, alpha: float = 0.05, tail: str = "both") -> np.ndarray: + """Compute surrogate confidence intervals along the last axis. + + Args: + distribution: Surrogate statistic array whose last axis stores + surrogate replications. + alpha: Type-I error rate. + tail: ``"upper"``, ``"lower"``, ``"one"``, or ``"both"``. + """ + + values = as_numeric_array(distribution, "distribution") + if values.ndim < 1: + raise ValueError("distribution must have a surrogate axis") + alpha_value = float(alpha) + if not 0 <= alpha_value <= 1: + raise ValueError("alpha must be between 0 and 1") + + sorted_values = np.sort(values, axis=-1) + n_samples = sorted_values.shape[-1] + ci_alpha = alpha_value / 2 if tail.lower() == "both" else alpha_value + low = int(np.floor(ci_alpha * n_samples + 0.5)) + high = n_samples - low + low_index = min(max(low, 0), n_samples - 1) + high_index = min(max(high - 1, 0), n_samples - 1) + mean_values = np.mean(sorted_values, axis=-1) + + tail_name = tail.lower() + if tail_name == "upper": + lower = mean_values + upper = sorted_values[..., high_index] + elif tail_name == "lower": + lower = sorted_values[..., low_index] + upper = mean_values + elif tail_name in {"both", "one"}: + lower = sorted_values[..., low_index] + upper = sorted_values[..., high_index] + else: + raise ValueError("tail must be 'upper', 'lower', 'one', or 'both'") + + return np.stack((lower, upper), axis=0) + __all__ = ["stat_surrogate_ci"] diff --git a/src/eegprep/functions/statistics/stat_surrogate_pvals.py b/src/eegprep/functions/statistics/stat_surrogate_pvals.py index c8c2c664..ab8a56f3 100644 --- a/src/eegprep/functions/statistics/stat_surrogate_pvals.py +++ b/src/eegprep/functions/statistics/stat_surrogate_pvals.py @@ -1,5 +1,45 @@ """Surrogate empirical p-value helper.""" -from eegprep.functions.statistics._core import stat_surrogate_pvals +from __future__ import annotations + +from typing import Any + +import numpy as np + +from eegprep.functions.statistics._shared import as_numeric_array + + +def stat_surrogate_pvals(distribution: Any, observed: Any, tail: str = "both") -> np.ndarray: + """Compute empirical p-values against a surrogate distribution. + + Args: + distribution: Surrogate statistic array whose last axis stores + surrogate replications. + observed: Observed statistic array matching ``distribution.shape[:-1]``. + tail: ``"right"``/``"upper"``/``"one"``, ``"left"``/``"lower"``, or + ``"both"``. + """ + + surrogates = as_numeric_array(distribution, "distribution") + observed_values = as_numeric_array(observed, "observed", require_axis=False) + if surrogates.ndim < 1: + raise ValueError("distribution must have a surrogate axis") + if observed_values.shape != surrogates.shape[:-1]: + raise ValueError("observed shape must match distribution without its last axis") + + n_samples = surrogates.shape[-1] + expanded = np.expand_dims(observed_values, axis=-1) + p_right = np.sum(surrogates >= expanded, axis=-1) / n_samples + tail_name = tail.lower() + if tail_name in {"right", "upper", "one"}: + return p_right + + p_left = 1 - p_right + np.sum(surrogates == expanded, axis=-1) / n_samples + if tail_name in {"left", "lower"}: + return p_left + if tail_name == "both": + return np.minimum(2 * np.minimum(p_right, p_left), 1.0) + raise ValueError("tail must be 'right', 'upper', 'left', 'lower', 'one', or 'both'") + __all__ = ["stat_surrogate_pvals"] diff --git a/src/eegprep/functions/statistics/statcond.py b/src/eegprep/functions/statistics/statcond.py index 57649b2e..0a818926 100644 --- a/src/eegprep/functions/statistics/statcond.py +++ b/src/eegprep/functions/statistics/statcond.py @@ -1,5 +1,262 @@ """Condition-level statistical comparison helper.""" -from eegprep.functions.statistics._core import StatcondResult, TwoWayEffects, statcond +from __future__ import annotations + +from collections.abc import Iterator, Sequence +from dataclasses import dataclass +from typing import Any + +import numpy as np +from scipy import stats as scipy_stats + +from eegprep.functions.statistics._shared import ( + TwoWayEffects, + condition_grid, + effect_map, + normalize_method, + paired_flag, +) +from eegprep.functions.statistics.anova1_cell import anova1_cell +from eegprep.functions.statistics.anova1rm_cell import anova1rm_cell +from eegprep.functions.statistics.anova2_cell import anova2_cell +from eegprep.functions.statistics.anova2rm_cell import anova2rm_cell +from eegprep.functions.statistics.stat_surrogate_ci import stat_surrogate_ci +from eegprep.functions.statistics.stat_surrogate_pvals import stat_surrogate_pvals +from eegprep.functions.statistics.surrogdistrib import SurrogateDistribution, surrogdistrib +from eegprep.functions.statistics.ttest2_cell import ttest2_cell +from eegprep.functions.statistics.ttest_cell import ttest_cell + + +@dataclass(frozen=True) +class StatcondResult: + """Result returned by :func:`statcond`.""" + + stat: Any + df: Any + pvalue: Any + surrogate: Any + method: str + paired: bool + ci: Any = None + mask: Any = None + + def __iter__(self) -> Iterator[Any]: + yield self.stat + yield self.df + yield self.pvalue + yield self.surrogate + + +def statcond( + data: Any, + *, + paired: str | bool = "auto", + method: str = "param", + mode: str | None = None, + naccu: int = 200, + variance: str = "homogenous", + forceanova: bool = False, + tail: str = "both", + axis: int = -1, + rng: np.random.Generator | int | None = None, + alpha: float | None = None, + surrog: Any = None, + stats: Any = None, + return_resampling_array: bool = False, +) -> StatcondResult | SurrogateDistribution: + """Compare condition arrays using EEGLAB-style t-tests or ANOVAs. + + Args: + data: One- or two-dimensional sequence of condition arrays. The case + dimension is the last axis by default. + paired: ``"auto"``, ``"on"``/``True``, or ``"off"``/``False``. + method: ``"param"``, ``"perm"``, or ``"bootstrap"``. + mode: Legacy alias for ``method``. + naccu: Number of surrogate samples for nonparametric methods. + variance: ``"homogenous"`` or ``"inhomogenous"`` for unpaired t-tests. + forceanova: Use one-way ANOVA instead of a two-condition t-test. + tail: Empirical-tail mode for supplied or computed surrogates. + axis: Axis in each condition array that stores cases. + rng: Optional NumPy generator or seed for deterministic resampling. + alpha: Optional threshold for confidence intervals and masks; requires + a nonparametric method or supplied surrogate statistics. + surrog: Precomputed surrogate statistic array. + stats: Observed statistic to pair with ``surrog``. + return_resampling_array: Return surrogate condition grids instead of + computing statistics. + """ + + method_name = normalize_method(mode or method) + grid = condition_grid(data, axis=axis, min_cases=2) + paired_flag_value = paired_flag(grid, paired) + if return_resampling_array: + if method_name == "param": + raise ValueError("return_resampling_array requires 'perm' or 'bootstrap'") + return surrogdistrib( + grid, + method=method_name, + pairing="on" if paired_flag_value else "off", + naccu=naccu, + rng=rng, + ) + + if surrog is not None: + if stats is None: + raise ValueError("stats must be supplied when surrog is supplied") + observed_stat = stats + observed_df = None + surrogate_stat = surrog + pvalue = _surrogate_pvalues(surrogate_stat, observed_stat, tail) + ci = None + mask = None + if alpha is not None: + ci = _surrogate_ci(surrogate_stat, alpha, _ci_tail(tail)) + mask = effect_map(pvalue, lambda value: value < alpha) + return StatcondResult( + observed_stat, observed_df, pvalue, surrogate_stat, method_name, paired_flag_value, ci=ci, mask=mask + ) + + observed_stat, observed_df, statistic_kind = _compute_statistic( + grid, + paired=paired_flag_value, + variance=variance, + forceanova=forceanova, + ) + surrogate_stat = None + if method_name == "param": + pvalue = _parametric_pvalues(observed_stat, observed_df, statistic_kind) + else: + surrogate_stat = _compute_surrogate_statistics( + grid, + paired=paired_flag_value, + method=method_name, + naccu=naccu, + variance=variance, + forceanova=forceanova, + rng=rng, + ) + empirical_tail = "one" if statistic_kind.startswith("f") else tail + pvalue = _surrogate_pvalues(surrogate_stat, observed_stat, empirical_tail) + + ci = None + mask = None + if alpha is not None: + if surrogate_stat is None: + raise ValueError("alpha confidence intervals require a nonparametric method or supplied surrogates") + empirical_tail = "one" if statistic_kind.startswith("f") else tail + ci = _surrogate_ci(surrogate_stat, alpha, _ci_tail(empirical_tail)) + mask = effect_map(pvalue, lambda value: value < alpha) + + return StatcondResult( + observed_stat, observed_df, pvalue, surrogate_stat, method_name, paired_flag_value, ci=ci, mask=mask + ) + + +def _compute_statistic( + grid: tuple[tuple[np.ndarray, ...], ...], + *, + paired: bool, + variance: str, + forceanova: bool, +) -> tuple[Any, Any, str]: + rows = len(grid) + columns = len(grid[0]) + if rows == 1: + if columns == 2 and not forceanova: + if paired: + stat, df = ttest_cell(grid[0][0], grid[0][1]) + else: + stat, df = ttest2_cell(grid[0][0], grid[0][1], variance=variance) + return stat, df, "t" + if paired: + stat, df = anova1rm_cell(grid[0]) + else: + stat, df = anova1_cell(grid[0]) + return stat, df, "f_one_way" + + anova = anova2rm_cell(grid) if paired else anova2_cell(grid) + return anova.as_effects(), anova.df_effects(), "f_two_way" + + +def _parametric_pvalues(stat: Any, df: Any, statistic_kind: str) -> Any: + if statistic_kind == "t": + return 2 * scipy_stats.t.sf(np.abs(stat), df) + if isinstance(stat, TwoWayEffects): + return TwoWayEffects( + scipy_stats.f.sf(stat.rows, df.rows[0], df.rows[1]), + scipy_stats.f.sf(stat.columns, df.columns[0], df.columns[1]), + scipy_stats.f.sf(stat.interaction, df.interaction[0], df.interaction[1]), + ) + return scipy_stats.f.sf(stat, df[0], df[1]) + + +def _compute_surrogate_statistics( + grid: tuple[tuple[np.ndarray, ...], ...], + *, + paired: bool, + method: str, + naccu: int, + variance: str, + forceanova: bool, + rng: np.random.Generator | int | None, +) -> Any: + distribution = surrogdistrib( + grid, + method=method, + pairing="on" if paired else "off", + naccu=naccu, + rng=rng, + ) + stats = [] + for sample in distribution: + sample_stat, _sample_df, _kind = _compute_statistic( + sample, + paired=paired, + variance=variance, + forceanova=forceanova, + ) + stats.append(sample_stat) + return _stack_effects(stats) + + +def _stack_effects(values: Sequence[Any]) -> Any: + first = values[0] + if isinstance(first, TwoWayEffects): + return TwoWayEffects( + np.stack([value.rows for value in values], axis=-1), + np.stack([value.columns for value in values], axis=-1), + np.stack([value.interaction for value in values], axis=-1), + ) + return np.stack(values, axis=-1) + + +def _surrogate_pvalues(surrogate: Any, observed: Any, tail: str) -> Any: + if isinstance(surrogate, TwoWayEffects): + return TwoWayEffects( + stat_surrogate_pvals(surrogate.rows, observed.rows, tail), + stat_surrogate_pvals(surrogate.columns, observed.columns, tail), + stat_surrogate_pvals(surrogate.interaction, observed.interaction, tail), + ) + return stat_surrogate_pvals(surrogate, observed, tail) + + +def _surrogate_ci(surrogate: Any, alpha: float, tail: str) -> Any: + if isinstance(surrogate, TwoWayEffects): + return TwoWayEffects( + stat_surrogate_ci(surrogate.rows, alpha, tail), + stat_surrogate_ci(surrogate.columns, alpha, tail), + stat_surrogate_ci(surrogate.interaction, alpha, tail), + ) + return stat_surrogate_ci(surrogate, alpha, tail) + + +def _ci_tail(tail: str) -> str: + tail_name = tail.lower() + if tail_name == "right": + return "upper" + if tail_name == "left": + return "lower" + return tail_name + __all__ = ["StatcondResult", "TwoWayEffects", "statcond"] diff --git a/src/eegprep/functions/statistics/surrogdistrib.py b/src/eegprep/functions/statistics/surrogdistrib.py index 9a652156..93e317ac 100644 --- a/src/eegprep/functions/statistics/surrogdistrib.py +++ b/src/eegprep/functions/statistics/surrogdistrib.py @@ -1,5 +1,133 @@ """Surrogate resampling helper.""" -from eegprep.functions.statistics._core import SurrogateDistribution, surrogdistrib +from __future__ import annotations + +from collections.abc import Iterator, Sequence +from dataclasses import dataclass +from typing import Any + +import numpy as np + +from eegprep.functions.statistics._shared import condition_grid, flatten_grid, normalize_method, rng_from_seed + + +@dataclass(frozen=True) +class SurrogateDistribution: + """Surrogate condition grids produced by permutation or bootstrap.""" + + samples: tuple[tuple[tuple[np.ndarray, ...], ...], ...] + + def __iter__(self) -> Iterator[tuple[tuple[np.ndarray, ...], ...]]: + return iter(self.samples) + + def __len__(self) -> int: + return len(self.samples) + + +def surrogdistrib( + data: Any, + *, + method: str = "perm", + pairing: str = "on", + naccu: int = 1, + axis: int = -1, + rng: np.random.Generator | int | None = None, +) -> SurrogateDistribution: + """Build bootstrap or permutation surrogate condition grids. + + Args: + data: One- or two-dimensional sequence of condition arrays. + method: ``"perm"``/``"permutation"`` or ``"bootstrap"``. + pairing: ``"on"`` to preserve case identity across conditions or + ``"off"`` to resample from the pooled case axis. + naccu: Number of surrogate grids to generate. + axis: Axis in each condition array that stores cases. + rng: Optional NumPy generator or seed for deterministic resampling. + """ + + method_name = normalize_method(method) + if method_name == "param": + raise ValueError("surrogdistrib only supports permutation or bootstrap methods") + pairing_name = pairing.lower() + if pairing_name not in {"on", "off"}: + raise ValueError("pairing must be 'on' or 'off'") + count = int(naccu) + if count < 1: + raise ValueError("naccu must be at least 1") + + generator = rng_from_seed(rng) + grid = condition_grid(data, axis=axis, min_cases=1) + samples = tuple( + _resampled_grid(grid, bootstrap=method_name == "bootstrap", paired=pairing_name == "on", rng=generator) + for _ in range(count) + ) + return SurrogateDistribution(samples) + + +def _resampled_grid( + grid: tuple[tuple[np.ndarray, ...], ...], + *, + bootstrap: bool, + paired: bool, + rng: np.random.Generator, +) -> tuple[tuple[np.ndarray, ...], ...]: + arrays = flatten_grid(grid) + feature_shape = arrays[0].shape[:-1] + for index, array in enumerate(arrays): + if array.shape[:-1] != feature_shape: + raise ValueError(f"condition {index} has feature shape {array.shape[:-1]}, expected {feature_shape}") + counts = [array.shape[-1] for array in arrays] + + if paired: + if len(set(counts)) != 1: + raise ValueError("paired surrogate resampling requires equal case counts") + resampled = _paired_resample(arrays, bootstrap=bootstrap, rng=rng) + else: + resampled = _unpaired_resample(arrays, counts, bootstrap=bootstrap, rng=rng) + + iterator = iter(resampled) + return tuple(tuple(next(iterator) for _column in row) for row in grid) + + +def _paired_resample( + arrays: Sequence[np.ndarray], + *, + bootstrap: bool, + rng: np.random.Generator, +) -> list[np.ndarray]: + n_conditions = len(arrays) + n_cases = arrays[0].shape[-1] + output = [np.empty_like(array) for array in arrays] + for case_index in range(n_cases): + source_conditions = ( + rng.integers(0, n_conditions, size=n_conditions) if bootstrap else rng.permutation(n_conditions) + ) + for target_condition, source_condition in enumerate(source_conditions): + output[target_condition][..., case_index] = arrays[int(source_condition)][..., case_index] + return output + + +def _unpaired_resample( + arrays: Sequence[np.ndarray], + counts: Sequence[int], + *, + bootstrap: bool, + rng: np.random.Generator, +) -> list[np.ndarray]: + pooled = np.concatenate(arrays, axis=-1) + total_cases = pooled.shape[-1] + if bootstrap: + indices = rng.integers(0, total_cases, size=total_cases) + else: + indices = rng.permutation(total_cases) + + output = [] + start = 0 + for count in counts: + stop = start + count + output.append(np.take(pooled, indices[start:stop], axis=-1)) + start = stop + return output + __all__ = ["SurrogateDistribution", "surrogdistrib"] diff --git a/src/eegprep/functions/statistics/teststat.py b/src/eegprep/functions/statistics/teststat.py index 1a012104..bb429f58 100644 --- a/src/eegprep/functions/statistics/teststat.py +++ b/src/eegprep/functions/statistics/teststat.py @@ -1,5 +1,53 @@ """Statistics package smoke-test helper.""" -from eegprep.functions.statistics._core import teststat +from __future__ import annotations + +import numpy as np + +from eegprep.functions.statistics._shared import TwoWayEffects +from eegprep.functions.statistics.anova1_cell import anova1_cell +from eegprep.functions.statistics.statcond import StatcondResult, statcond +from eegprep.functions.statistics.ttest_cell import ttest_cell + + +def teststat(seed: int = 0) -> dict[str, float]: + """Run deterministic smoke checks for the EEGPrep statistics package.""" + + rng = np.random.default_rng(seed) + first = rng.normal(size=(3, 12)) + second = first + rng.normal(loc=0.25, scale=0.2, size=(3, 12)) + paired_result = statcond([first, second], paired="on", method="param") + if not isinstance(paired_result, StatcondResult): + raise AssertionError("paired statcond unexpectedly returned surrogate grids") + t_values, df = ttest_cell(first, second) + np.testing.assert_allclose(paired_result.stat, t_values) + if paired_result.df != df: + raise AssertionError("paired t-test degrees of freedom changed") + + groups = [rng.normal(loc=offset, size=(3, 10)) for offset in (0.0, 0.2, 0.5)] + one_way = statcond(groups, paired="off", method="param") + if not isinstance(one_way, StatcondResult): + raise AssertionError("one-way statcond unexpectedly returned surrogate grids") + direct_one_way, one_way_df = anova1_cell(groups) + np.testing.assert_allclose(one_way.stat, direct_one_way) + if one_way.df != one_way_df: + raise AssertionError("one-way ANOVA degrees of freedom changed") + + grid = ( + (rng.normal(size=(2, 9)), rng.normal(loc=0.1, size=(2, 9))), + (rng.normal(loc=0.2, size=(2, 9)), rng.normal(loc=0.4, size=(2, 9))), + ) + two_way = statcond(grid, paired="on", method="param") + if not isinstance(two_way, StatcondResult): + raise AssertionError("two-way statcond unexpectedly returned surrogate grids") + if not isinstance(two_way.stat, TwoWayEffects): + raise AssertionError("two-way statcond did not return factor effects") + + return { + "paired_t_mean": float(np.mean(paired_result.stat)), + "one_way_f_mean": float(np.mean(one_way.stat)), + "two_way_interaction_mean": float(np.mean(two_way.stat.interaction)), + } + __all__ = ["teststat"] diff --git a/src/eegprep/functions/statistics/ttest2_cell.py b/src/eegprep/functions/statistics/ttest2_cell.py index a84856e0..53e36c1f 100644 --- a/src/eegprep/functions/statistics/ttest2_cell.py +++ b/src/eegprep/functions/statistics/ttest2_cell.py @@ -1,5 +1,57 @@ """Unpaired t-test helper.""" -from eegprep.functions.statistics._core import ttest2_cell +from __future__ import annotations + +from typing import Any + +import numpy as np + +from eegprep.functions.statistics._shared import stat_mean, stat_std, two_arrays + + +def ttest2_cell( + a: Any, + b: Any | None = None, + variance: str = "homogenous", + *, + axis: int = -1, +) -> tuple[np.ndarray, np.ndarray | int]: + """Compute unpaired t-statistics across the case axis.""" + + if isinstance(b, str): + variance = b + b = None + first, second = two_arrays(a, b, "ttest2_cell", axis=axis) + if first.shape[:-1] != second.shape[:-1]: + raise ValueError("ttest2_cell requires matching feature shapes before the case axis") + if first.shape[-1] < 2 or second.shape[-1] < 2: + raise ValueError("ttest2_cell requires at least two cases in each group") + + variance_name = variance.lower() + if variance_name not in {"homogenous", "inhomogenous"}: + raise ValueError("variance must be 'homogenous' or 'inhomogenous'") + + first_n = first.shape[-1] + second_n = second.shape[-1] + first_mean = stat_mean(first, axis=-1) + second_mean = stat_mean(second, axis=-1) + if variance_name == "inhomogenous": + first_scaled = np.var(first, axis=-1, ddof=1) / first_n + second_scaled = np.var(second, axis=-1, ddof=1) / second_n + standard_error = np.sqrt(first_scaled + second_scaled) + with np.errstate(divide="ignore", invalid="ignore"): + t_values = (first_mean - second_mean) / standard_error + df = (first_scaled + second_scaled) ** 2 / ( + first_scaled**2 / (first_n - 1) + second_scaled**2 / (second_n - 1) + ) + return t_values, df + + first_sd = stat_std(first, axis=-1) + second_sd = stat_std(second, axis=-1) + pooled_sd = np.sqrt(((first_n - 1) * first_sd**2 + (second_n - 1) * second_sd**2) / (first_n + second_n - 2)) + with np.errstate(divide="ignore", invalid="ignore"): + t_values = (first_mean - second_mean) / pooled_sd / np.sqrt(1 / first_n + 1 / second_n) + return t_values, first_n + second_n - 2 + __all__ = ["ttest2_cell"] diff --git a/src/eegprep/functions/statistics/ttest_cell.py b/src/eegprep/functions/statistics/ttest_cell.py index 86467d5f..c68539cf 100644 --- a/src/eegprep/functions/statistics/ttest_cell.py +++ b/src/eegprep/functions/statistics/ttest_cell.py @@ -1,5 +1,30 @@ """Paired t-test helper.""" -from eegprep.functions.statistics._core import ttest_cell +from __future__ import annotations + +from typing import Any + +import numpy as np + +from eegprep.functions.statistics._shared import stat_mean, stat_std, two_arrays + + +def ttest_cell(a: Any, b: Any | None = None, *, axis: int = -1) -> tuple[np.ndarray, int]: + """Compute paired t-statistics across the case axis.""" + + first, second = two_arrays(a, b, "ttest_cell", axis=axis) + if first.shape != second.shape: + raise ValueError("ttest_cell requires paired arrays with identical shapes") + n_cases = first.shape[-1] + if n_cases < 2: + raise ValueError("ttest_cell requires at least two paired cases") + + difference = first - second + mean_difference = stat_mean(difference, axis=-1) + sd_difference = stat_std(difference, axis=-1) + with np.errstate(divide="ignore", invalid="ignore"): + t_values = mean_difference / sd_difference * np.sqrt(n_cases) + return t_values, n_cases - 1 + __all__ = ["ttest_cell"] diff --git a/src/eegprep/functions/studyfunc/_std_measureplot.py b/src/eegprep/functions/studyfunc/_std_measureplot.py index 7c48e13c..8ef19302 100644 --- a/src/eegprep/functions/studyfunc/_std_measureplot.py +++ b/src/eegprep/functions/studyfunc/_std_measureplot.py @@ -8,8 +8,8 @@ import numpy as np from eegprep.functions.popfunc._pop_utils import is_on, parse_key_value_args -from eegprep.functions.studyfunc._study_utils import build_python_call -from eegprep.functions.studyfunc.std_readdata import MEASURE_DATA_FIELDS, std_readdata +from eegprep.functions.studyfunc._study_utils import MEASURE_DATA_FIELDS, build_python_call, range_mask +from eegprep.functions.studyfunc.std_readdata import std_readdata LINE_MEASURES = {"erp", "spec"} @@ -68,7 +68,7 @@ def std_measureplot( design=design, ) data, x_axis, y_axis = _apply_ranges(data, datatype, x_axis, y_axis, timerange, freqrange) - figure = None if is_on(noplot) or plotmode == "none" else _plot_measure(data, datatype, x_axis, y_axis) + figure = None if is_on(noplot) or plotmode == "none" else plot_measure_data(data, datatype, x_axis, y_axis) command = _history_command( datatype, channels=channels, @@ -132,53 +132,79 @@ def _axis_subset(axis: np.ndarray, bounds: Any) -> np.ndarray: def _axis_mask(axis: np.ndarray, bounds: Any) -> np.ndarray | None: - if bounds is None or (isinstance(bounds, str) and bounds == ""): - return None - values = np.asarray(bounds, dtype=float).ravel() - if values.size == 0: + mask = range_mask( + axis, + bounds, + name="range options", + empty_message="range option does not include any measure samples", + ) + if mask.size == 0 or np.all(mask): return None - if values.size != 2: - raise ValueError("range options must contain [min max]") - mask = (axis >= values[0]) & (axis <= values[1]) - if not np.any(mask): - raise ValueError("range option does not include any measure samples") return mask -def _plot_measure(data: list[np.ndarray], datatype: str, x_axis: np.ndarray, y_axis: np.ndarray) -> Any: +def plot_measure_data( + data: list[np.ndarray], + datatype: str, + x_axis: np.ndarray, + y_axis: np.ndarray, + *, + title: str | None = None, + line_labels: list[str] | None = None, +) -> Any: + """Plot cached STUDY measure arrays read by ``std_readdata``.""" if datatype in LINE_MEASURES: - return _plot_lines(data, datatype, x_axis) - return _plot_image(data, datatype, x_axis, y_axis) + return _plot_lines(data, datatype, x_axis, title=title, line_labels=line_labels) + return _plot_image(data, datatype, x_axis, y_axis, title=title) -def _plot_lines(data: list[np.ndarray], datatype: str, x_axis: np.ndarray) -> Any: +def _plot_lines( + data: list[np.ndarray], + datatype: str, + x_axis: np.ndarray, + *, + title: str | None, + line_labels: list[str] | None, +) -> Any: fig, ax = plt.subplots(figsize=(8, 4.5)) - for label, values in _line_series(data): + for label, values in _line_series(data, line_labels): ax.plot(x_axis, values, label=label) ax.set_xlabel("Time (ms)" if datatype == "erp" else "Frequency (Hz)") ax.set_ylabel("uV" if datatype == "erp" else "Power 10*log10(uV^2/Hz)") - ax.set_title(f"STUDY {datatype.upper()}") + ax.set_title(title or f"STUDY {datatype.upper()}") ax.grid(True, alpha=0.25) ax.legend(fontsize=8) fig.tight_layout() return fig -def _line_series(data: list[np.ndarray]) -> list[tuple[str, np.ndarray]]: +def _line_series(data: list[np.ndarray], line_labels: list[str] | None) -> list[tuple[str, np.ndarray]]: series = [] + labels = iter(line_labels or []) for group_index, values in enumerate(data, start=1): array = np.asarray(values, dtype=float) if array.ndim == 2: - series.append((f"Group {group_index}", np.nanmean(array, axis=0))) + series.append((_next_label(labels, f"Group {group_index}"), np.nanmean(array, axis=0))) elif array.ndim == 3: for component_index, component_values in enumerate(np.nanmean(array, axis=0), start=1): - series.append((f"IC {component_index}", component_values)) + series.append((_next_label(labels, f"IC {component_index}"), component_values)) else: raise ValueError("line measure data must be 2-D or 3-D") return series -def _plot_image(data: list[np.ndarray], datatype: str, x_axis: np.ndarray, y_axis: np.ndarray) -> Any: +def _next_label(labels: Any, fallback: str) -> str: + return next(labels, fallback) + + +def _plot_image( + data: list[np.ndarray], + datatype: str, + x_axis: np.ndarray, + y_axis: np.ndarray, + *, + title: str | None, +) -> Any: images = [] for values in data: array = np.asarray(values, dtype=float) @@ -198,7 +224,7 @@ def _plot_image(data: list[np.ndarray], datatype: str, x_axis: np.ndarray, y_axi ) ax.set_xlabel("Time (ms)") ax.set_ylabel("Frequency (Hz)") - ax.set_title(f"STUDY {datatype.upper()}") + ax.set_title(title or f"STUDY {datatype.upper()}") fig.colorbar(mesh, ax=ax) fig.tight_layout() return fig diff --git a/src/eegprep/functions/studyfunc/_study_utils.py b/src/eegprep/functions/studyfunc/_study_utils.py index 2fd4653b..778db4b8 100644 --- a/src/eegprep/functions/studyfunc/_study_utils.py +++ b/src/eegprep/functions/studyfunc/_study_utils.py @@ -62,6 +62,9 @@ "topopol", "dipoles", ) +MEASURE_DATA_FIELDS = {"erp": "erpdata", "spec": "specdata", "ersp": "erspdata", "itc": "itcdata"} +MEASURE_X_AXIS_FIELDS = {"erp": "erptimes", "spec": "specfreqs", "ersp": "ersptimes", "itc": "itctimes"} +MEASURE_Y_AXIS_FIELDS = {"ersp": "erspfreqs", "itc": "itcfreqs"} def as_alleeg_list(ALLEEG: Any) -> list[dict[str, Any]]: @@ -361,17 +364,103 @@ def unique_preserving_order(values: list[int]) -> list[int]: def trialinfo_rows(value: Any) -> list[dict[str, Any]]: """Normalize STUDY trialinfo into a list of row dictionaries.""" + if isinstance(value, dict) and "trialinfo" in value: + value = value.get("trialinfo") if value is None: return [] if isinstance(value, np.ndarray): value = value.tolist() if isinstance(value, dict): - return [value] + lengths = [len(item) for item in value.values() if isinstance(item, (list, tuple, np.ndarray))] + if not lengths: + return [value] + rows = [] + count = max(lengths) + for index in range(count): + row = {} + for key, column in value.items(): + if isinstance(column, np.ndarray): + column = column.tolist() + if isinstance(column, (list, tuple)): + if index < len(column): + row[key] = column[index] + elif not _empty_value(column): + row[key] = deepcopy(column) + rows.append(row) + return rows if not isinstance(value, list): return [] return [row for row in value if isinstance(row, dict)] +def cached_measure_axis( + group: dict[str, Any], measureinfo_key: str, fallback_key: str, count: int, *, fallback_start: int = 1 +) -> np.ndarray: + """Return cached STUDY measure-axis IDs with EEGLAB-facing fallbacks.""" + raw_measureinfo = group.get("measureinfo") + measureinfo: dict[str, Any] = raw_measureinfo if isinstance(raw_measureinfo, dict) else {} + values = _as_int_vector(measureinfo.get(measureinfo_key)) + if values.size == count: + return values.astype(int) + values = _as_int_vector(group.get(fallback_key)) + unique_values = np.asarray(unique_preserving_order(values.tolist()), dtype=int) + if unique_values.size == count: + return unique_values + return np.arange(fallback_start, fallback_start + count, dtype=int) + + +def component_measure_axis(group: dict[str, Any], count: int) -> np.ndarray: + """Return cached component IDs for a STUDY component-measure axis.""" + return cached_measure_axis(group, "components", "comps", count, fallback_start=1) + + +def component_dataset_axis(group: dict[str, Any], count: int) -> np.ndarray: + """Return cached STUDY dataset IDs for a component-measure axis.""" + return cached_measure_axis(group, "datasets", "sets", count, fallback_start=1) + + +def axis_position(axis: np.ndarray, value: int, label: str, *, context: str = "Component measure cache") -> int: + """Return the zero-based position of an EEGLAB-facing cached-axis value.""" + matches = np.where(np.asarray(axis) == value)[0] + if matches.size == 0: + raise ValueError(f"{context} is missing {label}") + return int(matches[0]) + + +def range_mask(axis: Any, bounds: Any, *, name: str = "range filter", empty_message: str | None = None) -> np.ndarray: + """Return an inclusive mask for a cached measure axis.""" + axis_values = np.asarray(axis, dtype=float).ravel() + if axis_values.size == 0: + return np.asarray([], dtype=bool) + values = parse_numeric_sequence(bounds, dtype=float) + if len(values) == 0: + return np.ones(axis_values.size, dtype=bool) + if len(values) != 2: + raise ValueError(f"{name} must contain [min max]") + mask = (axis_values >= values[0]) & (axis_values <= values[1]) + if not np.any(mask): + raise ValueError(empty_message or f"{name} does not include any cached measure samples") + return mask + + +def value_matches(value: Any, expected: Any) -> bool: + """Return whether a trialinfo value matches a categorical factor level.""" + if isinstance(expected, np.ndarray): + expected = expected.tolist() + if isinstance(expected, (list, tuple, set)): + return any(value_matches(value, item) for item in expected) + return equal_value(value, expected) + + +def equal_value(left: Any, right: Any) -> bool: + """Compare MATLAB-loaded scalar/list values without numpy ambiguity.""" + if isinstance(left, np.ndarray): + left = left.tolist() + if isinstance(right, np.ndarray): + right = right.tolist() + return left == right + + def clear_study_data_fields(study: dict[str, Any]) -> dict[str, Any]: """Remove precomputed measure arrays from STUDY channel/component groups.""" study = deepcopy(study) @@ -424,6 +513,15 @@ def _normalize_optional_number(value: Any) -> Any: return value +def _as_int_vector(value: Any) -> np.ndarray: + if value is None: + return np.asarray([], dtype=int) + array = np.asarray(value, dtype=int) + if array.size == 0: + return np.asarray([], dtype=int) + return array.ravel() + + def _unique_values(values: Any) -> list[Any]: unique: list[Any] = [] seen: set[str] = set() diff --git a/src/eegprep/functions/studyfunc/pop_chanplot.py b/src/eegprep/functions/studyfunc/pop_chanplot.py index a069627c..5e60e275 100644 --- a/src/eegprep/functions/studyfunc/pop_chanplot.py +++ b/src/eegprep/functions/studyfunc/pop_chanplot.py @@ -18,6 +18,8 @@ numeric_vector, python_literal, ) +from eegprep.functions.studyfunc._std_measureplot import plot_measure_data +from eegprep.functions.studyfunc._study_utils import MEASURE_DATA_FIELDS from eegprep.functions.studyfunc.std_readdata import component_measure_axis, component_measure_selection, std_readdata @@ -194,10 +196,16 @@ def _plot_channel_measure( ) -> tuple[Any, list[int]]: if _has_cached_channels(study, measure): groups = _cached_channel_groups(study, channels) - if measure in {"erp", "spec"}: - fig = _plot_cached_lines(groups, measure, title=str(study.get("name") or "STUDY channel measures")) - else: - fig = _plot_cached_image(groups, measure, title=str(study.get("name") or "STUDY channel measures")) + _study, data, x_axis, y_axis = std_readdata(study, datasets, datatype=measure, channels=channels) + labels = [str(group.get("name") or "channel") for group in groups] + fig = plot_measure_data( + data, + measure, + x_axis, + y_axis, + title=str(study.get("name") or "STUDY channel measures"), + line_labels=labels, + ) return fig, [_group_channel_index(group, index) for index, group in enumerate(groups, start=1)] if measure != "erp": raise ValueError(f"{measure.upper()} channel measures have not been precomputed") @@ -207,26 +215,20 @@ def _plot_channel_measure( def _plot_component_measure(study: dict[str, Any], components: Any, measure: str) -> tuple[Any, list[int]]: parent = (study.get("cluster") or [{}])[0] - raw_data = np.asarray(parent.get(_field(measure), []), dtype=float) + raw_data = np.asarray(parent.get(MEASURE_DATA_FIELDS[measure], []), dtype=float) component_axis = component_measure_axis(parent, raw_data.shape[1] if raw_data.ndim >= 2 else 0) positions = component_measure_selection(components, component_axis) selected = component_axis[positions].astype(int).tolist() _study, data, x_axis, y_axis = std_readdata(study, datatype=measure, clusters=1, components=components) - values = data[0] - if measure in {"erp", "spec"}: - fig, ax = plt.subplots(figsize=(8, 4.5)) - y_values = np.nanmean(values, axis=0) - for index, component_values in zip(selected, y_values): - ax.plot(x_axis, component_values, label=f"IC {index}") - ax.set_xlabel("Time (ms)" if measure == "erp" else "Frequency (Hz)") - ax.set_ylabel("uV" if measure == "erp" else "Power 10*log10(uV^2/Hz)") - ax.set_title(str(study.get("name") or f"STUDY component {measure.upper()}")) - ax.grid(True, alpha=0.25) - ax.legend(fontsize=8) - fig.tight_layout() - return fig, selected - image = np.nanmean(values, axis=(0, 1)) - return _plot_image(image, x_axis, y_axis, str(study.get("name") or f"STUDY component {measure.upper()}")), selected + fig = plot_measure_data( + data, + measure, + x_axis, + y_axis, + title=str(study.get("name") or f"STUDY component {measure.upper()}"), + line_labels=[f"IC {index}" for index in selected], + ) + return fig, selected def _plot_button(label: str, measure: str, tag: str, *, enabled: bool = True) -> ControlSpec: @@ -280,51 +282,16 @@ def _plot_loaded_erp(study: dict[str, Any], datasets: list[dict[str, Any]], chan return fig, (selected + 1).tolist() -def _plot_cached_lines(groups: list[dict[str, Any]], measure: str, *, title: str) -> Any: - fig, ax = plt.subplots(figsize=(8, 4.5)) - x_axis = _axis(groups[0], measure) - for group in groups: - data = _data(group, measure) - ax.plot(x_axis, np.nanmean(data, axis=0), label=str(group.get("name") or "channel")) - ax.set_xlabel("Time (ms)" if measure == "erp" else "Frequency (Hz)") - ax.set_ylabel("uV" if measure == "erp" else "Power 10*log10(uV^2/Hz)") - ax.set_title(title) - ax.grid(True, alpha=0.25) - ax.legend(fontsize=8) - fig.tight_layout() - return fig - - -def _plot_cached_image(groups: list[dict[str, Any]], measure: str, *, title: str) -> Any: - images = [_data(group, measure) for group in groups] - image = np.nanmean(np.stack(images, axis=0), axis=(0, 1)) - return _plot_image(image, _axis(groups[0], measure), _freq_axis(groups[0], measure), title) - - -def _plot_image(image: np.ndarray, x_axis: np.ndarray, y_axis: np.ndarray, title: str) -> Any: - fig, ax = plt.subplots(figsize=(7, 4.8)) - mesh = ax.imshow( - image, - aspect="auto", - origin="lower", - extent=[float(x_axis[0]), float(x_axis[-1]), float(y_axis[0]), float(y_axis[-1])], - ) - ax.set_xlabel("Time (ms)") - ax.set_ylabel("Frequency (Hz)") - ax.set_title(title) - fig.colorbar(mesh, ax=ax) - fig.tight_layout() - return fig - - def _has_cached_channels(study: dict[str, Any], measure: str) -> bool: - return any(isinstance(group, dict) and _field(measure) in group for group in study.get("changrp") or []) + return any( + isinstance(group, dict) and MEASURE_DATA_FIELDS[measure] in group for group in study.get("changrp") or [] + ) def _has_cached_components(study: dict[str, Any], measure: str) -> bool: clusters = study.get("cluster") or [] parent = clusters[0] if clusters and isinstance(clusters[0], dict) else {} - return _field(measure) in parent + return MEASURE_DATA_FIELDS[measure] in parent def _cached_channel_groups(study: dict[str, Any], channels: Any) -> list[dict[str, Any]]: @@ -361,26 +328,6 @@ def _all_cached_channels_requested(channels: Any) -> bool: return channels is None or (isinstance(channels, str) and channels in {"", "channels"}) -def _data(group: dict[str, Any], measure: str) -> np.ndarray: - if _field(measure) not in group: - raise ValueError(f"{measure.upper()} channel measures have not been precomputed") - return np.asarray(group[_field(measure)], dtype=float) - - -def _axis(group: dict[str, Any], measure: str) -> np.ndarray: - field = {"erp": "erptimes", "spec": "specfreqs", "ersp": "ersptimes", "itc": "itctimes"}[measure] - return np.asarray(group[field], dtype=float) - - -def _freq_axis(group: dict[str, Any], measure: str) -> np.ndarray: - field = {"ersp": "erspfreqs", "itc": "itcfreqs"}[measure] - return np.asarray(group[field], dtype=float) - - -def _field(measure: str) -> str: - return {"erp": "erpdata", "spec": "specdata", "ersp": "erspdata", "itc": "itcdata"}[measure] - - def _group_channel_index(group: dict[str, Any], fallback: int) -> int: inds = group.get("inds") or [fallback] return int(inds[0]) diff --git a/src/eegprep/functions/studyfunc/pop_clust.py b/src/eegprep/functions/studyfunc/pop_clust.py index f9305d91..23c1275f 100644 --- a/src/eegprep/functions/studyfunc/pop_clust.py +++ b/src/eegprep/functions/studyfunc/pop_clust.py @@ -11,6 +11,7 @@ from eegprep.functions.popfunc._plot_utils import numeric_vector from eegprep.functions.studyfunc._cluster_kmeans import kmeans_labels from eegprep.functions.studyfunc._cluster_utils import checked_study_and_datasets, cluster_command +from eegprep.functions.studyfunc.robust_kmeans import robust_kmeans from eegprep.functions.studyfunc.std_createclust import std_createclust @@ -64,16 +65,26 @@ def pop_clust( if clus_num > data.shape[0]: raise ValueError("Number of clusters cannot exceed the number of preclustered components") - labels, centers = kmeans_labels(data, clus_num, random_state) + algorithm_provenance = ["Kmeans", clus_num] if np.isfinite(outliers): if outliers <= 0: raise ValueError("Outlier threshold must be greater than 0") - labels = _mark_outliers(data, labels, centers, outliers) + labels, _centers, _sumd, _distances, _outlier_rows = robust_kmeans( + data, + clus_num, + STD=outliers, + MAXiter=5, + method=algorithm, + random_state=random_state, + ) + algorithm_provenance = ["robust_kmeans", clus_num] + else: + labels, _centers = kmeans_labels(data, clus_num, random_state) study = std_createclust( study, datasets, clusterind=labels, - algorithm=["Kmeans", clus_num], + algorithm=algorithm_provenance, name="Cls", ) command = cluster_command( @@ -130,27 +141,6 @@ def pop_clust_dialog_spec(STUDY: dict[str, Any]) -> DialogSpec: ) -def _mark_outliers(data: np.ndarray, labels: np.ndarray, centers: np.ndarray, threshold: float) -> np.ndarray: - output = labels.copy() - cluster_distances = [] - for label in sorted(set(labels.tolist())): - rows = np.flatnonzero(labels == label) - if rows.size: - cluster_distances.append(np.linalg.norm(data[rows] - centers[label - 1], axis=1)) - if not cluster_distances: - return output - reference_distance = float(np.mean([np.mean(distances) for distances in cluster_distances if distances.size])) - for label, distances in zip(sorted(set(labels.tolist())), cluster_distances): - rows = np.flatnonzero(labels == label) - distances = np.linalg.norm(data[rows] - centers[label - 1], axis=1) - spread = float(np.std(distances)) - if spread == 0: - continue - outlier_mask = (distances > spread * threshold) & (distances > reference_distance * threshold) - output[rows[outlier_mask]] = 0 - return output - - def _algorithm_from_gui(value: Any) -> str: try: index = int(value) - 1 diff --git a/src/eegprep/functions/studyfunc/std_builddesignmat.py b/src/eegprep/functions/studyfunc/std_builddesignmat.py index 7eb5b45a..64886ff7 100644 --- a/src/eegprep/functions/studyfunc/std_builddesignmat.py +++ b/src/eegprep/functions/studyfunc/std_builddesignmat.py @@ -6,6 +6,8 @@ import numpy as np +from eegprep.functions.studyfunc._study_utils import trialinfo_rows, value_matches + def std_builddesignmat( design: dict[str, Any], trialinfo: list[dict[str, Any]] | dict[str, Any], expanding: int | bool = False @@ -16,7 +18,7 @@ def std_builddesignmat( 1-based level numbers, continuous factors keep their numeric values, and a constant column is appended at the end. """ - rows = _trial_rows(trialinfo) + rows = trialinfo_rows(trialinfo) variables = [variable for variable in design.get("variable") or [] if isinstance(variable, dict)] variables = [variable for variable in variables if str(variable.get("label") or "") != "group"] matrix = np.full((len(rows), len(variables)), np.nan, dtype=float) @@ -51,25 +53,6 @@ def std_builddesignmat( return matrix, labels, categorical -def _trial_rows(trialinfo: list[dict[str, Any]] | dict[str, Any]) -> list[dict[str, Any]]: - if isinstance(trialinfo, list): - return [row if isinstance(row, dict) else {} for row in trialinfo] - if not isinstance(trialinfo, dict): - return [] - lengths = [len(value) for value in trialinfo.values() if isinstance(value, (list, tuple, np.ndarray))] - count = max(lengths) if lengths else 0 - rows = [] - for index in range(count): - row = {} - for key, values in trialinfo.items(): - if isinstance(values, np.ndarray): - values = values.tolist() - if isinstance(values, (list, tuple)) and index < len(values): - row[key] = values[index] - rows.append(row) - return rows - - def _levels(variable: dict[str, Any]) -> list[Any]: values = variable.get("value") if isinstance(values, np.ndarray): @@ -89,11 +72,7 @@ def _matching_level(value: Any, levels: list[Any]) -> int | None: def _level_contains(level: Any, value: Any) -> bool: - if isinstance(level, np.ndarray): - level = level.tolist() - if isinstance(level, (list, tuple, set)): - return any(_level_contains(item, value) for item in level) - return level == value + return value_matches(value, level) def _expand_categorical( diff --git a/src/eegprep/functions/studyfunc/std_clustplot.py b/src/eegprep/functions/studyfunc/std_clustplot.py index e96d5f72..f3667c70 100644 --- a/src/eegprep/functions/studyfunc/std_clustplot.py +++ b/src/eegprep/functions/studyfunc/std_clustplot.py @@ -8,6 +8,7 @@ import numpy as np from eegprep.functions.studyfunc._cluster_utils import checked_study_and_datasets, cluster_at, cluster_list +from eegprep.functions.studyfunc._study_utils import build_python_call def std_clustplot( @@ -40,7 +41,7 @@ def std_clustplot( ax.set_title(f"{study.get('name') or 'STUDY'} cluster {measure} summary") ax.grid(axis="y", alpha=0.25) fig.tight_layout() - command = f"fig = std_clustplot(STUDY, ALLEEG, clusters={cluster_indices}, measure={measure!r})" + command = build_python_call(("FIGURE",), "std_clustplot", "STUDY", "ALLEEG", clusters=clusters, measure=measure) return (study, command, fig) if return_com else fig diff --git a/src/eegprep/functions/studyfunc/std_getindvar.py b/src/eegprep/functions/studyfunc/std_getindvar.py index 231befe6..79b7c246 100644 --- a/src/eegprep/functions/studyfunc/std_getindvar.py +++ b/src/eegprep/functions/studyfunc/std_getindvar.py @@ -4,9 +4,13 @@ from typing import Any -import numpy as np - -from eegprep.functions.studyfunc._study_utils import RESERVED_VARIABLE_FIELDS, _empty_value, ensure_study +from eegprep.functions.studyfunc._study_utils import ( + RESERVED_VARIABLE_FIELDS, + _empty_value, + ensure_study, + equal_value, + trialinfo_rows, +) DATASETINFO_EXCLUDE = RESERVED_VARIABLE_FIELDS | {"ncomps", "subject"} @@ -71,13 +75,13 @@ def _append_trial_factors( ) -> None: labels = [] for info in infos: - for row in _trial_rows(info.get("trialinfo")): + for row in trialinfo_rows(info.get("trialinfo")): for label, value in row.items(): if not _empty_value(value) and label not in labels: labels.append(label) for label in labels: values = _unique_values( - row.get(label) for info in infos for row in _trial_rows(info.get("trialinfo")) if label in row + row.get(label) for info in infos for row in trialinfo_rows(info.get("trialinfo")) if label in row ) if len(values) <= 1 or label in factors: continue @@ -87,7 +91,7 @@ def _append_trial_factors( _unique_values( info.get("subject") for info in infos - if any(_equal_value(row.get(label), value) for row in _trial_rows(info.get("trialinfo"))) + if any(_equal_value(row.get(label), value) for row in trialinfo_rows(info.get("trialinfo"))) ) for value in values ] @@ -114,18 +118,6 @@ def _append_design_values(study: dict[str, Any], factors: list[str], factorvals: factorvals[index].append(value) -def _trial_rows(value: Any) -> list[dict[str, Any]]: - if value is None: - return [] - if isinstance(value, np.ndarray): - value = value.tolist() - if isinstance(value, dict): - return [value] - if not isinstance(value, list): - return [] - return [row for row in value if isinstance(row, dict)] - - def _unique_values(values: Any) -> list[Any]: output: list[Any] = [] for value in values: @@ -150,11 +142,7 @@ def _paired_subject_sets(value_subjects: list[list[Any]]) -> bool: def _equal_value(left: Any, right: Any) -> bool: - if isinstance(left, np.ndarray): - left = left.tolist() - if isinstance(right, np.ndarray): - right = right.tolist() - return left == right + return equal_value(left, right) __all__ = ["std_getindvar"] diff --git a/src/eegprep/functions/studyfunc/std_gettrialsind.py b/src/eegprep/functions/studyfunc/std_gettrialsind.py index bf600489..266e2bac 100644 --- a/src/eegprep/functions/studyfunc/std_gettrialsind.py +++ b/src/eegprep/functions/studyfunc/std_gettrialsind.py @@ -8,6 +8,7 @@ import numpy as np from eegprep.functions.popfunc._pop_utils import parse_key_value_args +from eegprep.functions.studyfunc._study_utils import equal_value, trialinfo_rows def std_gettrialsind( @@ -24,7 +25,7 @@ def std_gettrialsind( """ if isinstance(trialinfo, (str, Path)): raise ValueError("std_gettrialsind expects loaded trialinfo rows, not a MATLAB filename") - rows = _trial_rows(trialinfo) + rows = trialinfo_rows(trialinfo) queries = parse_key_value_args(args, kwargs, lowercase_kwargs=False) if not queries: indices = list(range(1, len(rows) + 1)) @@ -84,18 +85,6 @@ def _numeric_or_none(value: Any) -> float | None: return float(value) -def _trial_rows(value: Any) -> list[dict[str, Any]]: - if isinstance(value, dict) and "trialinfo" in value: - value = value.get("trialinfo") - if isinstance(value, np.ndarray): - value = value.tolist() - if isinstance(value, dict): - return [value] - if not isinstance(value, list): - return [] - return [row for row in value if isinstance(row, dict)] - - def _as_values(value: Any) -> list[Any]: if isinstance(value, np.ndarray): value = value.tolist() @@ -105,11 +94,7 @@ def _as_values(value: Any) -> list[Any]: def _equal(left: Any, right: Any) -> bool: - if isinstance(left, np.ndarray): - left = left.tolist() - if isinstance(right, np.ndarray): - right = right.tolist() - return left == right + return equal_value(left, right) __all__ = ["std_gettrialsind"] diff --git a/src/eegprep/functions/studyfunc/std_limodesign.py b/src/eegprep/functions/studyfunc/std_limodesign.py index 06650130..7caa4d64 100644 --- a/src/eegprep/functions/studyfunc/std_limodesign.py +++ b/src/eegprep/functions/studyfunc/std_limodesign.py @@ -9,7 +9,7 @@ import numpy as np from eegprep.functions.popfunc._pop_utils import is_on, parse_key_value_args -from eegprep.functions.studyfunc._study_utils import build_python_call +from eegprep.functions.studyfunc._study_utils import build_python_call, equal_value, trialinfo_rows, value_matches from eegprep.functions.studyfunc.pop_listfactors import pop_listfactors @@ -33,7 +33,7 @@ def std_limodesign( if options: raise ValueError(f"Unknown std_limodesign option(s): {', '.join(sorted(options))}") - rows = _trial_rows(trialinfo) + rows = trialinfo_rows(trialinfo) factor_rows = _factor_rows(factors) categorical = _categorical_options(factor_rows) continuous = _continuous_options(factor_rows) @@ -75,30 +75,6 @@ def _factor_rows(factors: Any) -> list[dict[str, Any]]: return [factor for factor in factors if isinstance(factor, dict) and str(factor.get("label") or "") != "constant"] -def _trial_rows(value: Any) -> list[dict[str, Any]]: - if isinstance(value, dict) and "trialinfo" in value: - value = value["trialinfo"] - if isinstance(value, np.ndarray): - value = value.tolist() - if isinstance(value, dict): - lengths = [len(item) for item in value.values() if isinstance(item, (list, tuple, np.ndarray))] - if not lengths: - return [value] - rows = [] - for index in range(max(lengths)): - row = {} - for key, column in value.items(): - if isinstance(column, np.ndarray): - column = column.tolist() - if isinstance(column, (list, tuple)) and index < len(column): - row[key] = column[index] - rows.append(row) - return rows - if not isinstance(value, list): - return [] - return [row if isinstance(row, dict) else {} for row in value] - - def _categorical_options(factors: list[dict[str, Any]]) -> list[dict[str, Any]]: output: list[dict[str, Any]] = [] for factor in factors: @@ -209,19 +185,11 @@ def _matching_rows(rows: list[dict[str, Any]], option: list[tuple[str, Any]]) -> def _value_matches(value: Any, expected: Any) -> bool: - if isinstance(expected, np.ndarray): - expected = expected.tolist() - if isinstance(expected, (list, tuple, set)): - return any(_value_matches(value, item) for item in expected) - return _equal(value, expected) + return value_matches(value, expected) def _equal(left: Any, right: Any) -> bool: - if isinstance(left, np.ndarray): - left = left.tolist() - if isinstance(right, np.ndarray): - right = right.tolist() - return left == right + return equal_value(left, right) def _save_design_files(path: Path, cat_matrix: np.ndarray, cont_matrix: np.ndarray) -> None: diff --git a/src/eegprep/functions/studyfunc/std_pac.py b/src/eegprep/functions/studyfunc/std_pac.py index 8f6e7bcf..54a7a5d5 100644 --- a/src/eegprep/functions/studyfunc/std_pac.py +++ b/src/eegprep/functions/studyfunc/std_pac.py @@ -15,7 +15,7 @@ numeric_vector, ) from eegprep.functions.popfunc._pop_utils import is_on, parse_key_value_args -from eegprep.functions.studyfunc._study_utils import as_alleeg_list, build_python_call, ensure_study +from eegprep.functions.studyfunc._study_utils import as_alleeg_list, build_python_call, ensure_study, range_mask from eegprep.functions.studyfunc.std_checkset import std_checkset from eegprep.functions.studyfunc.std_savedat import std_savedat from eegprep.functions.timefreqfunc.pac import pac @@ -428,16 +428,7 @@ def _apply_ranges( def _range_mask(axis: np.ndarray, bounds: Any, name: str) -> np.ndarray: - values = numeric_vector(bounds, dtype=float) - axis = np.asarray(axis, dtype=float).ravel() - if values.size == 0: - return np.ones(axis.size, dtype=bool) - if values.size != 2: - raise ValueError(f"{name} must contain [min max]") - mask = (axis >= values[0]) & (axis <= values[1]) - if not np.any(mask): - raise ValueError(f"{name} does not include any PAC samples") - return mask + return range_mask(axis, bounds, name=name, empty_message=f"{name} does not include any PAC samples") def _check_axis(existing: Any, candidate: np.ndarray, name: str) -> np.ndarray: diff --git a/src/eegprep/functions/studyfunc/std_preclust.py b/src/eegprep/functions/studyfunc/std_preclust.py index cf1ed965..e9ab4027 100644 --- a/src/eegprep/functions/studyfunc/std_preclust.py +++ b/src/eegprep/functions/studyfunc/std_preclust.py @@ -8,6 +8,12 @@ from eegprep.functions.popfunc._plot_utils import component_maps, python_literal from eegprep.functions.popfunc._pop_utils import parse_numeric_sequence +from eegprep.functions.studyfunc._study_utils import ( + axis_position, + component_dataset_axis, + component_measure_axis, + range_mask, +) from eegprep.functions.studyfunc._cluster_utils import ( checked_study_and_datasets, cluster_command, @@ -183,8 +189,8 @@ def _component_measure_values( data = np.asarray(parent[data_field], dtype=float) if data.ndim < 3: raise ValueError(f"Component measure '{measure}' requires dataset x component data") - dataset_axis = _measure_axis(parent, "datasets", data.shape[0], fallback_start=1) - component_axis = _measure_axis(parent, "components", data.shape[1], fallback_start=1) + dataset_axis = component_dataset_axis(parent, data.shape[0]) + component_axis = component_measure_axis(parent, data.shape[1]) dataset_index = _axis_position(dataset_axis, study_set, f"dataset {study_set}", measure) component_index = _axis_position(component_axis, component, f"component {component}", measure) values = np.asarray(data[dataset_index, component_index], dtype=float) @@ -194,44 +200,8 @@ def _component_measure_values( return _slice_measure_axes(values, source, measure, spec) -def _measure_axis(parent: dict[str, Any], key: str, count: int, *, fallback_start: int) -> np.ndarray: - raw_measureinfo = parent.get("measureinfo") - measureinfo: dict[str, Any] = raw_measureinfo if isinstance(raw_measureinfo, dict) else {} - values = _as_int_vector(measureinfo.get(key)) - if values.size == count: - return values - if key == "datasets": - values = _as_int_vector(parent.get("sets")) - else: - values = _as_int_vector(parent.get("comps")) - unique_values = np.asarray(_unique_preserving_order(values.tolist()), dtype=int) - if unique_values.size == count: - return unique_values - return np.arange(fallback_start, fallback_start + count, dtype=int) - - -def _as_int_vector(value: Any) -> np.ndarray: - if value is None: - return np.asarray([], dtype=int) - array = np.asarray(value, dtype=int) - if array.size == 0: - return np.asarray([], dtype=int) - return array.ravel() - - def _axis_position(axis: np.ndarray, value: int, label: str, measure: str) -> int: - matches = np.where(axis == value)[0] - if matches.size == 0: - raise ValueError(f"Component measure '{measure}' is missing {label}") - return int(matches[0]) - - -def _unique_preserving_order(values: list[int]) -> list[int]: - output = [] - for value in values: - if value not in output: - output.append(value) - return output + return axis_position(axis, value, label, context=f"Component measure '{measure}'") def _slice_measure_axes(values: np.ndarray, source: dict[str, Any], measure: str, spec: dict[str, Any]) -> np.ndarray: @@ -259,10 +229,11 @@ def _slice_one_axis(values: np.ndarray, axis_values: Any, bounds: Any) -> np.nda flat = values.ravel() if parsed.size == 0 or axis.size != flat.size: return flat - mask = (axis >= parsed[0]) & (axis <= parsed[1]) - if not np.any(mask): - raise ValueError("preclust measure range does not contain any samples") - return flat[mask] + return flat[ + range_mask( + axis, parsed, name="preclust ranges", empty_message="preclust measure range does not contain any samples" + ) + ] def _bounds(value: Any) -> np.ndarray: diff --git a/src/eegprep/functions/studyfunc/std_readdata.py b/src/eegprep/functions/studyfunc/std_readdata.py index 726183d0..753bcb33 100644 --- a/src/eegprep/functions/studyfunc/std_readdata.py +++ b/src/eegprep/functions/studyfunc/std_readdata.py @@ -8,10 +8,16 @@ from eegprep.functions.popfunc._plot_utils import numeric_vector from eegprep.functions.studyfunc._cluster_utils import cluster_list, sets_array -from eegprep.functions.studyfunc._study_utils import ensure_study, unique_preserving_order - - -MEASURE_DATA_FIELDS = {"erp": "erpdata", "spec": "specdata", "ersp": "erspdata", "itc": "itcdata"} +from eegprep.functions.studyfunc._study_utils import ( + MEASURE_DATA_FIELDS, + MEASURE_X_AXIS_FIELDS, + MEASURE_Y_AXIS_FIELDS, + axis_position, + component_dataset_axis as _shared_component_dataset_axis, + component_measure_axis as _shared_component_measure_axis, + ensure_study, + range_mask, +) def std_readdata( @@ -225,27 +231,12 @@ def _component_cluster_index(study: dict[str, Any], clusters: Any) -> int: def component_measure_axis(group: dict[str, Any], count: int) -> np.ndarray: """Return cached component IDs for a STUDY component-measure axis.""" - return _cached_axis_values(group, "components", "comps", count, fallback_start=1) + return _shared_component_measure_axis(group, count) def component_dataset_axis(group: dict[str, Any], count: int) -> np.ndarray: """Return cached STUDY dataset IDs for a component-measure axis.""" - return _cached_axis_values(group, "datasets", "sets", count, fallback_start=1) - - -def _cached_axis_values( - group: dict[str, Any], measureinfo_key: str, fallback_key: str, count: int, *, fallback_start: int -) -> np.ndarray: - raw_measureinfo = group.get("measureinfo") - measureinfo: dict[str, Any] = raw_measureinfo if isinstance(raw_measureinfo, dict) else {} - values = numeric_vector(measureinfo.get(measureinfo_key), dtype=int) - if values.size == count: - return values.astype(int) - values = numeric_vector(group.get(fallback_key), dtype=int) - unique_values = np.asarray(unique_preserving_order(values.tolist()), dtype=int) - if unique_values.size == count: - return unique_values - return np.arange(fallback_start, fallback_start + count, dtype=int) + return _shared_component_dataset_axis(group, count) def component_measure_selection(components: Any, axis: np.ndarray) -> np.ndarray: @@ -301,14 +292,19 @@ def _cluster_component_data( def _axis_position(axis: np.ndarray, value: int, label: str) -> int: - matches = np.where(axis == value)[0] - if matches.size == 0: - raise ValueError(f"Component measure cache is missing {label}") - return int(matches[0]) + return axis_position(axis, value, label) def _all_channels_requested(channels: Any) -> bool: - return channels is None or (isinstance(channels, str) and channels in {"", "channels"}) + if channels is None: + return True + if isinstance(channels, str): + return channels in {"", "channels"} + if isinstance(channels, np.ndarray): + return channels.size == 0 + if isinstance(channels, (list, tuple)): + return len(channels) == 0 + return False def _parent_cluster_requested(clusters: Any) -> bool: @@ -383,18 +379,12 @@ def _pac_secondary_positions(group: dict[str, Any], secondary: Any) -> np.ndarra def _range_mask(axis: np.ndarray, bounds: Any) -> np.ndarray: - axis = np.asarray(axis, dtype=float).ravel() - if axis.size == 0: - return np.asarray([], dtype=bool) - values = numeric_vector(bounds, dtype=float) - if values.size == 0: - return np.ones(axis.size, dtype=bool) - if values.size != 2: - raise ValueError("range filters must contain [min max]") - mask = (axis >= values[0]) & (axis <= values[1]) - if not np.any(mask): - raise ValueError("range filter does not include any cached measure samples") - return mask + return range_mask( + axis, + bounds, + name="range filters", + empty_message="range filter does not include any cached measure samples", + ) def _subject_filter(data: np.ndarray, study: dict[str, Any], subject: Any) -> np.ndarray: @@ -423,12 +413,11 @@ def _subject_values(subject: Any) -> set[str]: def _x_axis(group: dict[str, Any], measure: str) -> np.ndarray: - field = {"erp": "erptimes", "spec": "specfreqs", "ersp": "ersptimes", "itc": "itctimes"}[measure] - return np.asarray(group.get(field, []), dtype=float) + return np.asarray(group.get(MEASURE_X_AXIS_FIELDS[measure], []), dtype=float) def _y_axis(group: dict[str, Any], measure: str) -> np.ndarray: - field = {"ersp": "erspfreqs", "itc": "itcfreqs"}.get(measure) + field = MEASURE_Y_AXIS_FIELDS.get(measure) if field is None: return np.asarray([], dtype=float) return np.asarray(group.get(field, []), dtype=float) diff --git a/src/eegprep/functions/studyfunc/std_selectdataset.py b/src/eegprep/functions/studyfunc/std_selectdataset.py index a9189e24..a2fbb088 100644 --- a/src/eegprep/functions/studyfunc/std_selectdataset.py +++ b/src/eegprep/functions/studyfunc/std_selectdataset.py @@ -6,7 +6,13 @@ import numpy as np -from eegprep.functions.studyfunc._study_utils import _empty_value, as_alleeg_list, ensure_study, sync_datasetinfo +from eegprep.functions.studyfunc._study_utils import ( + _empty_value, + as_alleeg_list, + ensure_study, + sync_datasetinfo, + trialinfo_rows, +) from eegprep.functions.studyfunc.std_indvarmatch import std_indvarmatch @@ -44,7 +50,7 @@ def std_selectdataset( selected_trials: list[list[int]] = [[] for _info in infos] found_field = False for index, info in enumerate(infos): - rows = _trial_rows(info.get("trialinfo")) + rows = trialinfo_rows(info.get("trialinfo")) if not rows: continue if any(label in row for row in rows): @@ -80,7 +86,7 @@ def _dataset_field_present(infos: list[dict[str, Any]], label: str) -> bool: def _all_trials(infos: list[dict[str, Any]], datasets: list[dict[str, Any]]) -> list[list[int]]: selections: list[list[int]] = [] for index, info in enumerate(infos): - rows = _trial_rows(info.get("trialinfo")) + rows = trialinfo_rows(info.get("trialinfo")) if rows: selections.append(list(range(1, len(rows) + 1))) continue @@ -91,18 +97,6 @@ def _all_trials(infos: list[dict[str, Any]], datasets: list[dict[str, Any]]) -> return selections -def _trial_rows(value: Any) -> list[dict[str, Any]]: - if value is None: - return [] - if isinstance(value, np.ndarray): - value = value.tolist() - if isinstance(value, dict): - return [value] - if not isinstance(value, list): - return [] - return [row for row in value if isinstance(row, dict)] - - def _unique(values: list[int]) -> list[int]: output: list[int] = [] for value in values: diff --git a/src/eegprep/functions/timefreqfunc/_bootstrap.py b/src/eegprep/functions/timefreqfunc/_bootstrap.py new file mode 100644 index 00000000..1fbbda93 --- /dev/null +++ b/src/eegprep/functions/timefreqfunc/_bootstrap.py @@ -0,0 +1,125 @@ +"""Shared bootstrap helpers for EEGLAB-style time-frequency functions.""" + +from __future__ import annotations + +from typing import Any + +import numpy as np + +from eegprep.functions.miscfunc.value_parsing import parse_numeric_sequence +from eegprep.functions.timefreqfunc.newtimeftrialbaseln import baseline_indices + + +def bootstrap_threshold(surrogates: Any, *, alpha: float = 0.05, bootside: str = "both") -> np.ndarray: + """Return lower/upper or upper-only thresholds from accumulated surrogates.""" + values = np.asarray(surrogates) + if values.ndim < 1: + raise ValueError("surrogates must contain an accumulation axis") + if np.iscomplexobj(values): + values = np.abs(values) + sorted_values = np.sort(values, axis=0) + tail_count = max(1, int(round(sorted_values.shape[0] * float(alpha)))) + upper = np.nanmean(sorted_values[-tail_count:, ...], axis=0) + if str(bootside).lower() == "upper": + return np.squeeze(upper) + lower = np.nanmean(sorted_values[:tail_count, ...], axis=0) + return np.stack([lower, upper], axis=-1).squeeze() + + +def thresholds_by_frequency(values: np.ndarray, *, alpha: float, bootside: str) -> np.ndarray: + """Pool surrogate baseline/time samples per frequency before thresholding.""" + nfreq = values.shape[1] + pooled = values.transpose(0, 2, 1).reshape(-1, nfreq) + thresholds = np.asarray(bootstrap_threshold(pooled, alpha=alpha, bootside=bootside)) + if str(bootside).lower() == "both": + return thresholds.reshape(nfreq, 2) + return thresholds.reshape(nfreq) + + +def threshold_vector(thresholds: Any, target_shape: tuple[int, ...]) -> np.ndarray: + """Broadcast scalar or per-frequency thresholds toward a time-frequency result.""" + values = np.asarray(thresholds, dtype=float).squeeze() + if values.ndim == 0: + return np.full(target_shape, float(values)) + if values.ndim == 1: + return values[:, np.newaxis] + return values + + +def bootstrap_indices( + times: np.ndarray, + *, + baseline: Any = None, + baseboot: Any = 1, + baseln: np.ndarray | None = None, + limit_to_baseboot: bool = False, +) -> np.ndarray: + """Return zero-based bootstrap time indices for newtimef/newcrossf paths.""" + values = np.asarray(parse_numeric_sequence(baseboot, dtype=float), dtype=float) + if values.size == 0: + if baseln is not None: + return np.asarray(baseln, dtype=int) + return np.nonzero(np.asarray(times) <= 0)[0] + if values.size == 1: + if values[0] == 0: + return np.asarray([], dtype=int) + if not limit_to_baseboot: + baseline_values = np.asarray(parse_numeric_sequence(baseline, dtype=float), dtype=float) + if baseline_values.size and not np.isnan(baseline_values[0]): + return np.asarray([] if baseln is None else baseln, dtype=int) + upper_bound = values[0] if limit_to_baseboot else 0 + indices = np.nonzero(np.asarray(times) <= upper_bound)[0] + return indices if indices.size else np.arange(np.asarray(times).size, dtype=int) + return baseline_indices(times, values) + + +def resample_trials( + values: np.ndarray, + generator: np.random.Generator, + boottype: str, + *, + complex_phase: bool = False, +) -> np.ndarray: + """Resample trial-axis time-frequency arrays for bootstrap/permutation.""" + mode = str(boottype).lower() + sample = np.asarray(values).copy() + if mode in {"shuffle", "shufftrials"}: + trial_indices = generator.integers(0, sample.shape[2], size=sample.shape[2]) + return sample[:, :, trial_indices] + if mode in {"rand", "randall"}: + if complex_phase or np.iscomplexobj(sample): + return sample * np.exp(1j * generator.uniform(0.0, 2.0 * np.pi, size=sample.shape)) + signs = generator.choice(np.asarray([-1.0, 1.0]), size=sample.shape) + return sample * signs + raise ValueError("boottype must be 'shuffle', 'shufftrials', 'rand', or 'randall'") + + +def resample_pair( + first: np.ndarray, second: np.ndarray, generator: np.random.Generator, *, boottype: str +) -> tuple[np.ndarray, np.ndarray]: + """Resample paired time-frequency arrays for cross-frequency bootstrap.""" + mode = str(boottype).lower() + if mode in {"shuffle", "shufftrials"}: + indices = generator.permutation(second.shape[2]) + return first, second[:, :, indices] + if mode in {"rand", "randall"}: + phases = np.exp(1j * generator.uniform(0.0, 2.0 * np.pi, size=first.shape)) + return first * phases, second + raise ValueError("boottype must be 'shuffle', 'shufftrials', 'rand', or 'randall'") + + +def resample_array(array: np.ndarray, rng: np.random.Generator, *, boottype: str, shuffledim: list[int]) -> np.ndarray: + """Resample a generic bootstrap array for bootstat.""" + values = np.asarray(array).copy() + mode = str(boottype).lower() + if mode == "rand": + if np.iscomplexobj(values): + phases = rng.uniform(0.0, 2.0 * np.pi, size=values.shape) + return values * np.exp(1j * phases) + signs = rng.choice(np.asarray([-1.0, 1.0]), size=values.shape) + return values * signs + if mode != "shuffle": + raise ValueError("boottype must be 'shuffle' or 'rand'") + for axis in shuffledim: + values = np.take(values, rng.permutation(values.shape[axis]), axis=axis) + return values diff --git a/src/eegprep/functions/timefreqfunc/_pac_support.py b/src/eegprep/functions/timefreqfunc/_pac_support.py index bb450f00..6c4f3cdb 100644 --- a/src/eegprep/functions/timefreqfunc/_pac_support.py +++ b/src/eegprep/functions/timefreqfunc/_pac_support.py @@ -9,7 +9,7 @@ import numpy as np from scipy import signal, stats -from eegprep.functions.popfunc._pop_utils import is_on, parse_numeric_sequence +from eegprep.functions.miscfunc.value_parsing import is_on, parse_numeric_sequence from eegprep.functions.statistics.fdr import fdr from eegprep.functions.timefreqfunc.timefreq import timefreq diff --git a/src/eegprep/functions/timefreqfunc/bootstat.py b/src/eegprep/functions/timefreqfunc/bootstat.py index 4eefb50a..4536ce06 100644 --- a/src/eegprep/functions/timefreqfunc/bootstat.py +++ b/src/eegprep/functions/timefreqfunc/bootstat.py @@ -7,6 +7,8 @@ import numpy as np +from eegprep.functions.timefreqfunc._bootstrap import bootstrap_threshold, resample_array + Statistic = Callable[..., np.ndarray] @@ -43,29 +45,13 @@ def bootstat( dims = _shuffle_dims(shuffledim, selected[0].ndim, boottype=boottype) surrogates = [] for _ in range(int(naccu)): - boot_args = [_resample_array(array, rng, boottype=boottype, shuffledim=dims) for array in selected] + boot_args = [resample_array(array, rng, boottype=boottype, shuffledim=dims) for array in selected] surrogates.append(np.asarray(statistic(*boot_args))) accumulated = np.stack(surrogates, axis=0) thresholds = bootstrap_threshold(accumulated, alpha=alpha, bootside=bootside) return BootstrapResult(thresholds=thresholds, surrogates=accumulated) -def bootstrap_threshold(surrogates: Any, *, alpha: float = 0.05, bootside: str = "both") -> np.ndarray: - """Return lower/upper or upper-only thresholds from accumulated surrogates.""" - values = np.asarray(surrogates) - if values.ndim < 1: - raise ValueError("surrogates must contain an accumulation axis") - if np.iscomplexobj(values): - values = np.abs(values) - sorted_values = np.sort(values, axis=0) - tail_count = max(1, int(round(sorted_values.shape[0] * float(alpha)))) - upper = np.nanmean(sorted_values[-tail_count:, ...], axis=0) - if str(bootside).lower() == "upper": - return np.squeeze(upper) - lower = np.nanmean(sorted_values[:tail_count, ...], axis=0) - return np.stack([lower, upper], axis=-1).squeeze() - - def exact_p_values(observed: Any, surrogates: Any, *, center: Any = None) -> np.ndarray: """Return two-sided empirical p-values for observed values.""" observed_values = np.asarray(observed) @@ -109,20 +95,4 @@ def _shuffle_dims(shuffledim: Any, ndim: int, *, boottype: str) -> list[int]: return dims -def _resample_array(array: np.ndarray, rng: np.random.Generator, *, boottype: str, shuffledim: list[int]) -> np.ndarray: - values = np.asarray(array).copy() - mode = str(boottype).lower() - if mode == "rand": - if np.iscomplexobj(values): - phases = rng.uniform(0.0, 2.0 * np.pi, size=values.shape) - return values * np.exp(1j * phases) - signs = rng.choice(np.asarray([-1.0, 1.0]), size=values.shape) - return values * signs - if mode != "shuffle": - raise ValueError("boottype must be 'shuffle' or 'rand'") - for axis in shuffledim: - values = np.take(values, rng.permutation(values.shape[axis]), axis=axis) - return values - - __all__ = ["BootstrapResult", "bootstat", "bootstrap_threshold", "exact_p_values"] diff --git a/src/eegprep/functions/timefreqfunc/newcrossf.py b/src/eegprep/functions/timefreqfunc/newcrossf.py index 1a1f1790..e49809fc 100644 --- a/src/eegprep/functions/timefreqfunc/newcrossf.py +++ b/src/eegprep/functions/timefreqfunc/newcrossf.py @@ -9,12 +9,17 @@ import numpy as np from scipy import stats -from eegprep.functions.popfunc._pop_utils import is_on as _is_on -from eegprep.functions.popfunc._pop_utils import parse_numeric_sequence -from eegprep.functions.timefreqfunc.bootstat import bootstrap_threshold, exact_p_values -from eegprep.functions.timefreqfunc.newtimef import _threshold_vector, compute_time_frequency +from eegprep.functions.miscfunc.value_parsing import is_on as _is_on +from eegprep.functions.miscfunc.value_parsing import parse_numeric_sequence +from eegprep.functions.timefreqfunc._bootstrap import ( + bootstrap_indices as shared_bootstrap_indices, + resample_pair, + threshold_vector as _threshold_vector, + thresholds_by_frequency, +) +from eegprep.functions.timefreqfunc.bootstat import exact_p_values +from eegprep.functions.timefreqfunc.newtimef import compute_time_frequency from eegprep.functions.timefreqfunc.newtimefitc import newtimefitc -from eegprep.functions.timefreqfunc.newtimeftrialbaseln import baseline_indices @dataclass(frozen=True) @@ -218,34 +223,18 @@ def _bootstrap_coherence( source_y = tf_y[:, base_indices, :] if base_indices.size else tf_y threshold_source = np.empty((int(naccu), tf_x.shape[0], max(1, source_x.shape[1])), dtype=float) for index in range(int(naccu)): - sample_x, sample_y = _resample_pair(tf_x, tf_y, generator, boottype=boottype) + sample_x, sample_y = resample_pair(tf_x, tf_y, generator, boottype=boottype) coher, _allcoher, _lagmap = _coherence(sample_x, sample_y, mode=mode, amplag=np.asarray([0]), alpha=None) surrogates[index] = np.abs(coher) - boot_x, boot_y = _resample_pair(source_x, source_y, generator, boottype=boottype) + boot_x, boot_y = resample_pair(source_x, source_y, generator, boottype=boottype) boot_coher, _allcoher, _lagmap = _coherence(boot_x, boot_y, mode=mode, amplag=np.asarray([0]), alpha=None) threshold_source[index] = np.abs(boot_coher) thresholds = _upper_thresholds_by_frequency(threshold_source, alpha=alpha) return thresholds, surrogates -def _resample_pair( - tf_x: np.ndarray, tf_y: np.ndarray, generator: np.random.Generator, *, boottype: str -) -> tuple[np.ndarray, np.ndarray]: - mode = str(boottype).lower() - if mode in {"shuffle", "shufftrials"}: - indices = generator.permutation(tf_y.shape[2]) - return tf_x, tf_y[:, :, indices] - if mode in {"rand", "randall"}: - phases = np.exp(1j * generator.uniform(0.0, 2.0 * np.pi, size=tf_x.shape)) - return tf_x * phases, tf_y - raise ValueError("boottype must be 'shuffle', 'shufftrials', 'rand', or 'randall'") - - def _upper_thresholds_by_frequency(values: np.ndarray, *, alpha: float) -> np.ndarray: - nfreq = values.shape[1] - pooled = values.transpose(0, 2, 1).reshape(-1, nfreq) - thresholds = np.asarray(bootstrap_threshold(pooled, alpha=alpha, bootside="upper")) - return thresholds.reshape(nfreq) + return thresholds_by_frequency(values, alpha=alpha, bootside="upper") def _shuffle_trials(tf_y: np.ndarray, count: int, rng: Any) -> np.ndarray: @@ -263,15 +252,7 @@ def _remove_itc(tfdata: np.ndarray) -> np.ndarray: def _bootstrap_indices(times: np.ndarray, baseboot: Any) -> np.ndarray: - values = _numeric_vector(baseboot) - if values.size == 0: - return np.nonzero(times <= 0)[0] - if values.size == 1: - if values[0] == 0: - return np.asarray([], dtype=int) - indices = np.nonzero(times <= values[0])[0] - return indices if indices.size else np.arange(times.size, dtype=int) - return baseline_indices(times, values) + return shared_bootstrap_indices(times, baseboot=baseboot, limit_to_baseboot=True) def _plot_cross_frequency( diff --git a/src/eegprep/functions/timefreqfunc/newtimef.py b/src/eegprep/functions/timefreqfunc/newtimef.py index 6e90c2a9..479aae0e 100644 --- a/src/eegprep/functions/timefreqfunc/newtimef.py +++ b/src/eegprep/functions/timefreqfunc/newtimef.py @@ -8,13 +8,20 @@ import matplotlib.pyplot as plt import numpy as np -from eegprep.functions.popfunc._pop_utils import is_on as _is_on -from eegprep.functions.popfunc._pop_utils import parse_numeric_sequence +from eegprep.functions.miscfunc.value_parsing import is_empty_value as _is_empty_value +from eegprep.functions.miscfunc.value_parsing import is_on as _is_on +from eegprep.functions.miscfunc.value_parsing import parse_numeric_sequence from eegprep.functions.statistics.fdr import fdr -from eegprep.functions.timefreqfunc.bootstat import bootstrap_threshold, exact_p_values +from eegprep.functions.timefreqfunc._bootstrap import ( + bootstrap_indices as shared_bootstrap_indices, + resample_trials, + threshold_vector as _threshold_vector, + thresholds_by_frequency, +) +from eegprep.functions.timefreqfunc.bootstat import exact_p_values from eegprep.functions.timefreqfunc.newtimefbaseln import newtimefbaseln from eegprep.functions.timefreqfunc.newtimefitc import newtimefitc -from eegprep.functions.timefreqfunc.newtimeftrialbaseln import baseline_indices, newtimeftrialbaseln +from eegprep.functions.timefreqfunc.newtimeftrialbaseln import newtimeftrialbaseln from eegprep.functions.timefreqfunc.timefreq import timefreq @@ -376,14 +383,6 @@ def _validate_vertical_markers(markers: np.ndarray, tlimits: Any) -> None: raise ValueError("vertical line ('vert') latency outside of epoch boundaries") -def _is_empty_value(value: Any) -> bool: - if value is None: - return True - if isinstance(value, str): - return not value.strip() or value.strip() in {"[]", "{}"} - return np.asarray(value).size == 0 - - def _split_timesout(timesout: Any) -> tuple[Any, Any]: values = _numeric_vector(timesout) if values.size == 0: @@ -428,18 +427,7 @@ def _boot_array(value: Any) -> np.ndarray | None: def _bootstrap_indices(times: np.ndarray, baseline: Any, baseboot: Any, baseln: np.ndarray) -> np.ndarray: - values = _numeric_vector(baseboot) - if values.size == 0: - return baseln - if values.size == 1: - if values[0] == 0: - return np.asarray([], dtype=int) - baseline_values = _numeric_vector(baseline) - if baseline_values.size and not np.isnan(baseline_values[0]): - return baseln - indices = np.nonzero(times <= 0)[0] - return indices if indices.size else np.arange(times.size, dtype=int) - return baseline_indices(times, values) + return shared_bootstrap_indices(times, baseline=baseline, baseboot=baseboot, baseln=baseln) def _bootstrap_power( @@ -457,9 +445,9 @@ def _bootstrap_power( boot_source = power[:, base_indices, :] if base_indices.size else power threshold_source = np.empty((int(naccu), power.shape[0], max(1, boot_source.shape[1])), dtype=float) for index in range(int(naccu)): - sample = _resample_trials(power, generator, boottype) + sample = resample_trials(power, generator, boottype) surrogates[index] = _power_to_output(np.nanmean(sample, axis=2), scale) - threshold_sample = _resample_trials(boot_source, generator, boottype) + threshold_sample = resample_trials(boot_source, generator, boottype) threshold_source[index] = _power_to_output(np.nanmean(threshold_sample, axis=2), scale) thresholds = _thresholds_by_frequency(threshold_source, alpha=alpha, both=True) return thresholds, surrogates @@ -480,40 +468,16 @@ def _bootstrap_itc( boot_source = tfdata[:, base_indices, :] if base_indices.size else tfdata threshold_source = np.empty((int(naccu), tfdata.shape[0], max(1, boot_source.shape[1])), dtype=float) for index in range(int(naccu)): - sample = _resample_trials(tfdata, generator, boottype, complex_phase=True) + sample = resample_trials(tfdata, generator, boottype, complex_phase=True) surrogates[index] = np.abs(newtimefitc(sample, itctype)) - threshold_sample = _resample_trials(boot_source, generator, boottype, complex_phase=True) + threshold_sample = resample_trials(boot_source, generator, boottype, complex_phase=True) threshold_source[index] = np.abs(newtimefitc(threshold_sample, itctype)) thresholds = _thresholds_by_frequency(threshold_source, alpha=alpha, both=False) return thresholds, surrogates -def _resample_trials( - values: np.ndarray, - generator: np.random.Generator, - boottype: str, - *, - complex_phase: bool = False, -) -> np.ndarray: - mode = str(boottype).lower() - sample = np.asarray(values).copy() - if mode in {"shuffle", "shufftrials"}: - trial_indices = generator.integers(0, sample.shape[2], size=sample.shape[2]) - return sample[:, :, trial_indices] - if mode in {"rand", "randall"}: - if complex_phase or np.iscomplexobj(sample): - return sample * np.exp(1j * generator.uniform(0.0, 2.0 * np.pi, size=sample.shape)) - signs = generator.choice(np.asarray([-1.0, 1.0]), size=sample.shape) - return sample * signs - raise ValueError("boottype must be 'shuffle', 'shufftrials', 'rand', or 'randall'") - - def _thresholds_by_frequency(values: np.ndarray, *, alpha: float, both: bool) -> np.ndarray: - nfreq = values.shape[1] - pooled = values.transpose(0, 2, 1).reshape(-1, nfreq) - bootside = "both" if both else "upper" - thresholds = np.asarray(bootstrap_threshold(pooled, alpha=alpha, bootside=bootside)) - return thresholds.reshape(nfreq, 2) if both else thresholds.reshape(nfreq) + return thresholds_by_frequency(values, alpha=alpha, bootside="both" if both else "upper") def _significance_mask(pvalues: np.ndarray, alpha: float, correction: str) -> np.ndarray: @@ -537,15 +501,6 @@ def _threshold_mask(values: np.ndarray, thresholds: np.ndarray) -> np.ndarray: return (values <= lower) | (values >= upper) -def _threshold_vector(thresholds: np.ndarray, target_shape: tuple[int, ...]) -> np.ndarray: - values = np.asarray(thresholds, dtype=float).squeeze() - if values.ndim == 0: - return np.full(target_shape, float(values)) - if values.ndim == 1: - return values[:, np.newaxis] - return values - - def _plot_time_frequency( ersp: np.ndarray, itc: np.ndarray, diff --git a/src/eegprep/functions/timefreqfunc/tf_cycle_calc.py b/src/eegprep/functions/timefreqfunc/tf_cycle_calc.py index 0ec37608..2b4fb336 100644 --- a/src/eegprep/functions/timefreqfunc/tf_cycle_calc.py +++ b/src/eegprep/functions/timefreqfunc/tf_cycle_calc.py @@ -8,6 +8,8 @@ import matplotlib.pyplot as plt import numpy as np +from eegprep.functions.miscfunc.value_parsing import parse_numeric_sequence + SIGMA_TO_FWHM = 2.0 * np.sqrt(2.0 * np.log(2.0)) WIDTH_COLUMNS = ("freq", "cycles", "fwhm_f", "fwhm_t", "2_sigma_f", "2_sigma_t", "sigma_f", "sigma_t") WIDTH_UNITS = ("fwhm_t", "fwhm_f", "2_sigma_t", "2_sigma_f", "sigma_t", "sigma_f", "cycles") @@ -106,41 +108,9 @@ def _width_vector(value: Any, count: int, log_spaced: bool) -> np.ndarray: def _numeric_vector(value: Any) -> np.ndarray: if value is None: return np.asarray([], dtype=float) - if isinstance(value, np.ndarray): - return value.astype(float).ravel() - if isinstance(value, (int, float, np.integer, np.floating)): - return np.asarray([value], dtype=float) - if isinstance(value, str): - text = value.strip().strip("[]") - if not text: - return np.asarray([], dtype=float) - values: list[float] = [] - for token in text.replace(",", " ").split(): - if ":" in token: - values.extend(_colon_sequence(token)) - else: - values.append(float(token)) - return np.asarray(values, dtype=float) - if isinstance(value, (list, tuple)): - return np.asarray(value, dtype=float).ravel() - return np.asarray([value], dtype=float) - - -def _colon_sequence(token: str) -> list[float]: - pieces = token.split(":") - if len(pieces) not in {2, 3}: - raise ValueError(f"Invalid colon range: {token}") - start = float(pieces[0]) - if len(pieces) == 2: - stop = float(pieces[1]) - step = 1.0 if stop >= start else -1.0 - else: - step = float(pieces[1]) - stop = float(pieces[2]) - if step == 0 or (stop - start) * step < 0: - return [] - count = int(np.floor((stop - start) / step + 1e-9)) + 1 - return [float(start + index * step) for index in range(max(count, 0))] + if isinstance(value, str) and value.strip() == "": + return np.asarray([], dtype=float) + return np.asarray(parse_numeric_sequence(value, dtype=float), dtype=float).ravel() def _time_frequency_demo_transform( diff --git a/src/eegprep/functions/timefreqfunc/timefreq.py b/src/eegprep/functions/timefreqfunc/timefreq.py index 8f4dcbfd..ffb22f72 100644 --- a/src/eegprep/functions/timefreqfunc/timefreq.py +++ b/src/eegprep/functions/timefreqfunc/timefreq.py @@ -8,6 +8,7 @@ import numpy as np from scipy import signal +from eegprep.functions.miscfunc.value_parsing import parse_numeric_sequence from eegprep.functions.timefreqfunc.dftfilt2 import dftfilt2 from eegprep.functions.timefreqfunc.dftfilt3 import dftfilt3, symmetric_hanning from eegprep.functions.timefreqfunc.angtimewarp import angtimewarp @@ -475,18 +476,9 @@ def _subtract_itc(tfdata: np.ndarray, itcvals: np.ndarray | None) -> np.ndarray: def _numeric_vector(value: Any, *, dtype: Any = float) -> np.ndarray: if value is None: return np.asarray([], dtype=dtype) - if isinstance(value, np.ndarray): - return value.astype(dtype).ravel() - if isinstance(value, (int, float, np.integer, np.floating)): - return np.asarray([value], dtype=dtype) - if isinstance(value, str): - text = value.strip().strip("[]") - if not text: - return np.asarray([], dtype=dtype) - return np.asarray([float(token) for token in text.replace(",", " ").split()], dtype=dtype) - if isinstance(value, (list, tuple)): - return np.asarray(value, dtype=dtype).ravel() - return np.asarray([value], dtype=dtype) + if isinstance(value, str) and value.strip() == "": + return np.asarray([], dtype=dtype) + return np.asarray(parse_numeric_sequence(value, dtype=dtype), dtype=dtype).ravel() def _first_numeric(value: Any, default: float) -> float: diff --git a/src/eegprep/plugins/EEG_BIDS/montage.py b/src/eegprep/plugins/EEG_BIDS/montage.py new file mode 100644 index 00000000..0acd66ee --- /dev/null +++ b/src/eegprep/plugins/EEG_BIDS/montage.py @@ -0,0 +1,211 @@ +"""Montage inference helpers for EEG-BIDS imports.""" + +from __future__ import annotations + +from collections.abc import Callable, MutableMapping +from importlib.resources import files +import os +from typing import Any + +import numpy as np +from scipy.io.matlab import loadmat + +from eegprep.plugins.EEG_BIDS.coords import ( + clear_chanloc, + coords_ALS_to_angular, + coords_RAS_to_ALS, + coords_any_to_RAS, + coords_to_mm, +) + + +def apply_montage_inference( + EEG: MutableMapping[str, Any], + infer_locations: bool | str, + *, + numeric_null: Any, + report: MutableMapping[str, Any], + warning: Callable[[str], None], + error: Callable[[str], None], +) -> None: + """Infer channel locations from packaged or user-specified montage files.""" + if not infer_locations: + _assign_labelscheme_from_existing_labels(EEG) + return + + EEG["chaninfo"]["nosedir"] = "+X" + datalabels = _normalized_data_labels(EEG["chanlocs"]) + montage_path, filenames = _montage_candidates(infer_locations) + opt_score, best_data, best_cap, fractions = _select_best_montage(montage_path, filenames, datalabels) + best_fraction = opt_score[0] + + if best_data is None: + if isinstance(infer_locations, str): + raise RuntimeError( + f"The channel labels in your data do not match the specified montage file ({infer_locations})." + ) + raise RuntimeError("Channel labels do not match any known or specified montage.") + + skip_locations = _should_skip_or_warn( + best_fraction, + fractions, + datalabels, + warning=warning, + error=error, + ) + if skip_locations: + return + + report["ChanlocsFrom"] = os.path.basename(best_cap) + if "10-5" in best_cap: + labeling = "10-20" + else: + labeling, _ext = os.path.splitext(os.path.basename(best_cap)) + EEG["etc"]["labelscheme"] = labeling + _apply_montage_coordinates(EEG, best_data, datalabels, numeric_null) + + +def _strip_matching_quotes(name: str) -> str: + while len(name) >= 2 and name[0] == name[-1] and name[0] in ("'", '"'): + name = name[1:-1] + return name + + +def _normalized_data_labels(chanlocs: list[dict[str, Any]]) -> list[str]: + datalabels = [chanloc["labels"].lower() for chanloc in chanlocs] + for prefix in ["brainvision rda_", "rda_", "eeg ", "eeg-", "eeg"]: + datalabels = [label.replace(prefix, "") for label in datalabels] + datalabels = [label.split("-")[0] for label in datalabels] + return [_strip_matching_quotes(label) for label in datalabels] + + +def _montage_candidates(infer_locations: bool | str) -> tuple[str, list[str]]: + montage_path = str(files("eegprep").joinpath("resources").joinpath("montages")) + if not os.path.isdir(montage_path): + raise RuntimeError( + f"Could not find montages directory at {montage_path}. This may indicate a corrupted installation." + ) + + if isinstance(infer_locations, str): + if os.path.isabs(infer_locations): + return os.path.dirname(infer_locations), [os.path.basename(infer_locations)] + return montage_path, [infer_locations] + return montage_path, sorted(os.listdir(montage_path)) + + +def _select_best_montage( + montage_path: str, filenames: list[str], datalabels: list[str] +) -> tuple[tuple[float, int, float], Any | None, str, list[float]]: + opt_score: tuple[float, int, float] = (0, 0, 0) + best_data = None + best_cap = "(not set)" + fractions: list[float] = [] + for filename in filenames: + if not filename.endswith(".locs"): + continue + try: + data = loadmat(os.path.join(montage_path, filename), squeeze_me=True) + except Exception as exc: + raise ValueError( + f"Failed to load montage file {filename}. Make sure it is a valid .locs file (MATLAB v7 .mat format)." + ) from exc + caplabels = [label.lower() for label in data["labels"]] + fraction_in_data = np.mean([label in caplabels for label in datalabels]) + fraction_in_locfile = np.mean([label in datalabels for label in caplabels]) + bonus1020 = 1 if {"c3", "cz", "fcz", "c4"}.issubset(caplabels) else 0 + score = (fraction_in_data, bonus1020, fraction_in_locfile) + if score > opt_score: + opt_score = score + best_data = data + best_cap = filename + fractions.append(fraction_in_data) + return opt_score, best_data, best_cap, sorted(fractions, reverse=True) + + +def _should_skip_or_warn( + best_fraction: float, + fractions: list[float], + datalabels: list[str], + *, + warning: Callable[[str], None], + error: Callable[[str], None], +) -> bool: + percent_found = int(100 * best_fraction) + if best_fraction < 0.25: + error( + "The given data has a very poor match to all " + "known montages (%s percent of channels found); " + "not assigning locations (got: %s)" % (percent_found, datalabels) + ) + return True + if best_fraction < 0.5: + if len(fractions) > 1 and best_fraction / 1.5 < fractions[1]: + warning( + "The given data has a poor match and multiple " + "montages are partially matching potentially " + "ambiguously (%s percent of channels found); " + "please double-check assigned locations." % percent_found + ) + else: + warning( + "The given data has a poor match to all known " + "montages (%s percent of channels found); please " + "double-check assigned locations." % percent_found + ) + elif best_fraction < 0.75 and len(fractions) > 1 and best_fraction / 1.5 < fractions[1]: + warning( + "The given data has a reasonable match to known " + "montages but multiple montages are potentially " + "matching (%s percent of channels found); " + "locations may be wrong." % percent_found + ) + elif best_fraction < 1.0: + warning( + "Not all channel locations could be matched to a " + "known montage; some channels may be non-EEG " + "channels ({} percent of channels found).".format(percent_found) + ) + return False + + +def _apply_montage_coordinates( + EEG: MutableMapping[str, Any], + best_data: MutableMapping[str, Any], + datalabels: list[str], + numeric_null: Any, +) -> None: + unit = best_data["meta"]["unit"][()] + x = best_data["meta"]["x"][()] + y = best_data["meta"]["y"][()] + z = best_data["meta"]["z"][()] + coords = best_data["coordinates"] + coords = coords_to_mm(coords, unit) + coords = coords_any_to_RAS(coords, x, y, z) + coords = coords_RAS_to_ALS(coords) + sph_theta, sph_phi, sph_radius, polar_theta, polar_radius = coords_ALS_to_angular(coords) + caplabels = [label.lower() for label in best_data["labels"]] + for data_index, data_label in enumerate(datalabels): + record = EEG["chanlocs"][data_index] + for cap_index, cap_label in enumerate(caplabels): + if data_label != cap_label: + continue + xyz = coords[cap_index, :] + record["X"] = xyz[0] + record["Y"] = xyz[1] + record["Z"] = xyz[2] + record["sph_radius"] = sph_radius[cap_index] + record["sph_theta"] = sph_theta[cap_index] + record["sph_phi"] = sph_phi[cap_index] + record["theta"] = polar_theta[cap_index] + record["radius"] = polar_radius[cap_index] + break + else: + clear_chanloc(record, numeric_null) + + +def _assign_labelscheme_from_existing_labels(EEG: MutableMapping[str, Any]) -> None: + candidate_locs = {"fp1", "fp2", "fz", "t3", "cz", "t4", "t5", "p3", "pz", "p4", "t6", "o1", "o2"} + if any(chanloc["labels"].lower() in candidate_locs for chanloc in EEG["chanlocs"]): + EEG["etc"]["labelscheme"] = "10-20" + else: + EEG["etc"]["labelscheme"] = "unknown" diff --git a/src/eegprep/plugins/EEG_BIDS/raw.py b/src/eegprep/plugins/EEG_BIDS/raw.py new file mode 100644 index 00000000..1a9f336c --- /dev/null +++ b/src/eegprep/plugins/EEG_BIDS/raw.py @@ -0,0 +1,349 @@ +"""Raw EEG file readers used by EEG-BIDS import workflows.""" + +from __future__ import annotations + +from collections.abc import Callable +import copy +import logging +import os +from typing import Any + +import numpy as np + +from eegprep.functions.miscfunc.misc import ToolError + +logger = logging.getLogger(__name__) + + +def load_raw_eeg_file( + filename: str, + *, + dtype: np.dtype, + numeric_null: Any, + warning: Callable[[str], None], + verbose: bool = True, +) -> tuple[dict[str, Any], float, np.ndarray, dict[str, Any]]: + """Load a supported BIDS raw EEG file into an EEG dictionary.""" + _path, ext = os.path.splitext(filename) + ext = ext.lower() + basename = os.path.basename(filename) + report: dict[str, Any] = {} + + if ext == ".set": + from eegprep.functions.popfunc.pop_loadset import pop_loadset + + eeg = pop_loadset(filename) + eeg["data"] = eeg["data"].astype(dtype) + report["ImporterUsed"] = "pop_loadset" + srate = eeg["srate"] + times_sec = eeg["times"] / 1000.0 + return eeg, srate, times_sec, report + + if ext in [".edf", ".bdf", ".vhdr"]: + eeg, srate, times_sec, raw_report = _load_neo_raw_file( + filename, + ext=ext, + basename=basename, + dtype=dtype, + numeric_null=numeric_null, + warning=warning, + verbose=verbose, + ) + report.update(raw_report) + return eeg, srate, times_sec, report + + if ext in [".fdt", ".vmrk", ".eeg"]: + raise ValueError( + f"pop_load_frombids should be called with the main data file, but was called on a sidecar file: {filename}." + ) + raise ValueError(f"Unsupported file format: {ext}. Supported formats are .set, .edf, .bdf, .vhdr.") + + +def _load_neo_raw_file( + filename: str, + *, + ext: str, + basename: str, + dtype: np.dtype, + numeric_null: Any, + warning: Callable[[str], None], + verbose: bool, +) -> tuple[dict[str, Any], float, np.ndarray, dict[str, Any]]: + from neo import NeoReadWriteError + + if ext == ".vhdr": + from neo.rawio.brainvisionrawio import BrainVisionRawIO as NeoIO + + importer_used = "neo.rawio.brainvisionrawio.BrainVisionRawIO" + elif ext in [".edf", ".bdf"]: + from neo.rawio.edfrawio import EDFRawIO as NeoIO + + importer_used = "neo.rawio.edfrawio.EDFRawIO" + else: + raise ValueError(f"Unexpected file format: {ext}. Please add support for this format if needed.") + + io = NeoIO(filename) + try: + io.parse_header() + except NeoReadWriteError as exc: + classname = io.__class__.__name__ + raise ToolError( + f"Encountered error with NEO {classname} importer on {filename!r}: {exc}. Skipping file." + ) from exc + + if (n_streams := io.signal_streams_count()) > 1: + warning(f"The raw data file {filename} appears to contain more than one stream; using only the first stream.") + elif not n_streams: + raise ValueError(f"The raw data file {filename} does not contain any data.") + if (n_blocks := io.block_count()) > 1: + warning( + f"The raw data file {filename} appears to contain " + f"more than one recording; this is not meaningful " + f"in a BIDS context; using only the first block." + ) + elif not n_blocks: + raise ValueError(f"The raw data file {filename} does not contain any data.") + if (n_segments := io.segment_count(0)) > 1: + raise NotImplementedError( + f"The raw data file {filename} appears to contain " + f"more than one segment; This importer currently " + f"only supports continuous EEG data." + ) + elif not n_segments: + raise ValueError(f"The raw data file {filename} does not contain any data.") + + n_channels = io.signal_channels_count(0) + n_samples = io.get_signal_size(0, 0, 0) + channel_indexes = list(range(n_channels)) + report = { + "ImporterUsed": importer_used, + "NumStreams": n_streams, + "NumBlocks": n_blocks, + "NumSegments": n_segments, + } + + if verbose: + logger.info(" retrieving EEG data from file...") + data_t = io.get_analogsignal_chunk( + block_index=0, + seg_index=0, + channel_indexes=channel_indexes, + i_start=None, + i_stop=None, + ) + old_scale = np.std(data_t, axis=0) + data_t = io.rescale_signal_raw_to_float(data_t, dtype=dtype, channel_indexes=channel_indexes) + new_scale = np.std(data_t, axis=0) + scale_ratios = new_scale / old_scale + unique_ratios = np.unique(scale_ratios) + if len(unique_ratios) == 1: + report["ScaleApplied"] = unique_ratios.item() + else: + report["ScalesApplied"] = scale_ratios.tolist() + + srate = io.get_signal_sampling_rate(0) + t0 = io.get_signal_t_start(block_index=0, seg_index=0, stream_index=0) + report["RawStartTime"] = t0 + time_offset = getattr(io, "_global_time", 0.0) + report["StartTimeOffset"] = time_offset + t0 += time_offset + report["CombinedStartTime"] = t0 + times_sec = t0 + np.arange(0, n_samples, dtype=float) / srate + + channels = io.header["signal_channels"] + try: + units = channels["units"].tolist() + except KeyError: + units = ["uV"] * n_channels + unique_units = np.unique(units) + if len(unique_units) == 1 and unique_units[0] not in ("uV", "microvolts"): + warning( + f"Your channel unit does not appear to be in microvolts (uV) " + f"but is documented instead as {unique_units[0]}. EEG scale might be incorrect. " + ) + + labels = channels["name"].tolist() + chanlocs = np.asarray([_empty_chanloc(label, numeric_null) for label in labels]) + _apply_neo_channel_coordinates(io, ext, filename, chanlocs, n_channels, warning=warning, verbose=verbose) + events = _read_neo_events(io, ext, times_sec, numeric_null, verbose=verbose) + + eeg = { + "setname": "", + "filename": basename, + "filepath": os.path.dirname(filename), + "subject": "", + "group": "", + "condition": "", + "session": numeric_null, + "comments": "", + "nbchan": n_channels, + "trials": 1, + "pnts": n_samples, + "srate": srate, + "xmin": times_sec[0], + "xmax": times_sec[-1], + "times": times_sec * 1000, + "data": data_t.T, + "icaact": numeric_null, + "icawinv": numeric_null, + "icasphere": numeric_null, + "icaweights": numeric_null, + "icachansind": numeric_null, + "chanlocs": chanlocs, + "urchanlocs": numeric_null, + "chaninfo": { + "plotrad": numeric_null, + "shrink": numeric_null, + "nosedir": "+X", + "nodatchans": numeric_null, + "icachansind": numeric_null, + }, + "ref": "unknown", + "event": events, + "urevent": copy.deepcopy(events), + "eventdescription": [], + "epoch": numeric_null, + "epochdescription": [], + "reject": {}, + "stats": {}, + "specdata": numeric_null, + "specicaact": numeric_null, + "splinefile": "", + "icasplinefile": "", + "dipfit": numeric_null, + "history": "", + "saved": "justloaded", + "etc": {}, + "run": numeric_null, + } + return eeg, srate, times_sec, report + + +def _empty_chanloc(label: str, numeric_null: Any) -> dict[str, Any]: + return { + "labels": label, + "sph_radius": numeric_null, + "sph_theta": numeric_null, + "sph_phi": numeric_null, + "theta": numeric_null, + "radius": numeric_null, + "X": numeric_null, + "Y": numeric_null, + "Z": numeric_null, + "type": "EEG", + "ref": numeric_null, + } + + +def _apply_neo_channel_coordinates( + io: Any, + ext: str, + filename: str, + chanlocs: np.ndarray, + n_channels: int, + *, + warning: Callable[[str], None], + verbose: bool, +) -> None: + if ext == ".vhdr": + if verbose: + logger.info(" parsing VHDR-specific channel locations...") + try: + annots = io.raw_annotations["blocks"][0]["segments"][0]["signals"][0]["__array_annotations__"] + sph_radius = annots["coordinates_0"] + theta = annots["coordinates_1"] + phi = annots["coordinates_2"] + valid = (sph_radius != 0) | (theta != 0) | (phi != 0) + sph_theta = phi - 90 * np.sign(theta) + sph_phi = -np.abs(theta) + 90 + except KeyError: + warning(f"Channel coordinates not found in {filename}. Using default values for channel locations.") + valid = np.zeros(n_channels, dtype=bool) + elif ext in [".edf", ".bdf"]: + valid = np.zeros(n_channels, dtype=bool) + else: + raise ValueError( + f"Unsupported file format for channel coordinates extraction: {ext}. " + f"Supported formats are .edf, .bdf, .vhdr." + ) + + if not np.any(valid): + return + + if verbose: + logger.info(" applying channel locations from EEG file...") + for loc, val, sph_r, sph_p, sph_t in zip(chanlocs, valid, sph_radius, sph_phi, sph_theta): + if not val: + continue + loc["sph_radius"] = sph_r + loc["sph_theta"] = sph_t + loc["sph_phi"] = sph_p + az = sph_p + horiz = sph_t + loc["theta"] = -horiz + loc["radius"] = 0.5 - az / 180 + az = np.deg2rad(sph_t) + elev = np.deg2rad(sph_p) + loc["Z"] = sph_r * np.sin(elev) + loc["X"] = sph_r * np.cos(elev) * np.cos(az) + loc["Y"] = sph_r * np.cos(elev) * np.sin(az) + + +def _read_neo_events(io: Any, ext: str, times_sec: np.ndarray, numeric_null: Any, *, verbose: bool) -> Any: + if (event_channels := io.event_channels_count()) <= 0: + return numeric_null + + if verbose: + logger.info(" reading in event data from EEG file...") + all_times = [] + all_durations = [] + all_channels = [] + all_data = [] + for event_channel_index in range(event_channels): + event_times, event_durations, event_labels = io.get_event_timestamps( + block_index=0, + seg_index=0, + event_channel_index=event_channel_index, + t_start=None, + t_stop=None, + ) + all_times.extend(io.rescale_event_timestamp(event_times)) + if event_durations is not None: + all_durations.extend(event_durations) + else: + all_durations.extend([1] * len(event_times)) + all_channels.extend(np.repeat(io.header["event_channels"][event_channel_index]["name"], len(event_times))) + all_data.extend(event_labels) + + if ext == ".vhdr": + event_types = all_data + event_codes = all_channels + elif ext in [".edf", ".bdf"]: + event_types = [str(value) for value in all_data] + event_codes = [str(channel) for channel in all_channels] + else: + raise ValueError( + f"Unsupported file format for event extraction: {ext}. Supported formats are .edf, .bdf, .vhdr." + ) + + event_latencies = np.searchsorted(times_sec, all_times) + event_durations = np.array(all_durations, dtype=float) + urevents = np.arange(len(all_times)) + return np.array( + [ + { + "duration": duration, + "latency": latency, + "type": event_type or ("boundary" if code == "New Segment" else ""), + "code": code, + "urevent": urevent, + } + for duration, latency, event_type, code, urevent in zip( + event_durations, + event_latencies, + event_types, + event_codes, + urevents, + ) + ] + ) diff --git a/src/eegprep/plugins/ICLabel/_prop_browser.py b/src/eegprep/plugins/ICLabel/_prop_browser.py new file mode 100644 index 00000000..36143b87 --- /dev/null +++ b/src/eegprep/plugins/ICLabel/_prop_browser.py @@ -0,0 +1,778 @@ +"""Matplotlib rendering for ICLabel extended property dashboards.""" + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Any + +from matplotlib.widgets import Button +import matplotlib.pyplot as plt +import numpy as np + +from eegprep.functions.guifunc.pophelp import pophelp +from eegprep.functions.popfunc._property_browser import property_activity_browser +from eegprep.functions.popfunc._rejection import ( + component_rejection_flags, + set_component_rejection_flag, +) +from eegprep.functions.sigprocfunc.topoplot import topoplot +from eegprep.plugins.ICLabel._prop_numerics import ( + DipfitData, + ExtendedPropertyData, + build_extended_property_data, + component_count, + component_rejection_status, +) +from eegprep.plugins.dipfit._mri import dipfit_mri_slices, load_standard_mri_volume + + +_DASHBOARD_SIZE = (12.0, 7.0) +_SCROLL_SECONDS = 5.0 +_EVENT_COLORS = ( + "#1f77b4", + "#2ca02c", + "#9467bd", + "#17becf", + "#ff7f0e", + "#8c564b", + "#e377c2", + "#7f7f7f", +) +_DIPFIT_COLORS = ("#00cc00", "#d336d3", "#e0c21a") +_REJECT_COLOR = "#ff9999" +_ACCEPT_COLOR = "#bfffbf" +_CONTROL_COLOR = "#e6e6e6" +_DISABLED_CONTROL_COLOR = "#d0d0d0" + + +@dataclass(frozen=True) +class ActivityTraceData: + """Inline scroll-plot trace using EEGLAB-compatible event coordinates.""" + + x_values: np.ndarray + y_values: np.ndarray + times_ms: np.ndarray + pnts: int + epoched: bool + + +def build_navigable_dashboard( + EEG: dict[str, Any], + typecomp: int, + indices: list[int], + winhandle: Any, + spec_opt: Any, + erp_opt: Any, + scroll_event: int | bool, + classifier_name: str, + *, + fig: Any, + show_activity: bool, + reject_callback: Any | None, +) -> Any: + """Build the navigable Matplotlib property dashboard.""" + figure = fig if fig is not None else plt.figure(figsize=_DASHBOARD_SIZE) + state = { + "EEG": EEG, + "typecomp": int(typecomp), + "indices": tuple(indices), + "position": 0, + "spec_opt": spec_opt, + "erp_opt": erp_opt, + "scroll_event": int(bool(scroll_event)), + "classifier_name": classifier_name, + "show_activity": bool(show_activity), + "winhandle": winhandle, + "reject_callback": reject_callback, + "rejection_pending": _initial_rejection_state(EEG, int(typecomp), indices), + } + figure.eegprep_dashboard_state = state + _render_dashboard(figure) + return figure + + +def _render_dashboard(figure: Any) -> None: + state = figure.eegprep_dashboard_state + indices = state["indices"] + position = int(state["position"]) + index = int(indices[position]) + dashboard = build_extended_property_data( + state["EEG"], + state["typecomp"], + index, + spec_opt=state["spec_opt"], + erp_opt=state["erp_opt"], + classifier_name=state["classifier_name"], + ) + figure.clf() + figure.set_size_inches(*_DASHBOARD_SIZE, forward=True) + figure.patch.set_facecolor((0.93, 0.96, 1.0)) + _set_window_title(figure, dashboard.figure_title) + has_rejection_controls = _has_rejection_controls(state) + bottom = ( + 0.17 + if has_rejection_controls and len(indices) > 1 + else 0.125 + if has_rejection_controls + else 0.12 + if len(indices) > 1 + else 0.075 + ) + if dashboard.dipfit is None: + grid = figure.add_gridspec( + 2, + 4, + left=0.055, + right=0.97, + top=0.88, + bottom=bottom, + wspace=0.65, + hspace=0.55, + width_ratios=(1.2, 0.95, 1.25, 1.25), + height_ratios=(0.9, 1.2), + ) + topo_ax = figure.add_subplot(grid[0, 0]) + if dashboard.classifier is None: + class_ax = None + activity_ax = figure.add_subplot(grid[0, 1:]) + else: + class_ax = figure.add_subplot(grid[0, 1]) + activity_ax = figure.add_subplot(grid[0, 2:]) + image_ax = figure.add_subplot(grid[1, :2]) + dipfit_axes = [] + spectrum_ax = figure.add_subplot(grid[1, 2:]) + else: + grid = figure.add_gridspec( + 2, + 5, + left=0.055, + right=0.97, + top=0.88, + bottom=bottom, + wspace=0.58, + hspace=0.55, + width_ratios=(1.15, 0.9, 0.85, 1.25, 1.25), + height_ratios=(0.9, 1.2), + ) + topo_ax = figure.add_subplot(grid[0, 0]) + if dashboard.classifier is None: + class_ax = None + activity_ax = figure.add_subplot(grid[0, 1:]) + else: + class_ax = figure.add_subplot(grid[0, 1]) + activity_ax = figure.add_subplot(grid[0, 2:]) + image_ax = figure.add_subplot(grid[1, :2]) + dipfit_grid = grid[1, 2].subgridspec(3, 1, hspace=0.04) + dipfit_axes = [figure.add_subplot(dipfit_grid[row, 0]) for row in range(3)] + spectrum_ax = figure.add_subplot(grid[1, 3:]) + + _plot_topography(topo_ax, dashboard) + if class_ax is not None: + _plot_classifier(class_ax, dashboard) + events = state["EEG"].get("event", []) if bool(state["scroll_event"]) else [] + _plot_activity(activity_ax, dashboard, events) + _plot_activity_image(image_ax, dashboard) + if dipfit_axes: + _plot_dipfit(dipfit_axes, dashboard.dipfit) + _plot_spectrum(spectrum_ax, dashboard) + figure.suptitle(dashboard.figure_title, fontsize=14, fontweight="bold") + figure.eegprep_dashboard_data = dashboard + figure.eegprep_activity_view = property_activity_browser( + state["EEG"], + dashboard.typecomp, + dashboard.index, + scroll_event=state["scroll_event"], + show=state["show_activity"], + ) + if len(indices) > 1: + _add_navigation_controls( + figure, + bottom=0.075 if has_rejection_controls else 0.025, + count_y=0.1 if has_rejection_controls else 0.092, + ) + else: + figure.eegprep_dashboard_navigation = {} + figure.eegprep_dashboard_navigation_buttons = () + if has_rejection_controls: + _add_rejection_controls(figure, dashboard) + else: + figure.eegprep_dashboard_rejection = {} + figure.eegprep_dashboard_rejection_buttons = {} + figure.eegprep_dashboard_rejection_button_list = () + buttons = [] + buttons.extend(getattr(figure, "eegprep_dashboard_navigation_buttons", ())) + buttons.extend(getattr(figure, "eegprep_dashboard_rejection_button_list", ())) + figure.eegprep_dashboard_buttons = tuple(buttons) + figure.canvas.draw_idle() + + +def _plot_topography(axis: Any, dashboard: ExtendedPropertyData) -> None: + if dashboard.typecomp: + topoplot( + dashboard.topography_values, + dashboard.topography_chanlocs, + axes=axis, + style="blank", + electrodes="off", + ) + else: + topoplot( + dashboard.topography_values, + dashboard.topography_chanlocs, + axes=axis, + electrodes="on", + colorbar=False, + ) + axis.set_title(dashboard.topography_title, fontsize=12, fontweight="normal") + if dashboard.pvaf is not None: + axis.text( + 0.5, + -0.13, + f"{{% scalp data var. accounted for}}: {dashboard.pvaf:.1f}%", + transform=axis.transAxes, + ha="center", + va="top", + fontsize=9, + ) + + +def _plot_classifier(axis: Any, dashboard: ExtendedPropertyData) -> None: + assert dashboard.classifier is not None + assert dashboard.class_probabilities is not None + labels = list(reversed(dashboard.classifier.classes)) + probabilities = np.asarray(dashboard.class_probabilities, dtype=float)[::-1] + y_values = np.arange(len(labels)) + axis.barh(y_values, probabilities, color="#4c78a8") + axis.set_yticks(y_values, labels) + axis.set_xlim(0.0, 1.0) + axis.set_xticks([0.0, 0.5, 1.0]) + axis.grid(axis="x", alpha=0.3) + axis.set_xlabel("Probability") + axis.set_title(dashboard.classifier.name, fontsize=12, fontweight="normal") + for y_value, probability in zip(y_values, probabilities): + axis.text(0.5, y_value, f"{probability * 100:.1f}%", ha="center", va="center", fontsize=8) + + +def _plot_activity(axis: Any, dashboard: ExtendedPropertyData, events: Any) -> None: + trace = _activity_trace(dashboard) + axis.plot(trace.x_values, trace.y_values, color="black", linewidth=0.85) + axis.axhline(0.0, color="0.75", linewidth=0.6) + _plot_epoch_markers(axis, trace) + _plot_event_markers(axis, trace, events) + axis.set_title(dashboard.activity_title, fontsize=12, fontweight="normal") + axis.set_xlabel("Time (ms)") + axis.set_ylabel("uV") + axis.grid(True, alpha=0.2) + _format_scrollplot_axis(axis, trace) + + +def _plot_activity_image(axis: Any, dashboard: ExtendedPropertyData) -> None: + image = np.asarray(dashboard.image_data, dtype=float) + handle = axis.imshow( + image, + aspect="auto", + origin="lower", + extent=dashboard.image_extent, + cmap="RdBu_r", + ) + axis.set_title(dashboard.image_title, fontsize=12, fontweight="normal") + axis.set_xlabel("Time (ms)" if dashboard.activity.shape[2] > 1 else "Data") + axis.set_ylabel("Epoch" if dashboard.activity.shape[2] > 1 else "Data") + plt.colorbar(handle, ax=axis, fraction=0.046, pad=0.035) + + +def _plot_spectrum(axis: Any, dashboard: ExtendedPropertyData) -> None: + axis.plot(dashboard.spectrum_freqs, dashboard.spectrum_power, color="black", linewidth=1.0) + axis.set_title(dashboard.spectrum_title, fontsize=12, fontweight="normal") + axis.set_xlabel("Frequency (Hz)") + axis.set_ylabel("Power 10*log10(uV^2/Hz)") + axis.grid(True, alpha=0.25) + finite = np.isfinite(dashboard.spectrum_power) + if np.any(finite): + finite_values = dashboard.spectrum_power[finite] + low = float(np.min(finite_values)) + high = float(np.max(finite_values)) + if high == low: + padding = 1.0 if low == 0.0 else abs(low) * 0.05 + low -= padding + high += padding + axis.set_ylim(low, high) + + +def _plot_dipfit(axes: list[Any], dipfit: DipfitData | None) -> None: + assert dipfit is not None + volume = load_standard_mri_volume() + for axis, mri_slice in zip(axes, dipfit_mri_slices(volume, dipfit.positions)): + axis.set_facecolor("black") + axis.imshow(mri_slice.image, cmap="gray", origin="lower", extent=mri_slice.extent, interpolation="nearest") + axis.set_aspect("equal", adjustable="box") + axis.set_xticks([]) + axis.set_yticks([]) + for spine in axis.spines.values(): + spine.set_visible(False) + _plot_dipfit_points(axis, dipfit, mri_slice.x_axis, mri_slice.y_axis) + axes[0].set_title("Dipole Position", fontsize=12, fontweight="normal", pad=7) + _plot_dipfit_values(axes[-1], dipfit) + + +def _plot_dipfit_points(axis: Any, dipfit: DipfitData, x_index: int, y_index: int) -> None: + for row, position in enumerate(dipfit.positions): + color = _DIPFIT_COLORS[row % len(_DIPFIT_COLORS)] + x_value = float(position[x_index]) + y_value = float(position[y_index]) + axis.plot( + x_value, + y_value, + marker="o", + markersize=5.5, + color=color, + markeredgecolor="white", + markeredgewidth=0.45, + ) + if dipfit.moments is not None and row < dipfit.moments.shape[0]: + _plot_dipfit_moment(axis, x_value, y_value, dipfit.moments[row], x_index, y_index, color) + + +def _plot_dipfit_moment( + axis: Any, + x_value: float, + y_value: float, + moment: np.ndarray, + x_index: int, + y_index: int, + color: str, +) -> None: + dx = float(moment[x_index]) + dy = float(moment[y_index]) + norm = float(np.hypot(dx, dy)) + if not np.isfinite(norm) or norm <= 0.0: + return + axis.arrow( + x_value, + y_value, + dx / norm * 18.0, + dy / norm * 18.0, + color=color, + width=0.8, + head_width=5.0, + length_includes_head=True, + alpha=0.9, + ) + + +def _plot_dipfit_values(axis: Any, dipfit: DipfitData) -> None: + lines = [] + if dipfit.rv_percent is not None: + lines.append(f"RV: {dipfit.rv_percent:.1f}%") + if dipfit.dmr is not None: + lines.append(f"DMR: {dipfit.dmr:.1f}") + if lines: + axis.text( + 0.5, + -0.02, + "\n".join(lines), + transform=axis.transAxes, + color="black", + fontsize=8, + ha="center", + va="top", + ) + + +def _add_navigation_controls(figure: Any, *, bottom: float, count_y: float) -> None: + previous_axis = figure.add_axes((0.37, bottom, 0.105, 0.05)) + next_axis = figure.add_axes((0.525, bottom, 0.105, 0.05)) + previous_button = Button(previous_axis, "Previous") + next_button = Button(next_axis, "Next") + + def previous(_event: Any = None) -> None: + _navigate_dashboard(figure, -1) + + def next_(_event: Any = None) -> None: + _navigate_dashboard(figure, 1) + + previous_button.on_clicked(previous) + next_button.on_clicked(next_) + figure.eegprep_dashboard_navigation_buttons = (previous_button, next_button) + figure.eegprep_dashboard_navigation = {"previous": previous, "next": next_} + state = figure.eegprep_dashboard_state + figure.text( + 0.5, + count_y, + f"{int(state['position']) + 1} / {len(state['indices'])}", + ha="center", + va="center", + fontsize=9, + ) + + +def _add_rejection_controls(figure: Any, dashboard: ExtendedPropertyData) -> None: + state = figure.eegprep_dashboard_state + index = int(dashboard.index) + rejected = _pending_rejection_status(state, index) + cancel_button = Button(figure.add_axes((0.2, 0.015, 0.1, 0.045)), "Cancel", color=_CONTROL_COLOR) + values_button = Button(figure.add_axes((0.325, 0.015, 0.1, 0.045)), "Values", color=_CONTROL_COLOR) + status_button = Button( + figure.add_axes((0.45, 0.015, 0.1, 0.045)), + _rejection_label(rejected), + color=_rejection_color(rejected), + hovercolor=_rejection_color(rejected), + ) + help_button = Button(figure.add_axes((0.575, 0.015, 0.1, 0.045)), "HELP", color=_CONTROL_COLOR) + ok_button = Button(figure.add_axes((0.7, 0.015, 0.1, 0.045)), "OK", color=_CONTROL_COLOR) + + if not _component_values_available(state["EEG"]): + values_button.set_active(False) + values_button.ax.set_facecolor(_DISABLED_CONTROL_COLOR) + + def cancel(_event: Any = None) -> None: + plt.close(figure) + + def values(_event: Any = None) -> None: + if _component_values_available(state["EEG"]): + _show_component_values(state["EEG"], index, _pending_rejection_status(state, index)) + + def toggle(_event: Any = None) -> None: + state["rejection_pending"][index] = not _pending_rejection_status(state, index) + _style_rejection_status_button(status_button, state["rejection_pending"][index]) + figure.canvas.draw_idle() + + def help_(_event: Any = None) -> None: + pophelp("pop_prop_extended") + + def ok(_event: Any = None) -> None: + _commit_rejection_state(figure) + plt.close(figure) + + cancel_button.on_clicked(cancel) + values_button.on_clicked(values) + status_button.on_clicked(toggle) + help_button.on_clicked(help_) + ok_button.on_clicked(ok) + figure.eegprep_dashboard_rejection_button_list = ( + cancel_button, + values_button, + status_button, + help_button, + ok_button, + ) + figure.eegprep_dashboard_rejection_buttons = { + "cancel": cancel_button, + "values": values_button, + "status": status_button, + "help": help_button, + "ok": ok_button, + } + figure.eegprep_dashboard_rejection = { + "cancel": cancel, + "values": values, + "toggle": toggle, + "help": help_, + "ok": ok, + "pending": state["rejection_pending"], + } + + +def _navigate_dashboard(figure: Any, step: int) -> None: + state = figure.eegprep_dashboard_state + state["position"] = (int(state["position"]) + int(step)) % len(state["indices"]) + _render_dashboard(figure) + + +def _initial_rejection_state(EEG: dict[str, Any], typecomp: int, indices: list[int]) -> dict[int, bool]: + if int(typecomp): + return {} + total = component_count(EEG) + flags = component_rejection_flags(EEG, total, create=False) + return {int(index): bool(flags[int(index) - 1]) for index in indices} + + +def _has_rejection_controls(state: dict[str, Any]) -> bool: + return int(state["typecomp"]) == 0 and bool(state["indices"]) + + +def _pending_rejection_status(state: dict[str, Any], component_index: int) -> bool: + pending = state["rejection_pending"] + index = int(component_index) + if index not in pending: + pending[index] = component_rejection_status(state["EEG"], index) + return bool(pending[index]) + + +def _rejection_label(rejected: bool) -> str: + return "REJECT" if rejected else "ACCEPT" + + +def _rejection_color(rejected: bool) -> str: + return _REJECT_COLOR if rejected else _ACCEPT_COLOR + + +def _style_rejection_status_button(button: Button, rejected: bool) -> None: + button.label.set_text(_rejection_label(rejected)) + button.ax.set_facecolor(_rejection_color(rejected)) + button.color = _rejection_color(rejected) + button.hovercolor = _rejection_color(rejected) + + +def _commit_rejection_state(figure: Any) -> None: + state = figure.eegprep_dashboard_state + if not _has_rejection_controls(state): + return + EEG = state["EEG"] + total = component_count(EEG) + committed: dict[int, bool] = {} + for index in sorted(state["rejection_pending"]): + rejected = bool(state["rejection_pending"][index]) + set_component_rejection_flag(EEG, index, rejected, total) + _style_rejection_winhandle(state["winhandle"], index, rejected) + committed[int(index)] = rejected + callback = state.get("reject_callback") + if callback is not None: + callback(EEG, dict(committed)) + + +def _style_rejection_winhandle(winhandle: Any, component_index: int, rejected: bool) -> None: + if _is_empty_winhandle(winhandle): + return + handle = winhandle + if isinstance(winhandle, Mapping): + handle = winhandle.get(int(component_index)) + if isinstance(handle, Button): + handle.ax.set_facecolor(_rejection_color(rejected)) + + +def _is_empty_winhandle(winhandle: Any) -> bool: + if winhandle is None: + return True + if isinstance(winhandle, (int, float, np.integer, np.floating)): + return bool(winhandle == 0) or bool(np.isnan(float(winhandle))) + return False + + +def _component_values_available(EEG: dict[str, Any]) -> bool: + stats = EEG.get("stats") + if not isinstance(stats, dict): + return False + return np.asarray(stats.get("compenta", [])).size > 0 + + +def _show_component_values(EEG: dict[str, Any], component_index: int, rejected: bool) -> None: + values = _component_value_lines(EEG, component_index, rejected) + figure = plt.figure(figsize=(3.4, 3.4)) + manager = getattr(figure.canvas, "manager", None) + if manager is not None: + manager.set_window_title("Statistics of the component") + axis = figure.add_subplot(1, 1, 1) + axis.axis("off") + axis.text(0.04, 0.96, "\n".join(values), va="top", ha="left", fontsize=9, family="monospace") + close_button = Button(figure.add_axes((0.375, 0.03, 0.25, 0.08)), "Close", color=_CONTROL_COLOR) + close_button.on_clicked(lambda _event=None: plt.close(figure)) + setattr(figure, "eegprep_component_values_button", close_button) + figure.tight_layout(rect=(0, 0.12, 1, 1)) + figure.canvas.draw_idle() + + +def _component_value_lines(EEG: dict[str, Any], component_index: int, rejected: bool) -> list[str]: + raw_stats = EEG.get("stats") + stats = raw_stats if isinstance(raw_stats, dict) else {} + raw_reject = EEG.get("reject") + reject = raw_reject if isinstance(raw_reject, dict) else {} + index = int(component_index) + return [ + "(", + f"Entropy of component activity {_indexed_stat(stats.get('compenta'), index):>8}", + f"> Rejection threshold {_scalar_stat(reject.get('threshentropy')):>8}", + "", + " AND ----", + "", + f"Kurtosis of component activity {_indexed_stat(stats.get('compkurta'), index):>8}", + f"> Rejection threshold {_scalar_stat(reject.get('threshkurtact')):>8}", + "", + ") OR ----", + "", + f"Kurtosis distribution {_indexed_stat(stats.get('compkurtdist'), index):>8}", + f"> Rejection threshold {_scalar_stat(reject.get('threshkurtdist')):>8}", + "", + f"Current thresholds suggest to {_rejection_label(rejected)} the component", + "", + "After manually accepting/rejecting the component, recalibrate", + "thresholds before applying automatic rejection to other datasets.", + ] + + +def _indexed_stat(values: Any, component_index: int) -> str: + vector = np.asarray(values, dtype=float).ravel() + index = int(component_index) - 1 + if index < 0 or index >= vector.size or not np.isfinite(vector[index]): + return "----" + return f"{float(vector[index]):2.2f}" + + +def _scalar_stat(value: Any) -> str: + vector = np.asarray(value, dtype=float).ravel() + if vector.size == 0 or not np.isfinite(vector[0]): + return "----" + return f"{float(vector[0]):2.2f}" + + +def _activity_trace(dashboard: ExtendedPropertyData) -> ActivityTraceData: + trace = np.asarray(dashboard.activity[0], dtype=float) + if trace.ndim == 1: + trace = trace[:, np.newaxis] + pnts = int(dashboard.times_ms.size) + srate = _srate_from_times(dashboard.times_ms) + window_samples = max(1, int(round(_SCROLL_SECONDS * srate))) + if trace.shape[1] == 1: + sample_count = min(pnts, window_samples) + return ActivityTraceData( + x_values=np.arange(1, sample_count + 1, dtype=float), + y_values=trace[:sample_count, 0], + times_ms=dashboard.times_ms, + pnts=pnts, + epoched=False, + ) + flat = trace.T.reshape(-1) + sample_count = min(flat.size, window_samples) + return ActivityTraceData( + x_values=np.arange(1, sample_count + 1, dtype=float), + y_values=flat[:sample_count], + times_ms=dashboard.times_ms, + pnts=pnts, + epoched=True, + ) + + +def _plot_event_markers( + axis: Any, + trace: ActivityTraceData, + events: Any, +) -> None: + event_items = _event_items(events) + if not event_items: + return + first = float(np.nanmin(trace.x_values)) + last = float(np.nanmax(trace.x_values)) + colors = _event_color_map(event_items) + for event in event_items: + if "latency" not in event: + continue + try: + latency = float(_event_scalar(event["latency"])) + except (TypeError, ValueError): + continue + x_value = latency + if first <= x_value <= last: + label = str(_event_scalar(event.get("type", ""))) + axis.axvline(x_value, color=colors[label], linestyle="--", linewidth=0.75) + axis.text( + x_value, + 0.98, + label, + color=colors[label], + transform=axis.get_xaxis_transform(), + rotation=45, + fontsize=8, + ) + + +def _plot_epoch_markers(axis: Any, trace: ActivityTraceData) -> None: + if not trace.epoched: + return + first = float(np.nanmin(trace.x_values)) + last = float(np.nanmax(trace.x_values)) + start = int(np.ceil(first / trace.pnts) * trace.pnts) + for boundary in range(start, int(np.floor(last / trace.pnts) * trace.pnts) + 1, trace.pnts): + if boundary < first or boundary > last: + continue + axis.axvline(float(boundary), color="red", linestyle="-", linewidth=0.75) + axis.text( + float(boundary), + 0.98, + f"epoch {boundary // trace.pnts}", + color="red", + transform=axis.get_xaxis_transform(), + rotation=45, + fontsize=8, + ) + + +def _format_scrollplot_axis(axis: Any, trace: ActivityTraceData) -> None: + first = float(trace.x_values[0]) + last = float(trace.x_values[-1]) + axis.set_xlim(first, last) + tick_count = min(6, trace.x_values.size) + ticks = np.unique(np.linspace(first, last, tick_count).round().astype(int)) + labels = [] + for tick in ticks: + sample = (int(tick) - 1) % trace.pnts if trace.epoched else min(max(int(tick) - 1, 0), trace.pnts - 1) + labels.append(f"{trace.times_ms[sample]:g}") + axis.set_xticks(ticks) + axis.set_xticklabels(labels) + + +def _event_items(events: Any) -> list[dict[str, Any]]: + if events is None: + return [] + if isinstance(events, dict): + if "latency" not in events: + return [] + latencies = np.asarray(events["latency"], dtype=object).ravel() + types = np.asarray(events.get("type", [""] * latencies.size), dtype=object).ravel() + epochs = np.asarray(events.get("epoch", [None] * latencies.size), dtype=object).ravel() + event_items = [] + for index, latency in enumerate(latencies): + event = { + "type": _event_scalar(types[min(index, types.size - 1)]), + "latency": _event_scalar(latency), + } + epoch = _event_scalar(epochs[min(index, epochs.size - 1)]) + if epoch is not None: + event["epoch"] = epoch + event_items.append(event) + return event_items + if isinstance(events, np.ndarray): + events = events.tolist() + try: + event_values = list(events) + except TypeError: + return [] + return [event for event in event_values if isinstance(event, dict)] + + +def _event_scalar(value: Any) -> Any: + array = np.asarray(value) + if array.shape == (): + return array.item() + if array.size == 1: + item = array.ravel()[0] + return item.item() if hasattr(item, "item") else item + return value + + +def _event_color_map(event_items: list[dict[str, Any]]) -> dict[str, str]: + labels: list[str] = [] + for event in event_items: + label = str(_event_scalar(event.get("type", ""))) + if label not in labels: + labels.append(label) + return {label: _EVENT_COLORS[index % len(_EVENT_COLORS)] for index, label in enumerate(labels)} + + +def _srate_from_times(times_ms: np.ndarray) -> float: + if times_ms.size < 2: + return 1.0 + interval_ms = float(np.nanmedian(np.diff(times_ms))) + if not np.isfinite(interval_ms) or interval_ms <= 0: + return 1.0 + return 1000.0 / interval_ms + + +def _set_window_title(figure: Any, title: str) -> None: + figure.set_label(title) + manager = getattr(getattr(figure, "canvas", None), "manager", None) + if manager is not None and hasattr(manager, "set_window_title"): + manager.set_window_title(title) + + +__all__ = ["ActivityTraceData", "build_navigable_dashboard"] diff --git a/src/eegprep/plugins/ICLabel/_prop_numerics.py b/src/eegprep/plugins/ICLabel/_prop_numerics.py new file mode 100644 index 00000000..2c27ae71 --- /dev/null +++ b/src/eegprep/plugins/ICLabel/_prop_numerics.py @@ -0,0 +1,473 @@ +"""Data assembly and numerical helpers for ICLabel property dashboards.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import numpy as np + +from eegprep.functions.popfunc._plot_utils import ( + channel_labels, + component_activations, + component_channel_indices, + component_map_data, + eeg_epoch_data, + eeg_times_ms, + numeric_vector, + parse_plot_options_text, +) +from eegprep.functions.popfunc._rejection import component_rejection_flags, one_based_indices +from eegprep.functions.sigprocfunc.spectopo import compute_spectra +from eegprep.plugins.dipfit._utils import normalize_model_list + + +DEFAULT_ICLABEL_CLASSES = ("Brain", "Muscle", "Eye", "Heart", "Line Noise", "Channel Noise", "Other") + + +@dataclass(frozen=True) +class ClassifierData: + """Normalized component-classifier output from ``EEG.etc.ic_classification``.""" + + name: str + classes: tuple[str, ...] + probabilities: np.ndarray + + +@dataclass(frozen=True) +class DipfitData: + """Normalized localized DIPFIT model for one ICA component.""" + + positions: np.ndarray + moments: np.ndarray | None + rv_percent: float | None + dmr: float | None + coordformat: str + + +@dataclass(frozen=True) +class ExtendedPropertyData: + """Data assembled for one extended channel/component property dashboard.""" + + typecomp: int + index: int + label: str + figure_title: str + topography_title: str + topography_values: Any + topography_chanlocs: list[dict[str, Any]] + activity: np.ndarray + times_ms: np.ndarray + activity_title: str + image_data: np.ndarray + image_extent: tuple[float, float, float, float] + image_title: str + spectrum_freqs: np.ndarray + spectrum_power: np.ndarray + spectrum_title: str + classifier: ClassifierData | None + class_probabilities: np.ndarray | None + pvaf: float | None + dipfit: DipfitData | None + rejected: bool | None + + +def classifier_names(EEG: dict[str, Any]) -> list[str]: + """Return classifier field names available under ``EEG.etc.ic_classification``.""" + etc = EEG.get("etc") or {} + if not isinstance(etc, dict): + return [] + classifications = etc.get("ic_classification") or {} + if not isinstance(classifications, dict): + return [] + return [str(name) for name in classifications if str(name)] + + +def classifier_default_index(classifiers: list[str]) -> int: + """Return EEGLAB popup index for the default component classifier.""" + for index, name in enumerate(classifiers, start=1): + if name.lower() == "iclabel": + return index + return 1 + + +def classifier_name_from_gui(EEG: dict[str, Any], value: Any) -> str: + """Resolve a GUI popup value to a classifier field name.""" + classifiers = classifier_names(EEG) + if not classifiers: + return "" + if isinstance(value, str): + for classifier in classifiers: + if classifier.lower() == value.lower(): + return classifier + return classifiers[classifier_default_index(classifiers) - 1] + try: + index = int(value) - 1 + except (TypeError, ValueError): + index = classifier_default_index(classifiers) - 1 + if 0 <= index < len(classifiers): + return classifiers[index] + return classifiers[classifier_default_index(classifiers) - 1] + + +def resolve_classifier_data( + EEG: dict[str, Any], + classifier_name: str = "", + *, + component_total: int | None = None, + require: bool = False, +) -> ClassifierData | None: + """Return normalized classifier data or ``None`` when no classifier is available.""" + classifiers = classifier_names(EEG) + if not classifiers: + if require: + raise ValueError("No component classifier data found in EEG.etc.ic_classification") + return None + resolved_name = _resolve_classifier_name(classifiers, classifier_name) + record = (EEG.get("etc") or {})["ic_classification"][resolved_name] + if not isinstance(record, dict): + raise ValueError(f"Classifier {resolved_name!r} must be stored as a dictionary") + probabilities = np.asarray(record.get("classifications", []), dtype=float) + if probabilities.ndim != 2 or probabilities.size == 0: + raise ValueError(f"Classifier {resolved_name!r} is missing a 2-D classifications matrix") + if component_total is not None and probabilities.shape[0] != int(component_total): + raise ValueError( + f"Classifier {resolved_name!r} has {probabilities.shape[0]} rows for {component_total} ICA components" + ) + classes = _classifier_classes(record, resolved_name, probabilities.shape[1]) + return ClassifierData(resolved_name, classes, probabilities) + + +def resolve_dipfit_data(EEG: dict[str, Any], component_index: int) -> DipfitData | None: + """Return normalized DIPFIT model data for a 1-based component index.""" + models = normalize_model_list(EEG) + index = int(component_index) + if index < 1: + raise ValueError("component index must be 1-based") + if index > len(models): + return None + model = models[index - 1] + raw_positions = np.asarray(model.get("posxyz", []), dtype=float) + if raw_positions.size == 0: + return None + positions = _dipfit_matrix(raw_positions, "posxyz", index) + if not np.all(np.isfinite(positions)): + raise ValueError(f"DIPFIT model for component {index} contains non-finite posxyz values") + moments = _dipfit_moments(model.get("momxyz", []), positions.shape[0], index) + rv = _finite_float(model.get("rv")) + coordformat = "" + dipfit = EEG.get("dipfit") + if isinstance(dipfit, dict): + coordformat = str(dipfit.get("coordformat") or "") + return DipfitData( + positions=positions, + moments=moments, + rv_percent=None if rv is None else rv * 100.0, + dmr=_dipole_moment_ratio(moments), + coordformat=coordformat, + ) + + +def selected_property_indices( + EEG: dict[str, Any], + typecomp: int | bool, + values: Any, + *, + default_all: bool = True, +) -> list[int]: + """Normalize EEGLAB-facing channel/component selections to 1-based indices.""" + limit = int(EEG.get("nbchan", 0) or 0) if int(bool(typecomp)) else component_count(EEG) + return one_based_indices(values, limit=limit, default_all=default_all) + + +def component_count(EEG: dict[str, Any]) -> int: + """Return the number of available ICA components.""" + icaact = EEG.get("icaact") + if icaact is not None and np.asarray(icaact).size: + values = np.asarray(icaact) + if values.ndim >= 2: + return int(values.shape[0]) + weights = np.asarray(EEG.get("icaweights", [])) + if weights.ndim == 2 and weights.size: + return int(weights.shape[0]) + winv = np.asarray(EEG.get("icawinv", [])) + if winv.ndim == 2 and winv.size: + return int(winv.shape[1]) + return 0 + + +def has_component_classifier(EEG: dict[str, Any], classifier_name: str = "") -> bool: + """Return whether usable classifier data are available for the current ICA.""" + try: + return ( + resolve_classifier_data(EEG, classifier_name, component_total=component_count(EEG), require=False) + is not None + ) + except ValueError as exc: + if "missing a 2-D classifications matrix" in str(exc): + return False + raise + + +def component_rejection_status( + EEG: dict[str, Any], + component_index: int, + *, + component_total: int | None = None, +) -> bool: + """Return the current ``EEG.reject.gcompreject`` status for one component.""" + total = component_count(EEG) if component_total is None else int(component_total) + index = int(component_index) + if index < 1 or index > total: + raise ValueError("component index is outside available ICA components") + return bool(component_rejection_flags(EEG, total, create=False)[index - 1]) + + +def build_extended_property_data( + EEG: dict[str, Any], + typecomp: int | bool, + index: int, + *, + spec_opt: Any = None, + erp_opt: Any = None, + classifier_name: str = "", +) -> ExtendedPropertyData: + """Assemble the dashboard data for one EEGLAB-facing channel/component index.""" + del erp_opt + typecomp = int(bool(typecomp)) + data = eeg_epoch_data(EEG) + times_ms = eeg_times_ms(EEG) + if typecomp: + return _channel_dashboard_data(EEG, data, times_ms, int(index), spec_opt) + return _component_dashboard_data(EEG, data, times_ms, int(index), spec_opt, classifier_name) + + +def _channel_dashboard_data( + EEG: dict[str, Any], + data: np.ndarray, + times_ms: np.ndarray, + index: int, + spec_opt: Any, +) -> ExtendedPropertyData: + labels = channel_labels(EEG) + if index < 1 or index > data.shape[0]: + raise ValueError("channel index is outside available channels") + label = labels[index - 1] if index - 1 < len(labels) else str(index) + activity = np.array(data[index - 1 : index], dtype=float, copy=True) + spectrum_freqs, spectrum_power = _spectrum(activity, EEG, spec_opt) + image_data, image_extent, image_title = _activity_image(activity, times_ms, f"Epoched Channel {label} Activity") + return ExtendedPropertyData( + typecomp=1, + index=index, + label=label, + figure_title=f"Channel {label} - pop_prop_extended()", + topography_title=f"Channel {label}", + topography_values=index, + topography_chanlocs=list(EEG.get("chanlocs", []) or []), + activity=activity, + times_ms=times_ms, + activity_title="Channel Time Series", + image_data=image_data, + image_extent=image_extent, + image_title=image_title, + spectrum_freqs=spectrum_freqs, + spectrum_power=spectrum_power, + spectrum_title="Channel Activity Power Spectrum", + classifier=None, + class_probabilities=None, + pvaf=None, + dipfit=None, + rejected=None, + ) + + +def _component_dashboard_data( + EEG: dict[str, Any], + data: np.ndarray, + times_ms: np.ndarray, + index: int, + spec_opt: Any, + classifier_name: str, +) -> ExtendedPropertyData: + activity_all = component_activations(EEG) + maps, map_chanlocs = component_map_data(EEG) + if index < 1 or index > activity_all.shape[0]: + raise ValueError("component index is outside available ICA components") + activity = np.array(activity_all[index - 1 : index], dtype=float, copy=True) + classifier = resolve_classifier_data(EEG, classifier_name, component_total=activity_all.shape[0], require=False) + probabilities = ( + None if classifier is None else np.array(classifier.probabilities[index - 1], dtype=float, copy=True) + ) + spectrum_freqs, spectrum_power = _spectrum(activity, EEG, spec_opt) + image_data, image_extent, image_title = _activity_image(activity, times_ms, f"Epoched IC{index} Activity") + return ExtendedPropertyData( + typecomp=0, + index=index, + label=f"IC{index}", + figure_title=f"IC{index} - pop_prop_extended()", + topography_title=f"IC{index}", + topography_values=maps[:, index - 1], + topography_chanlocs=map_chanlocs, + activity=activity, + times_ms=times_ms, + activity_title=f"Scrolling IC{index} Activity", + image_data=image_data, + image_extent=image_extent, + image_title=image_title, + spectrum_freqs=spectrum_freqs, + spectrum_power=spectrum_power, + spectrum_title=f"IC{index} Activity Power Spectrum", + classifier=classifier, + class_probabilities=probabilities, + pvaf=_component_pvaf(EEG, data, maps, activity, index), + dipfit=resolve_dipfit_data(EEG, index), + rejected=component_rejection_status(EEG, index, component_total=activity_all.shape[0]), + ) + + +def _spectrum(activity: np.ndarray, EEG: dict[str, Any], spec_opt: Any) -> tuple[np.ndarray, np.ndarray]: + options = parse_plot_options_text(spec_opt) + flat = np.asarray(activity, dtype=float).reshape(1, -1) + spectra, freqs, _std = compute_spectra( + flat, + int(EEG.get("pnts", flat.shape[1]) or flat.shape[1]), + float(EEG.get("srate", 1.0) or 1.0), + winsize=_first_int(options.get("winsize")), + overlap=_first_int(options.get("overlap")) or 0, + nfft=_first_int(options.get("nfft")), + ) + return freqs, spectra[0] + + +def _activity_image( + activity: np.ndarray, + times_ms: np.ndarray, + epoched_title: str, +) -> tuple[np.ndarray, tuple[float, float, float, float], str]: + trace = np.asarray(activity[0], dtype=float) + trace = trace - float(np.nanmean(trace)) + if trace.ndim == 1: + trace = trace[:, np.newaxis] + if trace.shape[1] > 1: + image = trace.T + extent = (float(times_ms[0]), float(times_ms[-1]), 1.0, float(trace.shape[1])) + return image, extent, epoched_title + flat = trace[:, 0] + line_count = min(200, max(1, int(np.floor(np.sqrt(flat.size))))) + frame_count = max(1, flat.size // line_count) + image = flat[: line_count * frame_count].reshape(line_count, frame_count) + extent = (0.0, float(frame_count - 1), 1.0, float(line_count)) + return image, extent, "Continuous Data" + + +def _dipfit_matrix(values: np.ndarray, field_name: str, component_index: int) -> np.ndarray: + matrix = values + if matrix.ndim == 1: + matrix = matrix.reshape(1, -1) + if matrix.ndim != 2 or matrix.shape[1] < 3: + raise ValueError( + f"DIPFIT model for component {component_index} must contain {field_name} rows with 3 coordinates" + ) + return np.array(matrix[:, :3], dtype=float, copy=True) + + +def _dipfit_moments(values: Any, position_count: int, component_index: int) -> np.ndarray | None: + raw_moments = np.asarray(values, dtype=float) + if raw_moments.size == 0: + return None + moments = _dipfit_matrix(raw_moments, "momxyz", component_index) + if moments.shape[0] != position_count: + raise ValueError(f"DIPFIT model for component {component_index} must have matching posxyz and momxyz rows") + if not np.all(np.isfinite(moments)): + raise ValueError(f"DIPFIT model for component {component_index} contains non-finite momxyz values") + return moments + + +def _dipole_moment_ratio(moments: np.ndarray | None) -> float | None: + if moments is None or moments.shape[0] != 2: + return None + norms = np.linalg.norm(moments, axis=1) + if not np.all(np.isfinite(norms)) or np.any(norms <= 0.0): + return None + ratio = float(norms[0] / norms[1]) + return ratio if ratio >= 1.0 else 1.0 / ratio + + +def _finite_float(value: Any) -> float | None: + try: + numeric = float(np.asarray(value).reshape(())) + except (TypeError, ValueError): + return None + return numeric if np.isfinite(numeric) else None + + +def _component_pvaf( + EEG: dict[str, Any], + data: np.ndarray, + maps: np.ndarray, + activity: np.ndarray, + index: int, +) -> float | None: + if maps.shape[0] == 0: + return None + icachansind = component_channel_indices(EEG, data.shape[0]) + if maps.shape[0] != icachansind.size: + return None + flat_data = data[icachansind, :, :].reshape(icachansind.size, -1) + component_trace = activity.reshape(1, -1) + projection = maps[:, index - 1 : index] @ component_trace + datavar = float(np.nanmean(np.nanvar(flat_data, axis=1))) + if not np.isfinite(datavar) or datavar <= 0: + return None + projvar = float(np.nanmean(np.nanvar(flat_data - projection, axis=1))) + if not np.isfinite(projvar): + return None + return 100.0 * (1.0 - projvar / datavar) + + +def _resolve_classifier_name(classifiers: list[str], classifier_name: str) -> str: + if classifier_name: + for name in classifiers: + if name.lower() == str(classifier_name).lower(): + return name + raise ValueError(f"Classifier {classifier_name!r} was not found in EEG.etc.ic_classification") + return classifiers[classifier_default_index(classifiers) - 1] + + +def _classifier_classes(record: dict[str, Any], classifier_name: str, class_count: int) -> tuple[str, ...]: + raw_classes = record.get("classes", []) + classes = [str(item) for item in np.asarray(raw_classes, dtype=object).ravel().tolist() if str(item)] + if not classes: + if classifier_name.lower() == "iclabel" and class_count == len(DEFAULT_ICLABEL_CLASSES): + return DEFAULT_ICLABEL_CLASSES + return tuple(f"Class {index}" for index in range(1, class_count + 1)) + if len(classes) != class_count: + raise ValueError( + f"Classifier {classifier_name!r} has {class_count} probability columns but {len(classes)} class names" + ) + return tuple(classes) + + +def _first_int(value: Any) -> int | None: + vector = numeric_vector(value) + if vector.size == 0: + return None + return int(vector[0]) + + +__all__ = [ + "DEFAULT_ICLABEL_CLASSES", + "ClassifierData", + "DipfitData", + "ExtendedPropertyData", + "build_extended_property_data", + "classifier_default_index", + "classifier_name_from_gui", + "classifier_names", + "component_count", + "component_rejection_status", + "has_component_classifier", + "resolve_classifier_data", + "resolve_dipfit_data", + "selected_property_indices", +] diff --git a/src/eegprep/plugins/ICLabel/pop_prop_extended.py b/src/eegprep/plugins/ICLabel/pop_prop_extended.py index 4a23aac1..42a16a2a 100644 --- a/src/eegprep/plugins/ICLabel/pop_prop_extended.py +++ b/src/eegprep/plugins/ICLabel/pop_prop_extended.py @@ -2,116 +2,28 @@ from __future__ import annotations -from collections.abc import Mapping -from dataclasses import dataclass from typing import Any -from matplotlib.widgets import Button -import matplotlib.pyplot as plt -import numpy as np - from eegprep.functions.guifunc.inputgui import inputgui -from eegprep.functions.guifunc.pophelp import pophelp from eegprep.functions.guifunc.spec import ControlSpec, DialogSpec -from eegprep.functions.popfunc._plot_utils import ( - channel_labels, - component_activations, - component_channel_indices, - component_map_data, - eeg_epoch_data, - eeg_times_ms, - history_command, - numeric_vector, - parse_plot_options_text, -) -from eegprep.functions.popfunc._property_browser import property_activity_browser -from eegprep.functions.popfunc._rejection import ( - component_rejection_flags, - one_based_indices, - set_component_rejection_flag, +from eegprep.functions.popfunc._plot_utils import history_command +from eegprep.plugins.ICLabel._prop_browser import build_navigable_dashboard +from eegprep.plugins.ICLabel._prop_numerics import ( + DEFAULT_ICLABEL_CLASSES, + ClassifierData, + DipfitData, + ExtendedPropertyData, + build_extended_property_data, + classifier_default_index, + classifier_name_from_gui, + classifier_names, + component_count, + component_rejection_status, + has_component_classifier, + resolve_classifier_data, + resolve_dipfit_data, + selected_property_indices, ) -from eegprep.functions.sigprocfunc.spectopo import compute_spectra -from eegprep.functions.sigprocfunc.topoplot import topoplot -from eegprep.plugins.dipfit._mri import dipfit_mri_slices, load_standard_mri_volume -from eegprep.plugins.dipfit._utils import normalize_model_list - - -DEFAULT_ICLABEL_CLASSES = ("Brain", "Muscle", "Eye", "Heart", "Line Noise", "Channel Noise", "Other") -_DASHBOARD_SIZE = (12.0, 7.0) -_SCROLL_SECONDS = 5.0 -_EVENT_COLORS = ( - "#1f77b4", - "#2ca02c", - "#9467bd", - "#17becf", - "#ff7f0e", - "#8c564b", - "#e377c2", - "#7f7f7f", -) -_DIPFIT_COLORS = ("#00cc00", "#d336d3", "#e0c21a") -_REJECT_COLOR = "#ff9999" -_ACCEPT_COLOR = "#bfffbf" -_CONTROL_COLOR = "#e6e6e6" -_DISABLED_CONTROL_COLOR = "#d0d0d0" - - -@dataclass(frozen=True) -class ClassifierData: - """Normalized component-classifier output from ``EEG.etc.ic_classification``.""" - - name: str - classes: tuple[str, ...] - probabilities: np.ndarray - - -@dataclass(frozen=True) -class DipfitData: - """Normalized localized DIPFIT model for one ICA component.""" - - positions: np.ndarray - moments: np.ndarray | None - rv_percent: float | None - dmr: float | None - coordformat: str - - -@dataclass(frozen=True) -class ExtendedPropertyData: - """Data assembled for one extended channel/component property dashboard.""" - - typecomp: int - index: int - label: str - figure_title: str - topography_title: str - topography_values: Any - topography_chanlocs: list[dict[str, Any]] - activity: np.ndarray - times_ms: np.ndarray - activity_title: str - image_data: np.ndarray - image_extent: tuple[float, float, float, float] - image_title: str - spectrum_freqs: np.ndarray - spectrum_power: np.ndarray - spectrum_title: str - classifier: ClassifierData | None - class_probabilities: np.ndarray | None - pvaf: float | None - dipfit: DipfitData | None - rejected: bool | None - - -@dataclass(frozen=True) -class ActivityTraceData: - """Inline scroll-plot trace using EEGLAB-compatible event coordinates.""" - - x_values: np.ndarray - y_values: np.ndarray - times_ms: np.ndarray - pnts: int - epoched: bool def pop_prop_extended( @@ -176,7 +88,7 @@ def pop_prop_extended( command = _history_command(typecomp, indices, winhandle, spec_opt, erp_opt, scroll_event, classifier_name) figure = None if plot: - figure = _build_navigable_dashboard( + figure = build_navigable_dashboard( EEG, typecomp, indices, @@ -236,1106 +148,6 @@ def pop_prop_extended_dialog_spec(EEG: dict[str, Any], typecomp: int | bool = 1) ) -def classifier_names(EEG: dict[str, Any]) -> list[str]: - """Return classifier field names available under ``EEG.etc.ic_classification``.""" - etc = EEG.get("etc") or {} - if not isinstance(etc, dict): - return [] - classifications = etc.get("ic_classification") or {} - if not isinstance(classifications, dict): - return [] - return [str(name) for name in classifications if str(name)] - - -def classifier_default_index(classifiers: list[str]) -> int: - """Return EEGLAB popup index for the default component classifier.""" - for index, name in enumerate(classifiers, start=1): - if name.lower() == "iclabel": - return index - return 1 - - -def classifier_name_from_gui(EEG: dict[str, Any], value: Any) -> str: - """Resolve a GUI popup value to a classifier field name.""" - classifiers = classifier_names(EEG) - if not classifiers: - return "" - if isinstance(value, str): - for classifier in classifiers: - if classifier.lower() == value.lower(): - return classifier - return classifiers[classifier_default_index(classifiers) - 1] - try: - index = int(value) - 1 - except (TypeError, ValueError): - index = classifier_default_index(classifiers) - 1 - if 0 <= index < len(classifiers): - return classifiers[index] - return classifiers[classifier_default_index(classifiers) - 1] - - -def resolve_classifier_data( - EEG: dict[str, Any], - classifier_name: str = "", - *, - component_total: int | None = None, - require: bool = False, -) -> ClassifierData | None: - """Return normalized classifier data or ``None`` when no classifier is available.""" - classifiers = classifier_names(EEG) - if not classifiers: - if require: - raise ValueError("No component classifier data found in EEG.etc.ic_classification") - return None - resolved_name = _resolve_classifier_name(classifiers, classifier_name) - record = (EEG.get("etc") or {})["ic_classification"][resolved_name] - if not isinstance(record, dict): - raise ValueError(f"Classifier {resolved_name!r} must be stored as a dictionary") - probabilities = np.asarray(record.get("classifications", []), dtype=float) - if probabilities.ndim != 2 or probabilities.size == 0: - raise ValueError(f"Classifier {resolved_name!r} is missing a 2-D classifications matrix") - if component_total is not None and probabilities.shape[0] != int(component_total): - raise ValueError( - f"Classifier {resolved_name!r} has {probabilities.shape[0]} rows for {component_total} ICA components" - ) - classes = _classifier_classes(record, resolved_name, probabilities.shape[1]) - return ClassifierData(resolved_name, classes, probabilities) - - -def resolve_dipfit_data(EEG: dict[str, Any], component_index: int) -> DipfitData | None: - """Return normalized DIPFIT model data for a 1-based component index.""" - models = normalize_model_list(EEG) - index = int(component_index) - if index < 1: - raise ValueError("component index must be 1-based") - if index > len(models): - return None - model = models[index - 1] - raw_positions = np.asarray(model.get("posxyz", []), dtype=float) - if raw_positions.size == 0: - return None - positions = _dipfit_matrix(raw_positions, "posxyz", index) - if not np.all(np.isfinite(positions)): - raise ValueError(f"DIPFIT model for component {index} contains non-finite posxyz values") - moments = _dipfit_moments(model.get("momxyz", []), positions.shape[0], index) - rv = _finite_float(model.get("rv")) - coordformat = "" - dipfit = EEG.get("dipfit") - if isinstance(dipfit, dict): - coordformat = str(dipfit.get("coordformat") or "") - return DipfitData( - positions=positions, - moments=moments, - rv_percent=None if rv is None else rv * 100.0, - dmr=_dipole_moment_ratio(moments), - coordformat=coordformat, - ) - - -def selected_property_indices( - EEG: dict[str, Any], - typecomp: int | bool, - values: Any, - *, - default_all: bool = True, -) -> list[int]: - """Normalize EEGLAB-facing channel/component selections to 1-based indices.""" - limit = int(EEG.get("nbchan", 0) or 0) if int(bool(typecomp)) else component_count(EEG) - return one_based_indices(values, limit=limit, default_all=default_all) - - -def component_count(EEG: dict[str, Any]) -> int: - """Return the number of available ICA components.""" - icaact = EEG.get("icaact") - if icaact is not None and np.asarray(icaact).size: - values = np.asarray(icaact) - if values.ndim >= 2: - return int(values.shape[0]) - weights = np.asarray(EEG.get("icaweights", [])) - if weights.ndim == 2 and weights.size: - return int(weights.shape[0]) - winv = np.asarray(EEG.get("icawinv", [])) - if winv.ndim == 2 and winv.size: - return int(winv.shape[1]) - return 0 - - -def has_component_classifier(EEG: dict[str, Any], classifier_name: str = "") -> bool: - """Return whether usable classifier data are available for the current ICA.""" - try: - return ( - resolve_classifier_data(EEG, classifier_name, component_total=component_count(EEG), require=False) - is not None - ) - except ValueError as exc: - if "missing a 2-D classifications matrix" in str(exc): - return False - raise - - -def component_rejection_status( - EEG: dict[str, Any], - component_index: int, - *, - component_total: int | None = None, -) -> bool: - """Return the current ``EEG.reject.gcompreject`` status for one component.""" - total = component_count(EEG) if component_total is None else int(component_total) - index = int(component_index) - if index < 1 or index > total: - raise ValueError("component index is outside available ICA components") - return bool(component_rejection_flags(EEG, total, create=False)[index - 1]) - - -def build_extended_property_data( - EEG: dict[str, Any], - typecomp: int | bool, - index: int, - *, - spec_opt: Any = None, - erp_opt: Any = None, - classifier_name: str = "", -) -> ExtendedPropertyData: - """Assemble the dashboard data for one EEGLAB-facing channel/component index.""" - del erp_opt - typecomp = int(bool(typecomp)) - data = eeg_epoch_data(EEG) - times_ms = eeg_times_ms(EEG) - if typecomp: - return _channel_dashboard_data(EEG, data, times_ms, int(index), spec_opt) - return _component_dashboard_data(EEG, data, times_ms, int(index), spec_opt, classifier_name) - - -def _build_navigable_dashboard( - EEG: dict[str, Any], - typecomp: int, - indices: list[int], - winhandle: Any, - spec_opt: Any, - erp_opt: Any, - scroll_event: int | bool, - classifier_name: str, - *, - fig: Any, - show_activity: bool, - reject_callback: Any | None, -) -> Any: - figure = fig if fig is not None else plt.figure(figsize=_DASHBOARD_SIZE) - state = { - "EEG": EEG, - "typecomp": int(typecomp), - "indices": tuple(indices), - "position": 0, - "spec_opt": spec_opt, - "erp_opt": erp_opt, - "scroll_event": int(bool(scroll_event)), - "classifier_name": classifier_name, - "show_activity": bool(show_activity), - "winhandle": winhandle, - "reject_callback": reject_callback, - "rejection_pending": _initial_rejection_state(EEG, int(typecomp), indices), - } - figure.eegprep_dashboard_state = state - _render_dashboard(figure) - return figure - - -def _render_dashboard(figure: Any) -> None: - state = figure.eegprep_dashboard_state - indices = state["indices"] - position = int(state["position"]) - index = int(indices[position]) - dashboard = build_extended_property_data( - state["EEG"], - state["typecomp"], - index, - spec_opt=state["spec_opt"], - erp_opt=state["erp_opt"], - classifier_name=state["classifier_name"], - ) - figure.clf() - figure.set_size_inches(*_DASHBOARD_SIZE, forward=True) - figure.patch.set_facecolor((0.93, 0.96, 1.0)) - _set_window_title(figure, dashboard.figure_title) - has_rejection_controls = _has_rejection_controls(state) - bottom = ( - 0.17 - if has_rejection_controls and len(indices) > 1 - else 0.125 - if has_rejection_controls - else 0.12 - if len(indices) > 1 - else 0.075 - ) - if dashboard.dipfit is None: - grid = figure.add_gridspec( - 2, - 4, - left=0.055, - right=0.97, - top=0.88, - bottom=bottom, - wspace=0.65, - hspace=0.55, - width_ratios=(1.2, 0.95, 1.25, 1.25), - height_ratios=(0.9, 1.2), - ) - topo_ax = figure.add_subplot(grid[0, 0]) - if dashboard.classifier is None: - class_ax = None - activity_ax = figure.add_subplot(grid[0, 1:]) - else: - class_ax = figure.add_subplot(grid[0, 1]) - activity_ax = figure.add_subplot(grid[0, 2:]) - image_ax = figure.add_subplot(grid[1, :2]) - dipfit_axes = [] - spectrum_ax = figure.add_subplot(grid[1, 2:]) - else: - grid = figure.add_gridspec( - 2, - 5, - left=0.055, - right=0.97, - top=0.88, - bottom=bottom, - wspace=0.58, - hspace=0.55, - width_ratios=(1.15, 0.9, 0.85, 1.25, 1.25), - height_ratios=(0.9, 1.2), - ) - topo_ax = figure.add_subplot(grid[0, 0]) - if dashboard.classifier is None: - class_ax = None - activity_ax = figure.add_subplot(grid[0, 1:]) - else: - class_ax = figure.add_subplot(grid[0, 1]) - activity_ax = figure.add_subplot(grid[0, 2:]) - image_ax = figure.add_subplot(grid[1, :2]) - dipfit_grid = grid[1, 2].subgridspec(3, 1, hspace=0.04) - dipfit_axes = [figure.add_subplot(dipfit_grid[row, 0]) for row in range(3)] - spectrum_ax = figure.add_subplot(grid[1, 3:]) - - _plot_topography(topo_ax, dashboard) - if class_ax is not None: - _plot_classifier(class_ax, dashboard) - events = state["EEG"].get("event", []) if bool(state["scroll_event"]) else [] - _plot_activity(activity_ax, dashboard, events) - _plot_activity_image(image_ax, dashboard) - if dipfit_axes: - _plot_dipfit(dipfit_axes, dashboard.dipfit) - _plot_spectrum(spectrum_ax, dashboard) - figure.suptitle(dashboard.figure_title, fontsize=14, fontweight="bold") - figure.eegprep_dashboard_data = dashboard - figure.eegprep_activity_view = property_activity_browser( - state["EEG"], - dashboard.typecomp, - dashboard.index, - scroll_event=state["scroll_event"], - show=state["show_activity"], - ) - if len(indices) > 1: - _add_navigation_controls( - figure, - bottom=0.075 if has_rejection_controls else 0.025, - count_y=0.1 if has_rejection_controls else 0.092, - ) - else: - figure.eegprep_dashboard_navigation = {} - figure.eegprep_dashboard_navigation_buttons = () - if has_rejection_controls: - _add_rejection_controls(figure, dashboard) - else: - figure.eegprep_dashboard_rejection = {} - figure.eegprep_dashboard_rejection_buttons = {} - figure.eegprep_dashboard_rejection_button_list = () - buttons = [] - buttons.extend(getattr(figure, "eegprep_dashboard_navigation_buttons", ())) - buttons.extend(getattr(figure, "eegprep_dashboard_rejection_button_list", ())) - figure.eegprep_dashboard_buttons = tuple(buttons) - figure.canvas.draw_idle() - - -def _plot_topography(axis: Any, dashboard: ExtendedPropertyData) -> None: - if dashboard.typecomp: - topoplot( - dashboard.topography_values, - dashboard.topography_chanlocs, - axes=axis, - style="blank", - electrodes="off", - ) - else: - topoplot( - dashboard.topography_values, - dashboard.topography_chanlocs, - axes=axis, - electrodes="on", - colorbar=False, - ) - axis.set_title(dashboard.topography_title, fontsize=12, fontweight="normal") - if dashboard.pvaf is not None: - axis.text( - 0.5, - -0.13, - f"{{% scalp data var. accounted for}}: {dashboard.pvaf:.1f}%", - transform=axis.transAxes, - ha="center", - va="top", - fontsize=9, - ) - - -def _plot_classifier(axis: Any, dashboard: ExtendedPropertyData) -> None: - assert dashboard.classifier is not None - assert dashboard.class_probabilities is not None - labels = list(reversed(dashboard.classifier.classes)) - probabilities = np.asarray(dashboard.class_probabilities, dtype=float)[::-1] - y_values = np.arange(len(labels)) - axis.barh(y_values, probabilities, color="#4c78a8") - axis.set_yticks(y_values, labels) - axis.set_xlim(0.0, 1.0) - axis.set_xticks([0.0, 0.5, 1.0]) - axis.grid(axis="x", alpha=0.3) - axis.set_xlabel("Probability") - axis.set_title(dashboard.classifier.name, fontsize=12, fontweight="normal") - for y_value, probability in zip(y_values, probabilities): - axis.text(0.5, y_value, f"{probability * 100:.1f}%", ha="center", va="center", fontsize=8) - - -def _plot_activity(axis: Any, dashboard: ExtendedPropertyData, events: Any) -> None: - trace = _activity_trace(dashboard) - axis.plot(trace.x_values, trace.y_values, color="black", linewidth=0.85) - axis.axhline(0.0, color="0.75", linewidth=0.6) - _plot_epoch_markers(axis, trace) - _plot_event_markers(axis, trace, events) - axis.set_title(dashboard.activity_title, fontsize=12, fontweight="normal") - axis.set_xlabel("Time (ms)") - axis.set_ylabel("uV") - axis.grid(True, alpha=0.2) - _format_scrollplot_axis(axis, trace) - - -def _plot_activity_image(axis: Any, dashboard: ExtendedPropertyData) -> None: - image = np.asarray(dashboard.image_data, dtype=float) - handle = axis.imshow( - image, - aspect="auto", - origin="lower", - extent=dashboard.image_extent, - cmap="RdBu_r", - ) - axis.set_title(dashboard.image_title, fontsize=12, fontweight="normal") - axis.set_xlabel("Time (ms)" if dashboard.activity.shape[2] > 1 else "Data") - axis.set_ylabel("Epoch" if dashboard.activity.shape[2] > 1 else "Data") - plt.colorbar(handle, ax=axis, fraction=0.046, pad=0.035) - - -def _plot_spectrum(axis: Any, dashboard: ExtendedPropertyData) -> None: - axis.plot(dashboard.spectrum_freqs, dashboard.spectrum_power, color="black", linewidth=1.0) - axis.set_title(dashboard.spectrum_title, fontsize=12, fontweight="normal") - axis.set_xlabel("Frequency (Hz)") - axis.set_ylabel("Power 10*log10(uV^2/Hz)") - axis.grid(True, alpha=0.25) - finite = np.isfinite(dashboard.spectrum_power) - if np.any(finite): - finite_values = dashboard.spectrum_power[finite] - low = float(np.min(finite_values)) - high = float(np.max(finite_values)) - if high == low: - padding = 1.0 if low == 0.0 else abs(low) * 0.05 - low -= padding - high += padding - axis.set_ylim(low, high) - - -def _plot_dipfit(axes: list[Any], dipfit: DipfitData | None) -> None: - assert dipfit is not None - volume = load_standard_mri_volume() - for axis, mri_slice in zip(axes, dipfit_mri_slices(volume, dipfit.positions)): - axis.set_facecolor("black") - axis.imshow(mri_slice.image, cmap="gray", origin="lower", extent=mri_slice.extent, interpolation="nearest") - axis.set_aspect("equal", adjustable="box") - axis.set_xticks([]) - axis.set_yticks([]) - for spine in axis.spines.values(): - spine.set_visible(False) - _plot_dipfit_points(axis, dipfit, mri_slice.x_axis, mri_slice.y_axis) - axes[0].set_title("Dipole Position", fontsize=12, fontweight="normal", pad=7) - _plot_dipfit_values(axes[-1], dipfit) - - -def _plot_dipfit_points(axis: Any, dipfit: DipfitData, x_index: int, y_index: int) -> None: - for row, position in enumerate(dipfit.positions): - color = _DIPFIT_COLORS[row % len(_DIPFIT_COLORS)] - x_value = float(position[x_index]) - y_value = float(position[y_index]) - axis.plot( - x_value, - y_value, - marker="o", - markersize=5.5, - color=color, - markeredgecolor="white", - markeredgewidth=0.45, - ) - if dipfit.moments is not None and row < dipfit.moments.shape[0]: - _plot_dipfit_moment(axis, x_value, y_value, dipfit.moments[row], x_index, y_index, color) - - -def _plot_dipfit_moment( - axis: Any, - x_value: float, - y_value: float, - moment: np.ndarray, - x_index: int, - y_index: int, - color: str, -) -> None: - dx = float(moment[x_index]) - dy = float(moment[y_index]) - norm = float(np.hypot(dx, dy)) - if not np.isfinite(norm) or norm <= 0.0: - return - axis.arrow( - x_value, - y_value, - dx / norm * 18.0, - dy / norm * 18.0, - color=color, - width=0.8, - head_width=5.0, - length_includes_head=True, - alpha=0.9, - ) - - -def _plot_dipfit_values(axis: Any, dipfit: DipfitData) -> None: - lines = [] - if dipfit.rv_percent is not None: - lines.append(f"RV: {dipfit.rv_percent:.1f}%") - if dipfit.dmr is not None: - lines.append(f"DMR: {dipfit.dmr:.1f}") - if lines: - axis.text( - 0.5, - -0.02, - "\n".join(lines), - transform=axis.transAxes, - color="black", - fontsize=8, - ha="center", - va="top", - ) - - -def _add_navigation_controls(figure: Any, *, bottom: float, count_y: float) -> None: - previous_axis = figure.add_axes((0.37, bottom, 0.105, 0.05)) - next_axis = figure.add_axes((0.525, bottom, 0.105, 0.05)) - previous_button = Button(previous_axis, "Previous") - next_button = Button(next_axis, "Next") - - def previous(_event: Any = None) -> None: - _navigate_dashboard(figure, -1) - - def next_(_event: Any = None) -> None: - _navigate_dashboard(figure, 1) - - previous_button.on_clicked(previous) - next_button.on_clicked(next_) - figure.eegprep_dashboard_navigation_buttons = (previous_button, next_button) - figure.eegprep_dashboard_navigation = {"previous": previous, "next": next_} - state = figure.eegprep_dashboard_state - figure.text( - 0.5, - count_y, - f"{int(state['position']) + 1} / {len(state['indices'])}", - ha="center", - va="center", - fontsize=9, - ) - - -def _add_rejection_controls(figure: Any, dashboard: ExtendedPropertyData) -> None: - state = figure.eegprep_dashboard_state - index = int(dashboard.index) - rejected = _pending_rejection_status(state, index) - cancel_button = Button(figure.add_axes((0.2, 0.015, 0.1, 0.045)), "Cancel", color=_CONTROL_COLOR) - values_button = Button(figure.add_axes((0.325, 0.015, 0.1, 0.045)), "Values", color=_CONTROL_COLOR) - status_button = Button( - figure.add_axes((0.45, 0.015, 0.1, 0.045)), - _rejection_label(rejected), - color=_rejection_color(rejected), - hovercolor=_rejection_color(rejected), - ) - help_button = Button(figure.add_axes((0.575, 0.015, 0.1, 0.045)), "HELP", color=_CONTROL_COLOR) - ok_button = Button(figure.add_axes((0.7, 0.015, 0.1, 0.045)), "OK", color=_CONTROL_COLOR) - - if not _component_values_available(state["EEG"]): - values_button.set_active(False) - values_button.ax.set_facecolor(_DISABLED_CONTROL_COLOR) - - def cancel(_event: Any = None) -> None: - plt.close(figure) - - def values(_event: Any = None) -> None: - if _component_values_available(state["EEG"]): - _show_component_values(state["EEG"], index, _pending_rejection_status(state, index)) - - def toggle(_event: Any = None) -> None: - state["rejection_pending"][index] = not _pending_rejection_status(state, index) - _style_rejection_status_button(status_button, state["rejection_pending"][index]) - figure.canvas.draw_idle() - - def help_(_event: Any = None) -> None: - pophelp("pop_prop_extended") - - def ok(_event: Any = None) -> None: - _commit_rejection_state(figure) - plt.close(figure) - - cancel_button.on_clicked(cancel) - values_button.on_clicked(values) - status_button.on_clicked(toggle) - help_button.on_clicked(help_) - ok_button.on_clicked(ok) - figure.eegprep_dashboard_rejection_button_list = ( - cancel_button, - values_button, - status_button, - help_button, - ok_button, - ) - figure.eegprep_dashboard_rejection_buttons = { - "cancel": cancel_button, - "values": values_button, - "status": status_button, - "help": help_button, - "ok": ok_button, - } - figure.eegprep_dashboard_rejection = { - "cancel": cancel, - "values": values, - "toggle": toggle, - "help": help_, - "ok": ok, - "pending": state["rejection_pending"], - } - - -def _navigate_dashboard(figure: Any, step: int) -> None: - state = figure.eegprep_dashboard_state - state["position"] = (int(state["position"]) + int(step)) % len(state["indices"]) - _render_dashboard(figure) - - -def _initial_rejection_state(EEG: dict[str, Any], typecomp: int, indices: list[int]) -> dict[int, bool]: - if int(typecomp): - return {} - total = component_count(EEG) - flags = component_rejection_flags(EEG, total, create=False) - return {int(index): bool(flags[int(index) - 1]) for index in indices} - - -def _has_rejection_controls(state: dict[str, Any]) -> bool: - return int(state["typecomp"]) == 0 and bool(state["indices"]) - - -def _pending_rejection_status(state: dict[str, Any], component_index: int) -> bool: - pending = state["rejection_pending"] - index = int(component_index) - if index not in pending: - pending[index] = component_rejection_status(state["EEG"], index) - return bool(pending[index]) - - -def _rejection_label(rejected: bool) -> str: - return "REJECT" if rejected else "ACCEPT" - - -def _rejection_color(rejected: bool) -> str: - return _REJECT_COLOR if rejected else _ACCEPT_COLOR - - -def _style_rejection_status_button(button: Button, rejected: bool) -> None: - button.label.set_text(_rejection_label(rejected)) - button.ax.set_facecolor(_rejection_color(rejected)) - button.color = _rejection_color(rejected) - button.hovercolor = _rejection_color(rejected) - - -def _commit_rejection_state(figure: Any) -> None: - state = figure.eegprep_dashboard_state - if not _has_rejection_controls(state): - return - EEG = state["EEG"] - total = component_count(EEG) - committed: dict[int, bool] = {} - for index in sorted(state["rejection_pending"]): - rejected = bool(state["rejection_pending"][index]) - set_component_rejection_flag(EEG, index, rejected, total) - _style_rejection_winhandle(state["winhandle"], index, rejected) - committed[int(index)] = rejected - callback = state.get("reject_callback") - if callback is not None: - callback(EEG, dict(committed)) - - -def _style_rejection_winhandle(winhandle: Any, component_index: int, rejected: bool) -> None: - if _is_empty_winhandle(winhandle): - return - handle = winhandle - if isinstance(winhandle, Mapping): - handle = winhandle.get(int(component_index)) - if isinstance(handle, Button): - handle.ax.set_facecolor(_rejection_color(rejected)) - - -def _is_empty_winhandle(winhandle: Any) -> bool: - if winhandle is None: - return True - if isinstance(winhandle, (int, float, np.integer, np.floating)): - return bool(winhandle == 0) or bool(np.isnan(float(winhandle))) - return False - - -def _component_values_available(EEG: dict[str, Any]) -> bool: - stats = EEG.get("stats") - if not isinstance(stats, dict): - return False - return np.asarray(stats.get("compenta", [])).size > 0 - - -def _show_component_values(EEG: dict[str, Any], component_index: int, rejected: bool) -> None: - values = _component_value_lines(EEG, component_index, rejected) - figure = plt.figure(figsize=(3.4, 3.4)) - manager = getattr(figure.canvas, "manager", None) - if manager is not None: - manager.set_window_title("Statistics of the component") - axis = figure.add_subplot(1, 1, 1) - axis.axis("off") - axis.text(0.04, 0.96, "\n".join(values), va="top", ha="left", fontsize=9, family="monospace") - close_button = Button(figure.add_axes((0.375, 0.03, 0.25, 0.08)), "Close", color=_CONTROL_COLOR) - close_button.on_clicked(lambda _event=None: plt.close(figure)) - setattr(figure, "eegprep_component_values_button", close_button) - figure.tight_layout(rect=(0, 0.12, 1, 1)) - figure.canvas.draw_idle() - - -def _component_value_lines(EEG: dict[str, Any], component_index: int, rejected: bool) -> list[str]: - raw_stats = EEG.get("stats") - stats = raw_stats if isinstance(raw_stats, dict) else {} - raw_reject = EEG.get("reject") - reject = raw_reject if isinstance(raw_reject, dict) else {} - index = int(component_index) - return [ - "(", - f"Entropy of component activity {_indexed_stat(stats.get('compenta'), index):>8}", - f"> Rejection threshold {_scalar_stat(reject.get('threshentropy')):>8}", - "", - " AND ----", - "", - f"Kurtosis of component activity {_indexed_stat(stats.get('compkurta'), index):>8}", - f"> Rejection threshold {_scalar_stat(reject.get('threshkurtact')):>8}", - "", - ") OR ----", - "", - f"Kurtosis distribution {_indexed_stat(stats.get('compkurtdist'), index):>8}", - f"> Rejection threshold {_scalar_stat(reject.get('threshkurtdist')):>8}", - "", - f"Current thresholds suggest to {_rejection_label(rejected)} the component", - "", - "After manually accepting/rejecting the component, recalibrate", - "thresholds before applying automatic rejection to other datasets.", - ] - - -def _indexed_stat(values: Any, component_index: int) -> str: - vector = np.asarray(values, dtype=float).ravel() - index = int(component_index) - 1 - if index < 0 or index >= vector.size or not np.isfinite(vector[index]): - return "----" - return f"{float(vector[index]):2.2f}" - - -def _scalar_stat(value: Any) -> str: - vector = np.asarray(value, dtype=float).ravel() - if vector.size == 0 or not np.isfinite(vector[0]): - return "----" - return f"{float(vector[0]):2.2f}" - - -def _channel_dashboard_data( - EEG: dict[str, Any], - data: np.ndarray, - times_ms: np.ndarray, - index: int, - spec_opt: Any, -) -> ExtendedPropertyData: - labels = channel_labels(EEG) - if index < 1 or index > data.shape[0]: - raise ValueError("channel index is outside available channels") - label = labels[index - 1] if index - 1 < len(labels) else str(index) - activity = np.array(data[index - 1 : index], dtype=float, copy=True) - spectrum_freqs, spectrum_power = _spectrum(activity, EEG, spec_opt) - image_data, image_extent, image_title = _activity_image(activity, times_ms, f"Epoched Channel {label} Activity") - return ExtendedPropertyData( - typecomp=1, - index=index, - label=label, - figure_title=f"Channel {label} - pop_prop_extended()", - topography_title=f"Channel {label}", - topography_values=index, - topography_chanlocs=list(EEG.get("chanlocs", []) or []), - activity=activity, - times_ms=times_ms, - activity_title="Channel Time Series", - image_data=image_data, - image_extent=image_extent, - image_title=image_title, - spectrum_freqs=spectrum_freqs, - spectrum_power=spectrum_power, - spectrum_title="Channel Activity Power Spectrum", - classifier=None, - class_probabilities=None, - pvaf=None, - dipfit=None, - rejected=None, - ) - - -def _component_dashboard_data( - EEG: dict[str, Any], - data: np.ndarray, - times_ms: np.ndarray, - index: int, - spec_opt: Any, - classifier_name: str, -) -> ExtendedPropertyData: - activity_all = component_activations(EEG) - maps, map_chanlocs = component_map_data(EEG) - if index < 1 or index > activity_all.shape[0]: - raise ValueError("component index is outside available ICA components") - activity = np.array(activity_all[index - 1 : index], dtype=float, copy=True) - classifier = resolve_classifier_data(EEG, classifier_name, component_total=activity_all.shape[0], require=False) - probabilities = ( - None if classifier is None else np.array(classifier.probabilities[index - 1], dtype=float, copy=True) - ) - spectrum_freqs, spectrum_power = _spectrum(activity, EEG, spec_opt) - image_data, image_extent, image_title = _activity_image(activity, times_ms, f"Epoched IC{index} Activity") - return ExtendedPropertyData( - typecomp=0, - index=index, - label=f"IC{index}", - figure_title=f"IC{index} - pop_prop_extended()", - topography_title=f"IC{index}", - topography_values=maps[:, index - 1], - topography_chanlocs=map_chanlocs, - activity=activity, - times_ms=times_ms, - activity_title=f"Scrolling IC{index} Activity", - image_data=image_data, - image_extent=image_extent, - image_title=image_title, - spectrum_freqs=spectrum_freqs, - spectrum_power=spectrum_power, - spectrum_title=f"IC{index} Activity Power Spectrum", - classifier=classifier, - class_probabilities=probabilities, - pvaf=_component_pvaf(EEG, data, maps, activity, index), - dipfit=resolve_dipfit_data(EEG, index), - rejected=component_rejection_status(EEG, index, component_total=activity_all.shape[0]), - ) - - -def _spectrum(activity: np.ndarray, EEG: dict[str, Any], spec_opt: Any) -> tuple[np.ndarray, np.ndarray]: - options = parse_plot_options_text(spec_opt) - flat = np.asarray(activity, dtype=float).reshape(1, -1) - spectra, freqs, _std = compute_spectra( - flat, - int(EEG.get("pnts", flat.shape[1]) or flat.shape[1]), - float(EEG.get("srate", 1.0) or 1.0), - winsize=_first_int(options.get("winsize")), - overlap=_first_int(options.get("overlap")) or 0, - nfft=_first_int(options.get("nfft")), - ) - return freqs, spectra[0] - - -def _activity_image( - activity: np.ndarray, - times_ms: np.ndarray, - epoched_title: str, -) -> tuple[np.ndarray, tuple[float, float, float, float], str]: - trace = np.asarray(activity[0], dtype=float) - trace = trace - float(np.nanmean(trace)) - if trace.ndim == 1: - trace = trace[:, np.newaxis] - if trace.shape[1] > 1: - image = trace.T - extent = (float(times_ms[0]), float(times_ms[-1]), 1.0, float(trace.shape[1])) - return image, extent, epoched_title - flat = trace[:, 0] - line_count = min(200, max(1, int(np.floor(np.sqrt(flat.size))))) - frame_count = max(1, flat.size // line_count) - image = flat[: line_count * frame_count].reshape(line_count, frame_count) - extent = (0.0, float(frame_count - 1), 1.0, float(line_count)) - return image, extent, "Continuous Data" - - -def _activity_trace(dashboard: ExtendedPropertyData) -> ActivityTraceData: - trace = np.asarray(dashboard.activity[0], dtype=float) - if trace.ndim == 1: - trace = trace[:, np.newaxis] - pnts = int(dashboard.times_ms.size) - srate = _srate_from_times(dashboard.times_ms) - window_samples = max(1, int(round(_SCROLL_SECONDS * srate))) - if trace.shape[1] == 1: - sample_count = min(pnts, window_samples) - return ActivityTraceData( - x_values=np.arange(1, sample_count + 1, dtype=float), - y_values=trace[:sample_count, 0], - times_ms=dashboard.times_ms, - pnts=pnts, - epoched=False, - ) - flat = trace.T.reshape(-1) - sample_count = min(flat.size, window_samples) - return ActivityTraceData( - x_values=np.arange(1, sample_count + 1, dtype=float), - y_values=flat[:sample_count], - times_ms=dashboard.times_ms, - pnts=pnts, - epoched=True, - ) - - -def _plot_event_markers( - axis: Any, - trace: ActivityTraceData, - events: Any, -) -> None: - event_items = _event_items(events) - if not event_items: - return - first = float(np.nanmin(trace.x_values)) - last = float(np.nanmax(trace.x_values)) - colors = _event_color_map(event_items) - for event in event_items: - if "latency" not in event: - continue - try: - latency = float(_event_scalar(event["latency"])) - except (TypeError, ValueError): - continue - x_value = latency - if first <= x_value <= last: - label = str(_event_scalar(event.get("type", ""))) - axis.axvline(x_value, color=colors[label], linestyle="--", linewidth=0.75) - axis.text( - x_value, - 0.98, - label, - color=colors[label], - transform=axis.get_xaxis_transform(), - rotation=45, - fontsize=8, - ) - - -def _plot_epoch_markers(axis: Any, trace: ActivityTraceData) -> None: - if not trace.epoched: - return - first = float(np.nanmin(trace.x_values)) - last = float(np.nanmax(trace.x_values)) - start = int(np.ceil(first / trace.pnts) * trace.pnts) - for boundary in range(start, int(np.floor(last / trace.pnts) * trace.pnts) + 1, trace.pnts): - if boundary < first or boundary > last: - continue - axis.axvline(float(boundary), color="red", linestyle="-", linewidth=0.75) - axis.text( - float(boundary), - 0.98, - f"epoch {boundary // trace.pnts}", - color="red", - transform=axis.get_xaxis_transform(), - rotation=45, - fontsize=8, - ) - - -def _format_scrollplot_axis(axis: Any, trace: ActivityTraceData) -> None: - first = float(trace.x_values[0]) - last = float(trace.x_values[-1]) - axis.set_xlim(first, last) - tick_count = min(6, trace.x_values.size) - ticks = np.unique(np.linspace(first, last, tick_count).round().astype(int)) - labels = [] - for tick in ticks: - sample = (int(tick) - 1) % trace.pnts if trace.epoched else min(max(int(tick) - 1, 0), trace.pnts - 1) - labels.append(f"{trace.times_ms[sample]:g}") - axis.set_xticks(ticks) - axis.set_xticklabels(labels) - - -def _event_items(events: Any) -> list[dict[str, Any]]: - if events is None: - return [] - if isinstance(events, dict): - if "latency" not in events: - return [] - latencies = np.asarray(events["latency"], dtype=object).ravel() - types = np.asarray(events.get("type", [""] * latencies.size), dtype=object).ravel() - epochs = np.asarray(events.get("epoch", [None] * latencies.size), dtype=object).ravel() - event_items = [] - for index, latency in enumerate(latencies): - event = { - "type": _event_scalar(types[min(index, types.size - 1)]), - "latency": _event_scalar(latency), - } - epoch = _event_scalar(epochs[min(index, epochs.size - 1)]) - if epoch is not None: - event["epoch"] = epoch - event_items.append(event) - return event_items - if isinstance(events, np.ndarray): - events = events.tolist() - try: - event_values = list(events) - except TypeError: - return [] - return [event for event in event_values if isinstance(event, dict)] - - -def _event_scalar(value: Any) -> Any: - array = np.asarray(value) - if array.shape == (): - return array.item() - if array.size == 1: - item = array.ravel()[0] - return item.item() if hasattr(item, "item") else item - return value - - -def _event_color_map(event_items: list[dict[str, Any]]) -> dict[str, str]: - labels: list[str] = [] - for event in event_items: - label = str(_event_scalar(event.get("type", ""))) - if label not in labels: - labels.append(label) - return {label: _EVENT_COLORS[index % len(_EVENT_COLORS)] for index, label in enumerate(labels)} - - -def _dipfit_matrix(values: np.ndarray, field_name: str, component_index: int) -> np.ndarray: - matrix = values - if matrix.ndim == 1: - matrix = matrix.reshape(1, -1) - if matrix.ndim != 2 or matrix.shape[1] < 3: - raise ValueError( - f"DIPFIT model for component {component_index} must contain {field_name} rows with 3 coordinates" - ) - return np.array(matrix[:, :3], dtype=float, copy=True) - - -def _dipfit_moments(values: Any, position_count: int, component_index: int) -> np.ndarray | None: - raw_moments = np.asarray(values, dtype=float) - if raw_moments.size == 0: - return None - moments = _dipfit_matrix(raw_moments, "momxyz", component_index) - if moments.shape[0] != position_count: - raise ValueError(f"DIPFIT model for component {component_index} must have matching posxyz and momxyz rows") - if not np.all(np.isfinite(moments)): - raise ValueError(f"DIPFIT model for component {component_index} contains non-finite momxyz values") - return moments - - -def _dipole_moment_ratio(moments: np.ndarray | None) -> float | None: - if moments is None or moments.shape[0] != 2: - return None - norms = np.linalg.norm(moments, axis=1) - if not np.all(np.isfinite(norms)) or np.any(norms <= 0.0): - return None - ratio = float(norms[0] / norms[1]) - return ratio if ratio >= 1.0 else 1.0 / ratio - - -def _finite_float(value: Any) -> float | None: - try: - numeric = float(np.asarray(value).reshape(())) - except (TypeError, ValueError): - return None - return numeric if np.isfinite(numeric) else None - - -def _component_pvaf( - EEG: dict[str, Any], - data: np.ndarray, - maps: np.ndarray, - activity: np.ndarray, - index: int, -) -> float | None: - if maps.shape[0] == 0: - return None - icachansind = component_channel_indices(EEG, data.shape[0]) - if maps.shape[0] != icachansind.size: - return None - flat_data = data[icachansind, :, :].reshape(icachansind.size, -1) - component_trace = activity.reshape(1, -1) - projection = maps[:, index - 1 : index] @ component_trace - datavar = float(np.nanmean(np.nanvar(flat_data, axis=1))) - if not np.isfinite(datavar) or datavar <= 0: - return None - projvar = float(np.nanmean(np.nanvar(flat_data - projection, axis=1))) - if not np.isfinite(projvar): - return None - return 100.0 * (1.0 - projvar / datavar) - - -def _resolve_classifier_name(classifiers: list[str], classifier_name: str) -> str: - if classifier_name: - for name in classifiers: - if name.lower() == str(classifier_name).lower(): - return name - raise ValueError(f"Classifier {classifier_name!r} was not found in EEG.etc.ic_classification") - return classifiers[classifier_default_index(classifiers) - 1] - - -def _classifier_classes(record: dict[str, Any], classifier_name: str, class_count: int) -> tuple[str, ...]: - raw_classes = record.get("classes", []) - classes = [str(item) for item in np.asarray(raw_classes, dtype=object).ravel().tolist() if str(item)] - if not classes: - if classifier_name.lower() == "iclabel" and class_count == len(DEFAULT_ICLABEL_CLASSES): - return DEFAULT_ICLABEL_CLASSES - return tuple(f"Class {index}" for index in range(1, class_count + 1)) - if len(classes) != class_count: - raise ValueError( - f"Classifier {classifier_name!r} has {class_count} probability columns but {len(classes)} class names" - ) - return tuple(classes) - - -def _first_int(value: Any) -> int | None: - vector = numeric_vector(value) - if vector.size == 0: - return None - return int(vector[0]) - - -def _srate_from_times(times_ms: np.ndarray) -> float: - if times_ms.size < 2: - return 1.0 - interval_ms = float(np.nanmedian(np.diff(times_ms))) - if not np.isfinite(interval_ms) or interval_ms <= 0: - return 1.0 - return 1000.0 / interval_ms - - -def _set_window_title(figure: Any, title: str) -> None: - figure.set_label(title) - manager = getattr(getattr(figure, "canvas", None), "manager", None) - if manager is not None and hasattr(manager, "set_window_title"): - manager.set_window_title(title) - - def _history_command( typecomp: int, indices: list[int], @@ -1358,6 +170,7 @@ def _history_command( __all__ = [ + "DEFAULT_ICLABEL_CLASSES", "ClassifierData", "DipfitData", "ExtendedPropertyData", diff --git a/src/eegprep/plugins/clean_rawdata/clean_channels.py b/src/eegprep/plugins/clean_rawdata/clean_channels.py index 713af4a3..30b05bb7 100644 --- a/src/eegprep/plugins/clean_rawdata/clean_channels.py +++ b/src/eegprep/plugins/clean_rawdata/clean_channels.py @@ -5,9 +5,12 @@ import numpy as np +from eegprep.plugins.firfilt.design import design_fir + from ...functions.miscfunc.misc import finite_matmul, round_mat +from .private.channel_removal import remove_channels_without_pop_select, update_clean_channel_mask from .private.ransac import calc_projector -from .private.sigproc import design_fir, filtfilt_fast +from .private.sigproc import filtfilt_fast from .private.stats import mad logger = logging.getLogger(__name__) @@ -158,24 +161,7 @@ def clean_channels( logger.debug("Exception traceback:", exc_info=True) logger.info(f'Removing {np.sum(removed_channels)} channels and dropping signal meta-data.') - if len(EEG['chanlocs']) == EEG['data'].shape[0]: - EEG['chanlocs'] = np.asarray([ch for i, ch in enumerate(EEG['chanlocs']) if not removed_channels[i]]) - # pop_select() by default truncates the data to float32, so we need to do the same - EEG['data'] = np.asarray(EEG['data'], dtype=np.float32) - EEG['data'] = EEG['data'][~removed_channels, :] - EEG['nbchan'] = EEG['data'].shape[0] - - # Clear other fields - for field in ['icawinv', 'icasphere', 'icaweights', 'icaact', 'stats', 'specdata', 'specicaact']: - if field in EEG: - EEG[field] = np.array([]) - - # Update clean_channel_mask - if 'etc' in EEG and 'clean_channel_mask' in EEG['etc']: - EEG['etc']['clean_channel_mask'][EEG['etc']['clean_channel_mask']] = ~removed_channels - else: - if 'etc' not in EEG: - EEG['etc'] = {} - EEG['etc']['clean_channel_mask'] = ~removed_channels + EEG = remove_channels_without_pop_select(EEG, removed_channels) + update_clean_channel_mask(EEG, removed_channels) return EEG diff --git a/src/eegprep/plugins/clean_rawdata/clean_channels_nolocs.py b/src/eegprep/plugins/clean_rawdata/clean_channels_nolocs.py index 88284f92..e20900a8 100644 --- a/src/eegprep/plugins/clean_rawdata/clean_channels_nolocs.py +++ b/src/eegprep/plugins/clean_rawdata/clean_channels_nolocs.py @@ -5,7 +5,10 @@ import numpy as np -from .private.sigproc import design_fir, design_kaiser, filtfilt_fast +from eegprep.plugins.firfilt.design import design_fir, design_kaiser + +from .private.channel_removal import remove_channels_without_pop_select, update_clean_channel_mask +from .private.sigproc import filtfilt_fast logger = logging.getLogger(__name__) @@ -116,30 +119,7 @@ def clean_channels_nolocs( logger.debug('Exception traceback:', exc_info=True) logger.info('Falling back to a basic substitute and dropping signal meta-data.') - # Manual channel removal - if len(EEG['chanlocs']) == EEG['data'].shape[0]: - EEG['chanlocs'] = np.asarray([ch for i, ch in enumerate(EEG['chanlocs']) if not removed_channels[i]]) - # pop_select() by default truncates the data to float32, so we need to do the same - EEG['data'] = np.asarray(EEG['data'], dtype=np.float32) - EEG['data'] = EEG['data'][~removed_channels, :] - EEG['nbchan'] = EEG['data'].shape[0] - - # Clear other fields - for field in ['icawinv', 'icasphere', 'icaweights', 'icaact', 'stats', 'specdata', 'specicaact']: - if field in EEG: - EEG[field] = np.array([]) - - # Update clean_channel_mask - if ( - 'etc' in EEG - and 'clean_channel_mask' in EEG['etc'] - and sum(EEG['etc']['clean_channel_mask']) == len(removed_channels) - ): - mask = EEG['etc']['clean_channel_mask'] - EEG['etc']['clean_channel_mask'] = np.logical_and(mask, ~removed_channels[mask]) - else: - if 'etc' not in EEG: - EEG['etc'] = {} - EEG['etc']['clean_channel_mask'] = ~removed_channels + EEG = remove_channels_without_pop_select(EEG, removed_channels) + update_clean_channel_mask(EEG, removed_channels) return EEG, removed_channels diff --git a/src/eegprep/plugins/clean_rawdata/clean_drifts.py b/src/eegprep/plugins/clean_rawdata/clean_drifts.py index 2408bda2..1d850474 100644 --- a/src/eegprep/plugins/clean_rawdata/clean_drifts.py +++ b/src/eegprep/plugins/clean_rawdata/clean_drifts.py @@ -6,7 +6,9 @@ import numpy as np from scipy.signal import filtfilt -from .private.sigproc import design_fir, design_kaiser, filtfilt_fast +from eegprep.plugins.firfilt.design import design_fir, design_kaiser + +from .private.sigproc import filtfilt_fast logger = logging.getLogger(__name__) diff --git a/src/eegprep/plugins/clean_rawdata/clean_flatlines.py b/src/eegprep/plugins/clean_rawdata/clean_flatlines.py index e3ea44a6..bb2c4207 100644 --- a/src/eegprep/plugins/clean_rawdata/clean_flatlines.py +++ b/src/eegprep/plugins/clean_rawdata/clean_flatlines.py @@ -5,6 +5,8 @@ import numpy as np +from .private.channel_removal import remove_channels_without_pop_select, update_clean_channel_mask + logger = logging.getLogger(__name__) @@ -59,18 +61,7 @@ def clean_flatlines(EEG: Dict[str, Any], max_flatline_duration: float = 5.0, max logger.error('Could not select channels using EEGLAB\'s pop_select(); details: %s', str(e)) logger.debug('Exception traceback:', exc_info=True) logger.info('Falling back to a basic substitute and dropping signal meta-data.') - # pop_select() by default truncates the data to float32, so we need to do the same - EEG['data'] = np.asarray(EEG['data'], dtype=np.float32) - EEG['data'] = EEG['data'][np.logical_not(removed_channels), :] - if len(EEG['chanlocs']) == len(removed_channels): - EEG['chanlocs'] = EEG['chanlocs'][np.logical_not(removed_channels)] - EEG['nbchan'] = EEG['data'].shape[0] - for fn in EEG.keys() & {'icawinv', 'icasphere', 'icaweights', 'icaact', 'stats', 'specdata', 'specicaact'}: - EEG[fn] = np.array([]) - CCM = EEG['etc'].get('clean_channel_mask') - if CCM is not None: - CCM[CCM] = ~removed_channels - else: - EEG['etc']['clean_channel_mask'] = ~removed_channels + EEG = remove_channels_without_pop_select(EEG, removed_channels) + update_clean_channel_mask(EEG, removed_channels) return EEG diff --git a/src/eegprep/plugins/clean_rawdata/private/channel_removal.py b/src/eegprep/plugins/clean_rawdata/private/channel_removal.py new file mode 100644 index 00000000..eceb7717 --- /dev/null +++ b/src/eegprep/plugins/clean_rawdata/private/channel_removal.py @@ -0,0 +1,42 @@ +"""Shared clean_rawdata channel-removal helpers.""" + +from __future__ import annotations + +from typing import Any + +import numpy as np + +CHANNEL_DEPENDENT_FIELDS = ("icawinv", "icasphere", "icaweights", "icaact", "stats", "specdata", "specicaact") + + +def remove_channels_without_pop_select(EEG: dict[str, Any], removed_channels: np.ndarray) -> dict[str, Any]: + """Remove channels when ``pop_select`` is unavailable.""" + removed = np.asarray(removed_channels, dtype=bool).ravel() + data = np.asarray(EEG["data"], dtype=np.float32) + chanlocs = EEG.get("chanlocs", []) + if len(chanlocs) == data.shape[0]: + if isinstance(chanlocs, np.ndarray): + EEG["chanlocs"] = chanlocs[~removed] + else: + EEG["chanlocs"] = np.asarray([chanloc for index, chanloc in enumerate(chanlocs) if not removed[index]]) + EEG["data"] = data[~removed, ...] + EEG["nbchan"] = EEG["data"].shape[0] + for field in CHANNEL_DEPENDENT_FIELDS: + if field in EEG: + EEG[field] = np.array([]) + return EEG + + +def update_clean_channel_mask(EEG: dict[str, Any], removed_channels: np.ndarray) -> None: + """Update ``EEG.etc.clean_channel_mask`` after current-channel removal.""" + removed = np.asarray(removed_channels, dtype=bool).ravel() + etc = EEG.setdefault("etc", {}) + mask = etc.get("clean_channel_mask") + if mask is not None: + existing = np.asarray(mask, dtype=bool).ravel() + if int(np.sum(existing)) == removed.size: + updated = np.array(existing, dtype=bool, copy=True) + updated[updated] = ~removed + etc["clean_channel_mask"] = updated + return + etc["clean_channel_mask"] = ~removed diff --git a/src/eegprep/plugins/clean_rawdata/private/sigproc.py b/src/eegprep/plugins/clean_rawdata/private/sigproc.py index d5dc6fb6..3102f70b 100644 --- a/src/eegprep/plugins/clean_rawdata/private/sigproc.py +++ b/src/eegprep/plugins/clean_rawdata/private/sigproc.py @@ -1,118 +1,11 @@ """Signal processing utilities.""" -from typing import Optional, Sequence, Union +from typing import Union import numpy as np from scipy.signal import fftconvolve -from ....functions.miscfunc.misc import round_mat -__all__ = ['design_kaiser', 'design_fir', 'filtfilt_fast'] - - -def design_kaiser(lo: float, hi: float, atten: float, want_odd: bool, use_scipy: bool = False) -> np.ndarray: - """Design a Kaiser window for a low-pass FIR filter. - - Parameters - ---------- - lo : float - Normalized lower edge of the transition band. - hi : float - Normalized upper edge of the transition band. - atten : float - Stop-band attenuation in dB (-20log10(ratio)). - want_odd : bool - Whether the desired window length shall be odd. - use_scipy : bool, optional - Whether to use scipy's kaiserord() function, which gives - an approx. 2x longer window than the original function clean_rawdata. - - Returns - ------- - np.ndarray - The Kaiser window. - """ - from scipy.signal import kaiserord - from scipy.signal.windows import kaiser - - if not use_scipy: - # determine beta of the kaiser window - if atten < 21: - beta = 0 - elif atten < 50: - beta = 0.5842 * (atten - 21) ** 0.4 + 0.07886 * (atten - 21) - else: - beta = 0.1102 * (atten - 8.7) - - # determine the number of points - N = int(round_mat((atten - 7.95) / (2 * np.pi * 2.285 * (hi - lo))) + 1) - else: - N, beta = kaiserord(atten, hi - lo) - - if want_odd and N % 2 == 0: - N = N + 1 - # design the actual window - return kaiser(N, beta, sym=True) - - -def design_fir( - n: int, - f: Union[np.ndarray, Sequence[float]], - a: Union[np.ndarray, Sequence[float]], - *, - nfft: Optional[int] = None, - w: Optional[np.ndarray] = None, - compat: bool = True, -) -> np.ndarray: - """Design an FIR filter using the frequency-sampling method. - - The frequency response is interpolated cubically between the specified - frequency points. - - Parameters - ---------- - n : int - Order of the filter. - f : array_like - Vector of frequencies at which amplitudes shall be defined - (starts with 0 and goes up to 1; try to avoid too sharp transitions). - a : array_like - Vector of amplitudes, one value per specified frequency. - nfft : int, optional - Optionally number of FFT bins to use. - w : array_like, optional - Optionally the window function to use. - compat : bool, optional - Whether to use the original MATLAB-compatible filter design - (where the window is off by 1 sample). - - Returns - ------- - np.ndarray - The filter coefficients. - """ - from scipy.interpolate import PchipInterpolator - - f, a = np.asarray(f), np.asarray(a) - if nfft is None: - nfft = max([512, 2 ** np.ceil(np.log(n) / np.log(2))]) - if w is None: - if compat: - w = 0.54 - 0.46 * np.cos(2 * np.pi * np.arange(n + 1) / n) - else: - from scipy.signal.windows import hamming - - w = hamming(n) - - # calculate interpolated frequency response - # noinspection PyTypeChecker - f = PchipInterpolator(round_mat(f * nfft), a)(np.arange(nfft + 1)) - - # set phase & transform into time domain - f = f * np.exp(-(0.5 * n) * 1j * np.pi * np.arange(nfft + 1) / nfft) - b = np.real(np.fft.ifft(np.concatenate((f, np.conj(f[::-1][1:-1]))))) - - # apply window to kernel - return b[: len(w)] * w +__all__ = ['filtfilt_fast', 'moving_average'] def filtfilt_fast( diff --git a/src/eegprep/plugins/firfilt/__init__.py b/src/eegprep/plugins/firfilt/__init__.py index a926cc42..c4d3a4bf 100644 --- a/src/eegprep/plugins/firfilt/__init__.py +++ b/src/eegprep/plugins/firfilt/__init__.py @@ -6,6 +6,8 @@ from typing import Any _LAZY_EXPORTS = { + "design_fir": ("eegprep.plugins.firfilt.design", "design_fir"), + "design_kaiser": ("eegprep.plugins.firfilt.design", "design_kaiser"), "findboundaries": ("eegprep.plugins.firfilt.findboundaries", "findboundaries"), "fir_filterdcpadded": ("eegprep.plugins.firfilt.fir_filterdcpadded", "fir_filterdcpadded"), "firfiltreport": ("eegprep.plugins.firfilt.firfiltreport", "firfiltreport"), diff --git a/src/eegprep/plugins/firfilt/design.py b/src/eegprep/plugins/firfilt/design.py new file mode 100644 index 00000000..7bd3e9af --- /dev/null +++ b/src/eegprep/plugins/firfilt/design.py @@ -0,0 +1,107 @@ +"""FIR design helpers shared by firfilt and clean_rawdata.""" + +from __future__ import annotations + +from typing import Optional, Sequence, Union + +import numpy as np + +from eegprep.functions.miscfunc.misc import round_mat + +__all__ = ["design_fir", "design_kaiser"] + + +def design_kaiser(lo: float, hi: float, atten: float, want_odd: bool, use_scipy: bool = False) -> np.ndarray: + """Design a Kaiser window for a low-pass FIR filter. + + Parameters + ---------- + lo : float + Normalized lower edge of the transition band. + hi : float + Normalized upper edge of the transition band. + atten : float + Stop-band attenuation in dB (-20log10(ratio)). + want_odd : bool + Whether the desired window length shall be odd. + use_scipy : bool, optional + Whether to use scipy's kaiserord() function, which gives + an approx. 2x longer window than the original function clean_rawdata. + + Returns + ------- + np.ndarray + The Kaiser window. + """ + from scipy.signal import kaiserord + from scipy.signal.windows import kaiser + + if not use_scipy: + if atten < 21: + beta = 0 + elif atten < 50: + beta = 0.5842 * (atten - 21) ** 0.4 + 0.07886 * (atten - 21) + else: + beta = 0.1102 * (atten - 8.7) + n_points = int(round_mat((atten - 7.95) / (2 * np.pi * 2.285 * (hi - lo))) + 1) + else: + n_points, beta = kaiserord(atten, hi - lo) + + if want_odd and n_points % 2 == 0: + n_points += 1 + return kaiser(n_points, beta, sym=True) + + +def design_fir( + n: int, + f: Union[np.ndarray, Sequence[float]], + a: Union[np.ndarray, Sequence[float]], + *, + nfft: Optional[int] = None, + w: Optional[np.ndarray] = None, + compat: bool = True, +) -> np.ndarray: + """Design an FIR filter using the frequency-sampling method. + + The frequency response is interpolated cubically between the specified + frequency points. + + Parameters + ---------- + n : int + Order of the filter. + f : array_like + Vector of frequencies at which amplitudes shall be defined + (starts with 0 and goes up to 1; try to avoid too sharp transitions). + a : array_like + Vector of amplitudes, one value per specified frequency. + nfft : int, optional + Optionally number of FFT bins to use. + w : array_like, optional + Optionally the window function to use. + compat : bool, optional + Whether to use the original MATLAB-compatible filter design + (where the window is off by 1 sample). + + Returns + ------- + np.ndarray + The filter coefficients. + """ + from scipy.interpolate import PchipInterpolator + + f, a = np.asarray(f), np.asarray(a) + if nfft is None: + nfft = max([512, 2 ** np.ceil(np.log(n) / np.log(2))]) + if w is None: + if compat: + w = 0.54 - 0.46 * np.cos(2 * np.pi * np.arange(n + 1) / n) + else: + from scipy.signal.windows import hamming + + w = hamming(n) + + response = PchipInterpolator(round_mat(f * nfft), a)(np.arange(nfft + 1)) + response = response * np.exp(-(0.5 * n) * 1j * np.pi * np.arange(nfft + 1) / nfft) + b = np.real(np.fft.ifft(np.concatenate((response, np.conj(response[::-1][1:-1]))))) + return b[: len(w)] * w diff --git a/src/eegprep/plugins/firfilt/findboundaries.py b/src/eegprep/plugins/firfilt/findboundaries.py index dbfce13c..ef73c498 100644 --- a/src/eegprep/plugins/firfilt/findboundaries.py +++ b/src/eegprep/plugins/firfilt/findboundaries.py @@ -6,6 +6,7 @@ import numpy as np +from eegprep.functions.miscfunc.event_utils import boundary_event_indices from eegprep.functions.popfunc._file_io import events_to_records @@ -14,14 +15,8 @@ def findboundaries(event: Any) -> np.ndarray: events = events_to_records(event) boundaries: list[int] = [] if events and all("type" in record and "latency" in record for record in events): - first_type = events[0].get("type") - for record in events: - event_type = record.get("type") - is_boundary = ( - isinstance(first_type, str) and isinstance(event_type, str) and event_type.startswith("boundary") - ) or (not isinstance(first_type, str) and event_type == -99) - if not is_boundary: - continue + for index in boundary_event_indices(events): + record = events[index] try: boundaries.append(int(np.fix(float(record.get("latency")) + 0.5))) except (TypeError, ValueError): diff --git a/src/eegprep/resources/skills/eegprep-cli.md b/src/eegprep/resources/skills/eegprep-cli.md index 7b4119e7..97a32c53 100644 --- a/src/eegprep/resources/skills/eegprep-cli.md +++ b/src/eegprep/resources/skills/eegprep-cli.md @@ -58,6 +58,10 @@ steps: - name: qc - name: report format: html + +Pipeline transform steps share direct CLI defaults. For `clean`, unspecified +flatline/channel/line-noise/window/high-pass criteria stay `off`; pass +`highpass: [0.25, 0.75]` when ASR should high-pass first. ``` ## Extended Reference diff --git a/tests/test_bids_load_frombids_helpers.py b/tests/test_bids_load_frombids_helpers.py new file mode 100644 index 00000000..a0bfb8d9 --- /dev/null +++ b/tests/test_bids_load_frombids_helpers.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +from pathlib import Path + +import numpy as np + + +def test_raw_set_loader_returns_eeg_and_timing_metadata() -> None: + from eegprep.plugins.EEG_BIDS.raw import load_raw_eeg_file + + dataset = Path(__file__).resolve().parents[1] / "sample_data" / "eeglab_data.set" + warnings: list[str] = [] + + eeg, srate, times_sec, report = load_raw_eeg_file( + str(dataset), + dtype=np.float64, + numeric_null=np.array([]), + warning=warnings.append, + verbose=False, + ) + + assert report["ImporterUsed"] == "pop_loadset" + assert warnings == [] + assert srate == eeg["srate"] + assert eeg["data"].dtype == np.float64 + np.testing.assert_allclose(times_sec, np.asarray(eeg["times"], dtype=float) / 1000.0) + + +def test_montage_inference_uses_packaged_montage_resources() -> None: + from eegprep.plugins.EEG_BIDS.montage import apply_montage_inference + + numeric_null = np.array([]) + eeg = { + "chanlocs": [ + { + "labels": "'Fp1'", + "X": numeric_null, + "Y": numeric_null, + "Z": numeric_null, + "sph_radius": numeric_null, + "sph_theta": numeric_null, + "sph_phi": numeric_null, + "theta": numeric_null, + "radius": numeric_null, + }, + { + "labels": "Fpz", + "X": numeric_null, + "Y": numeric_null, + "Z": numeric_null, + "sph_radius": numeric_null, + "sph_theta": numeric_null, + "sph_phi": numeric_null, + "theta": numeric_null, + "radius": numeric_null, + }, + { + "labels": "Fp2", + "X": numeric_null, + "Y": numeric_null, + "Z": numeric_null, + "sph_radius": numeric_null, + "sph_theta": numeric_null, + "sph_phi": numeric_null, + "theta": numeric_null, + "radius": numeric_null, + }, + ], + "chaninfo": {"nosedir": "+Y"}, + "etc": {}, + } + report: dict[str, object] = {} + warnings: list[str] = [] + errors: list[str] = [] + + apply_montage_inference( + eeg, + "standard-10-5-342ch.locs", + numeric_null=numeric_null, + report=report, + warning=warnings.append, + error=errors.append, + ) + + assert warnings == [] + assert errors == [] + assert eeg["chaninfo"]["nosedir"] == "+X" + assert eeg["etc"]["labelscheme"] == "10-20" + assert report["ChanlocsFrom"] == "standard-10-5-342ch.locs" + for chanloc in eeg["chanlocs"]: + assert np.isfinite([chanloc["X"], chanloc["Y"], chanloc["Z"]]).all() diff --git a/tests/test_clean_rawdata.py b/tests/test_clean_rawdata.py index 86b3e4ce..364d0c15 100644 --- a/tests/test_clean_rawdata.py +++ b/tests/test_clean_rawdata.py @@ -76,14 +76,14 @@ def setUp(self): self.eeglab = eeglabcompat.get_eeglab('MAT') def test_design_kaiser(self): - from eegprep.plugins.clean_rawdata.private.sigproc import design_kaiser + from eegprep.plugins.firfilt.design import design_kaiser observed = design_kaiser(0.06, 0.08, 75, True) expected = np.asarray(self.eeglab.design_kaiser(0.06, 0.08, 75.0, True)) np.testing.assert_almost_equal(observed.flatten(), expected.flatten(), err_msg='design_kaiser() test failed') def test_design_fir_default_wnd(self): - from eegprep.plugins.clean_rawdata.private.sigproc import design_fir + from eegprep.plugins.firfilt.design import design_fir observed = design_fir(234, [0.0, 0.06, 0.08, 1.0], [0, 0, 1, 1]) expected = np.asarray( @@ -94,7 +94,7 @@ def test_design_fir_default_wnd(self): ) def test_design_fir_custom_wnd(self): - from eegprep.plugins.clean_rawdata.private.sigproc import design_fir, design_kaiser + from eegprep.plugins.firfilt.design import design_fir, design_kaiser wnd = design_kaiser(0.06, 0.08, 75.0, True) observed = design_fir(234, [0.0, 0.06, 0.08, 1.0], [0, 0, 1.0, 1.0], w=wnd) diff --git a/tests/test_clean_rawdata_channel_removal.py b/tests/test_clean_rawdata_channel_removal.py new file mode 100644 index 00000000..d47fc7b7 --- /dev/null +++ b/tests/test_clean_rawdata_channel_removal.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import numpy as np + +from eegprep.plugins.clean_rawdata.private.channel_removal import update_clean_channel_mask + + +def test_update_clean_channel_mask_composites_original_channel_mask(): + eeg = {"etc": {"clean_channel_mask": np.array([True, False, True, True])}} + removed_channels = np.array([False, True, False]) + + update_clean_channel_mask(eeg, removed_channels) + + np.testing.assert_array_equal(eeg["etc"]["clean_channel_mask"], [True, False, False, True]) + + +def test_update_clean_channel_mask_resets_mismatched_existing_mask(): + eeg = {"etc": {"clean_channel_mask": np.array([True, False, True, True])}} + removed_channels = np.array([False, True]) + + update_clean_channel_mask(eeg, removed_channels) + + np.testing.assert_array_equal(eeg["etc"]["clean_channel_mask"], [True, False]) + + +def test_update_clean_channel_mask_creates_zero_and_all_removal_masks(): + no_removal = {"etc": {}} + all_removed = {} + + update_clean_channel_mask(no_removal, np.zeros(3, dtype=bool)) + update_clean_channel_mask(all_removed, np.ones(3, dtype=bool)) + + np.testing.assert_array_equal(no_removal["etc"]["clean_channel_mask"], [True, True, True]) + np.testing.assert_array_equal(all_removed["etc"]["clean_channel_mask"], [False, False, False]) diff --git a/tests/test_cli_main.py b/tests/test_cli_main.py index d772f9d4..086b0ff4 100644 --- a/tests/test_cli_main.py +++ b/tests/test_cli_main.py @@ -93,6 +93,18 @@ def test_json_parse_errors_return_stable_error_code(): assert payload["code"] == "CONFIG_SCHEMA_ERROR" +def test_nested_qc_parse_errors_return_stable_json_error(): + result = _run_cli("qc", "report", "--json") + + assert result.returncode == 2 + assert result.stderr == "" + payload = _json_stdout(result) + assert payload["status"] == "error" + assert payload["schema_version"] == "eegprep.error.v1" + assert payload["code"] == "CONFIG_SCHEMA_ERROR" + assert "arguments are required" in payload["message"] + + def test_batch_run_dry_run_and_qc_pipeline(tmp_path): config = tmp_path / "pipeline.yaml" config.write_text( diff --git a/tests/test_cli_pipeline_qc_report.py b/tests/test_cli_pipeline_qc_report.py index 1f2d18a6..2c2f3d14 100644 --- a/tests/test_cli_pipeline_qc_report.py +++ b/tests/test_cli_pipeline_qc_report.py @@ -4,6 +4,7 @@ import numpy as np import yaml +from eegprep.cli.commands import transforms as transforms_cli from eegprep.cli.commands.pipeline import ( _channel_indices, plan_pipeline_config, @@ -116,25 +117,88 @@ def test_pipeline_clean_after_epoch_uses_pop_clean_continuous_data_guard(tmp_pat assert "Input data must be continuous" in result["message"] -def test_pipeline_filter_history_uses_modern_eegfiltnew(monkeypatch, tmp_path): - from eegprep.cli.commands import pipeline as pipeline_cli +def test_pipeline_clean_uses_direct_cli_defaults(monkeypatch, tmp_path): + calls = [] + + def fake_pop_clean_rawdata(eeg, **kwargs): + calls.append(kwargs) + return eeg, "EEG = pop_clean_rawdata(EEG, 'BurstCriterion', 20);" + + monkeypatch.setattr(transforms_cli, "pop_clean_rawdata", fake_pop_clean_rawdata) + config_path = _write_pipeline_config( + tmp_path, + steps=[{"name": "clean", "method": "asr"}], + ) + + result = run_pipeline_config(config_path) + + assert result["status"] == "ok" + assert calls == [ + { + "FlatlineCriterion": "off", + "ChannelCriterion": "off", + "LineNoiseCriterion": "off", + "Highpass": "off", + "BurstCriterion": 20.0, + "BurstRejection": False, + "WindowCriterion": "off", + "Distance": "Euclidean", + "gui": False, + "return_com": True, + } + ] + + +def test_pipeline_clean_rejects_scalar_highpass(tmp_path): + config_path = _write_pipeline_config( + tmp_path, + steps=[{"name": "clean", "method": "asr", "highpass": 0.5}], + ) + + result = validate_pipeline_config(config_path) + assert result["status"] == "error" + assert result["error"]["code"] == "CONFIG_SCHEMA_ERROR" + assert result["error"]["details"]["errors"][0]["path"] == "steps[0].highpass" + + +def test_pipeline_filter_history_uses_modern_eegfiltnew(monkeypatch, tmp_path): calls = [] def fake_pop_eegfiltnew(eeg, **kwargs): calls.append(kwargs) return eeg, "EEG = pop_eegfiltnew(EEG, 'locutoff', 1);" - monkeypatch.setattr(pipeline_cli, "pop_eegfiltnew", fake_pop_eegfiltnew) + monkeypatch.setattr(transforms_cli, "pop_eegfiltnew", fake_pop_eegfiltnew) config_path = _write_pipeline_config( tmp_path, - steps=[{"name": "filter", "highpass": 1.0}], + steps=[ + { + "name": "filter", + "highpass": 1.0, + "lowpass": 40.0, + "order": 128, + "minphase": True, + "usefftfilt": True, + } + ], ) result = run_pipeline_config(config_path) assert result["status"] == "ok" - assert calls and calls[0]["locutoff"] == 1.0 + assert calls == [ + { + "locutoff": 1.0, + "hicutoff": 40.0, + "filtorder": 128, + "plotfreqz": False, + "minphase": True, + "usefftfilt": True, + "gui": False, + "return_com": True, + } + ] assert "pop_eegfiltnew" in result["history"][0] diff --git a/tests/test_cli_transforms.py b/tests/test_cli_transforms.py index d91d42aa..51de08e7 100644 --- a/tests/test_cli_transforms.py +++ b/tests/test_cli_transforms.py @@ -102,7 +102,7 @@ def test_transform_requires_output_unless_overwrite_is_explicit(): assert result.returncode == 1 payload = _json_stdout(result) assert payload["status"] == "error" - assert payload["error"]["code"] == "OUTPUT_REQUIRED" + assert payload["code"] == "OUTPUT_REQUIRED" assert result.stdout.strip().startswith("{") @@ -138,8 +138,8 @@ def test_transform_refuses_existing_output_without_overwrite(tmp_path): assert result.returncode == 1 payload = _json_stdout(result) - assert payload["error"]["code"] == "OUTPUT_EXISTS" - assert payload["error"]["path"] == str(output) + assert payload["code"] == "OUTPUT_EXISTS" + assert payload["path"] == str(output) assert output.read_text(encoding="utf-8") == "already here" @@ -160,7 +160,7 @@ def test_transform_refuses_output_manifest_path_collision(tmp_path): assert result.returncode == 1 payload = _json_stdout(result) - assert payload["error"]["code"] == "OUTPUT_PATH_COLLISION" + assert payload["code"] == "OUTPUT_PATH_COLLISION" assert not output.exists() diff --git a/tests/test_console_workspace.py b/tests/test_console_workspace.py index cc85f054..f45dd7d6 100644 --- a/tests/test_console_workspace.py +++ b/tests/test_console_workspace.py @@ -427,6 +427,28 @@ def test_console_eegh_string_command_notifies_session_listeners(): workspace.close() +def test_console_eegh_clear_and_remove_notify_session_listeners(): + session = EEGPrepSession() + workspace = EEGPrepConsoleWorkspace(session, exports={}) + session.add_history("EEG = first;") + session.add_history("EEG = second;") + notified: list[tuple[list[str], str]] = [] + session.add_change_listener(lambda _session: notified.append((list(session.ALLCOM), session.LASTCOM))) + + assert workspace.namespace["eegh"](-1) == "" + assert session.ALLCOM == ["EEG = first;"] + assert session.LASTCOM == "EEG = first;" + assert workspace.namespace["LASTCOM"] == "EEG = first;" + + assert workspace.namespace["eegh"](0) == "" + assert session.ALLCOM == [] + assert session.LASTCOM == "" + assert workspace.namespace["ALLCOM"] == [] + assert workspace.namespace["LASTCOM"] == "" + assert notified == [(["EEG = first;"], "EEG = first;"), ([], "")] + workspace.close() + + def test_menu_actions_reuses_console_pop_result_decoders(): # The GUI extension-result path delegates to the canonical console decoders # instead of keeping its own copies. diff --git a/tests/test_eeg_eegrej.py b/tests/test_eeg_eegrej.py index 23f814e5..fed11680 100644 --- a/tests/test_eeg_eegrej.py +++ b/tests/test_eeg_eegrej.py @@ -6,6 +6,7 @@ # Assume eeg_eegrej is defined as in your module that imports: from eegrej import eegrej from eegprep import eeg_eegrej from eegprep.functions.adminfunc.eeglabcompat import get_eeglab +from eegprep.functions.adminfunc.eeg_options import EEG_OPTIONS from eegprep.functions.popfunc.pop_loadset import pop_loadset from eegprep.functions.adminfunc.eeg_checkset import eeg_checkset @@ -315,6 +316,21 @@ def test_eeg_eegrej_boundary_event_insertion(self): self.assertLessEqual(inserted_boundary['latency'], result['pnts'] + 1.0) self.assertEqual(inserted_boundary['duration'], 5.0) # Length of removed region + def test_eeg_eegrej_preserves_numeric_boundary99_when_enabled(self): + """Numeric -99 boundaries use the canonical boundary contract.""" + EEG = self.base_eeg.copy() + EEG["event"][1]["type"] = -99 + old = EEG_OPTIONS["option_boundary99"] + EEG_OPTIONS["option_boundary99"] = 1 + try: + result = eeg_eegrej(EEG, np.array([[6, 10]], dtype=int)) + finally: + EEG_OPTIONS["option_boundary99"] = old + + preserved = [event for event in result["event"] if event.get("type") == -99] + self.assertEqual(len(preserved), 1) + self.assertEqual(preserved[0]["latency"], 6.0) + def test_eeg_eegrej_event_latency_preservation(self): """Test eeg_eegrej preserves 1-based event latencies correctly.""" EEG = self.base_eeg.copy() diff --git a/tests/test_eeg_runica.py b/tests/test_eeg_runica.py index 61c995f1..6bef85f3 100644 --- a/tests/test_eeg_runica.py +++ b/tests/test_eeg_runica.py @@ -143,6 +143,28 @@ def test_finalize_ica_fields_shared_sort_and_sign_normalization(): np.testing.assert_allclose(out["icawinv"], pinv(out["icaweights"] @ out["icasphere"])) +def test_finalize_ica_fields_recomputes_stale_backend_inverse_after_sorting(): + """Final ICA fields must be invariant even if a backend reports stale maps.""" + from eegprep.functions.popfunc._ica_utils import finalize_ica_fields + + rng = np.random.default_rng(17) + nbchan, pnts, trials = 3, 5, 2 + sphere = np.array([[1.0, 0.2, 0.0], [0.0, 1.5, 0.1], [0.3, 0.0, 2.0]]) + weights = np.array([[2.0, 0.0, 0.5], [0.1, 1.0, 0.0], [0.0, -0.4, 1.5]]) + data = rng.standard_normal((nbchan, pnts * trials)) + icaact = (weights @ sphere) @ data + eeg = { + "icaweights": weights.copy(), + "icasphere": sphere.copy(), + "icawinv": np.full((nbchan, nbchan), 99.0), + "icaact": icaact.reshape(nbchan, pnts, trials, order="F"), + } + + out = finalize_ica_fields(eeg, sortcomps=True, posact=False) + + np.testing.assert_allclose(out["icawinv"], pinv(out["icaweights"] @ out["icasphere"])) + + def test_pop_runica_concatenates_epoched_datasets_in_eeglab_order(monkeypatch): first = _epoched_eeg() second = _epoched_eeg(offset=100) diff --git a/tests/test_extension_catalog.py b/tests/test_extension_catalog.py index d656ea6a..a723abe7 100644 --- a/tests/test_extension_catalog.py +++ b/tests/test_extension_catalog.py @@ -12,10 +12,13 @@ import pytest +import eegprep.extension_catalog as manager_catalog from eegprep.extension_catalog import ( CATALOG_KIND_CURATION, CATALOG_KIND_MANAGER, CATALOG_SCHEMA_VERSION, +) +from eegprep.extension_catalog_validation import ( CatalogValidationOptions, load_catalog_entries, main, @@ -103,6 +106,14 @@ def test_catalog_cli_emits_json_report(tmp_path: Path, capsys: pytest.CaptureFix assert json.loads(captured.out) == {"ok": True, "errors": [], "warnings": []} +def test_catalog_validator_entry_point_is_owned_by_validation_module() -> None: + pyproject_text = (Path(__file__).resolve().parents[1] / "pyproject.toml").read_text(encoding="utf-8") + + assert 'eegprep-validate-extension-catalog = "eegprep.extension_catalog_validation:main"' in pyproject_text + assert not hasattr(manager_catalog, "main") + assert not hasattr(manager_catalog, "validate_catalog_file") + + def test_schema_version_mismatch_is_reported(tmp_path: Path) -> None: catalog = tmp_path / "catalog.json" catalog.write_text( diff --git a/tests/test_extensions.py b/tests/test_extensions.py index ce7d4f83..97060251 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -21,7 +21,12 @@ ExtensionSpec, ExtensionStatus, LazyImport, + extension_api_major_version, + extension_entry_point_package_name, + extension_status_is_active, + extension_status_is_installed, extension_version_satisfies, + select_extension_entry_points, validate_extension_spec, ) @@ -137,6 +142,19 @@ def provider() -> SelectableEntryPoints: assert [record.name for record in records] == ["selectable_extension"] +def test_extension_metadata_helpers_are_shared_for_registry_and_validation() -> None: + entry_point = FakeEntryPoint("helper", "helper_pkg.register:register") + + assert select_extension_entry_points(_provider(entry_point), EXTENSION_ENTRY_POINT_GROUP) == (entry_point,) + assert extension_entry_point_package_name(entry_point) == "helper" + assert extension_api_major_version("1.2.3") == 1 + assert extension_status_is_active(ExtensionStatus.INSTALLED) + assert extension_status_is_active(ExtensionStatus.BUNDLED.value) + assert not extension_status_is_active(ExtensionStatus.FAILED_IMPORT) + assert extension_status_is_installed(ExtensionStatus.FAILED_IMPORT) + assert not extension_status_is_installed(ExtensionStatus.CURATED) + + def test_registry_records_are_deterministic(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: _write_basic_extension(tmp_path, monkeypatch, "beta_pkg", "beta_extension", action="beta_action") _write_basic_extension(tmp_path, monkeypatch, "alpha_pkg", "alpha_extension", action="alpha_action") diff --git a/tests/test_firfilt_helpers.py b/tests/test_firfilt_helpers.py index bddfc6da..8db63469 100644 --- a/tests/test_firfilt_helpers.py +++ b/tests/test_firfilt_helpers.py @@ -8,6 +8,7 @@ import pytest from eegprep.functions.adminfunc.eeglabcompat import get_eeglab +from eegprep.functions.adminfunc.eeg_options import EEG_OPTIONS from eegprep.plugins.firfilt.fir_filterdcpadded import fir_filterdcpadded from eegprep.plugins.firfilt.findboundaries import findboundaries from eegprep.plugins.firfilt.firfiltreport import firfiltreport @@ -62,6 +63,22 @@ def test_firws_and_firwsord_are_owned_by_firfilt_plugin(): np.testing.assert_allclose(b, b[::-1], atol=1e-12) +def test_clean_rawdata_fir_design_helpers_are_owned_by_firfilt_plugin(): + """Clean rawdata imports FIR design helpers downward from firfilt.""" + from eegprep.plugins.clean_rawdata.private import sigproc + from eegprep.plugins.firfilt.design import design_fir, design_kaiser + + assert design_fir.__module__ == "eegprep.plugins.firfilt.design" + assert design_kaiser.__module__ == "eegprep.plugins.firfilt.design" + assert not hasattr(sigproc, "design_fir") + assert not hasattr(sigproc, "design_kaiser") + + window = design_kaiser(0.06, 0.08, 75.0, True) + coeffs = design_fir(234, [0.0, 0.06, 0.08, 1.0], [0, 0, 1, 1], w=window) + assert window.size % 2 == 1 + assert coeffs.size == 235 + + def test_invfirwsord_returns_transition_width_and_window_deviation(): df, dev = invfirwsord("hamming", 500, 826) @@ -108,7 +125,14 @@ def test_findboundaries_returns_eeglab_boundary_latencies(): ] np.testing.assert_array_equal(findboundaries(events), [1, 40, 101]) - np.testing.assert_array_equal(findboundaries([{"type": -99, "latency": 12.6}]), [1, 13]) + old = EEG_OPTIONS["option_boundary99"] + try: + EEG_OPTIONS["option_boundary99"] = 0 + np.testing.assert_array_equal(findboundaries([{"type": -99, "latency": 12.6}]), [1]) + EEG_OPTIONS["option_boundary99"] = 1 + np.testing.assert_array_equal(findboundaries([{"type": -99, "latency": 12.6}]), [1, 13]) + finally: + EEG_OPTIONS["option_boundary99"] = old np.testing.assert_array_equal(findboundaries([]), [1]) diff --git a/tests/test_gui_pop_firfilt.py b/tests/test_gui_pop_firfilt.py index 9ed3d174..f8bcff44 100644 --- a/tests/test_gui_pop_firfilt.py +++ b/tests/test_gui_pop_firfilt.py @@ -5,7 +5,8 @@ import numpy as np from eegprep.functions.guifunc.spec import controls_by_tag -from eegprep.functions.guifunc.qt import _firpm_order_shape +from eegprep.functions.guifunc import qt as qt_renderer +from eegprep.functions.guifunc.qt import QtDialogRenderer, _firpm_order_shape from eegprep.functions.popfunc.pop_eegfilt import pop_eegfilt, pop_eegfilt_dialog_spec from eegprep.plugins.firfilt.pop_eegfiltnew import pop_eegfiltnew, pop_eegfiltnew_dialog_spec from eegprep.plugins.firfilt.pop_firma import pop_firma, pop_firma_dialog_spec @@ -175,6 +176,10 @@ def test_firfilt_order_dialog_specs_are_eeglab_labeled(self): self.assertNotIn("f", firpm_controls) self.assertNotIn("a", firpm_controls) + def test_qt_renderer_stateless_helpers_have_module_ownership(self): + self.assertIs(QtDialogRenderer._read_widget, qt_renderer._read_widget) + self.assertIs(QtDialogRenderer._validation_message, qt_renderer._validation_message) + def test_firpm_estimate_order_shape_uses_paired_edges_for_single_cutoff_filters(self): highpass_edges, highpass_amplitudes = _firpm_order_shape([8], 4, "highpass", 200) lowpass_edges, lowpass_amplitudes = _firpm_order_shape([30], 4, "lowpass", 200) diff --git a/tests/test_phase1b_file_edit_pop_functions.py b/tests/test_phase1b_file_edit_pop_functions.py index c38c75aa..9315b40b 100644 --- a/tests/test_phase1b_file_edit_pop_functions.py +++ b/tests/test_phase1b_file_edit_pop_functions.py @@ -8,6 +8,7 @@ import pytest from eegprep.functions.adminfunc.console import _console_python_command +from eegprep.functions.adminfunc.eeg_options import EEG_OPTIONS from eegprep.functions.guifunc.select_multiple_datasets import select_multiple_datasets from eegprep.functions.guifunc.session import EEGPrepSession from eegprep.functions.popfunc.pop_chanedit import pop_chanedit @@ -198,13 +199,17 @@ def test_pop_selectevent_renames_events_before_epoched_trial_selection(): def test_pop_selectevent_keeps_numeric_boundary_when_deleting_continuous_events(): eeg = _eeg() - eeg["event"].insert(1, {"type": -1, "latency": 25.0, "duration": 0.0, "urevent": 3}) - eeg["urevent"].append({"type": -1, "latency": 25.0, "duration": 0.0}) - - out, selected = pop_selectevent(eeg, "type", "stim", "deleteevents", "on") + eeg["event"].insert(1, {"type": -99, "latency": 25.0, "duration": 0.0, "urevent": 3}) + eeg["urevent"].append({"type": -99, "latency": 25.0, "duration": 0.0}) + old = EEG_OPTIONS["option_boundary99"] + EEG_OPTIONS["option_boundary99"] = 1 + try: + out, selected = pop_selectevent(eeg, "type", "stim", "deleteevents", "on") + finally: + EEG_OPTIONS["option_boundary99"] = old assert selected == [1, 2] - assert [event["type"] for event in out["event"]] == ["stim", -1] + assert [event["type"] for event in out["event"]] == ["stim", -99] def test_pop_rmdat_removes_or_keeps_continuous_windows_around_events(): @@ -219,6 +224,21 @@ def test_pop_rmdat_removes_or_keeps_continuous_windows_around_events(): _assert_python_echo_is_parseable(command) +def test_pop_rmdat_uses_numeric_boundary99_to_limit_windows(): + eeg = _eeg() + eeg["event"].insert(1, {"type": -99, "latency": 15.0, "duration": 0.0, "urevent": 3}) + old = EEG_OPTIONS["option_boundary99"] + try: + EEG_OPTIONS["option_boundary99"] = 1 + limited = pop_rmdat(eeg, ["stim"], [-0.05, 0.1], 1) + EEG_OPTIONS["option_boundary99"] = 0 + unlimited = pop_rmdat(eeg, ["stim"], [-0.05, 0.1], 1) + finally: + EEG_OPTIONS["option_boundary99"] = old + + assert limited["pnts"] > unlimited["pnts"] + + def test_pop_rmdat_matches_sorted_event_behavior_when_events_are_unsorted(): unsorted_eeg = _eeg() unsorted_eeg["event"] = [ diff --git a/tests/test_phase4_timefreq_statistics.py b/tests/test_phase4_timefreq_statistics.py index eafacd20..6ee5154c 100644 --- a/tests/test_phase4_timefreq_statistics.py +++ b/tests/test_phase4_timefreq_statistics.py @@ -33,17 +33,25 @@ from eegprep.functions.sigprocfunc.signalstat import signalstat from eegprep.functions.guifunc.tf_cycle_calc_dialog import tf_cycle_calc_dialog_spec from eegprep.functions.timefreqfunc.angtimewarp import angtimewarp -from eegprep.functions.timefreqfunc.bootstat import bootstat, bootstrap_threshold +from eegprep.functions.timefreqfunc._bootstrap import ( + bootstrap_indices, + resample_trials, + threshold_vector, + thresholds_by_frequency, +) +from eegprep.functions.timefreqfunc.bootstat import bootstat, bootstrap_threshold, exact_p_values from eegprep.functions.timefreqfunc.correct_mc import correct_mc from eegprep.functions.timefreqfunc.correctfit import correctfit from eegprep.functions.timefreqfunc.dftfilt import dftfilt from eegprep.functions.timefreqfunc.dftfilt2 import dftfilt2 from eegprep.functions.timefreqfunc.dftfilt3 import dftfilt3 from eegprep.functions.statistics.fdr import fdr +from eegprep.functions.statistics.stat_surrogate_pvals import stat_surrogate_pvals from eegprep.functions.timefreqfunc.newcrossf import _is_on as newcrossf_is_on from eegprep.functions.timefreqfunc.newcrossf import _threshold_vector as newcrossf_threshold_vector from eegprep.functions.timefreqfunc.newcrossf import _upper_thresholds_by_frequency from eegprep.functions.timefreqfunc.newcrossf import newcrossf +from eegprep.functions.timefreqfunc._pac_support import _empirical_pvalue as pac_empirical_pvalue from eegprep.functions.timefreqfunc.newtimef import _is_on as newtimef_is_on from eegprep.functions.timefreqfunc.newtimef import _significance_mask, _thresholds_by_frequency from eegprep.functions.timefreqfunc.newtimef import _threshold_vector as newtimef_threshold_vector @@ -106,6 +114,8 @@ def test_newtimef_is_on_uses_whitelist_semantics(): assert newtimef_is_on("yes-please") is False assert newtimef_is_on("display") is False assert newtimef_is_on("off") is False + assert newtimef_is_on([0, 1]) is False + assert newtimef_is_on(np.array([0, 1])) is False def test_newtimef_fails_loudly_on_unimplemented_overlap_and_plotphase(): @@ -577,6 +587,35 @@ def test_timefreq_threshold_helpers_pool_through_canonical_bootstrap_threshold() assert _upper_thresholds_by_frequency(single, alpha=0.1).shape == (1,) +def test_timefreq_shared_bootstrap_helpers_cover_newtimef_and_newcrossf_paths(): + times = np.asarray([-100.0, 0.0, 0.5, 100.0, 200.0]) + baseln = np.asarray([0, 1], dtype=int) + + np.testing.assert_array_equal(bootstrap_indices(times, baseline=0, baseboot=[], baseln=baseln), baseln) + np.testing.assert_array_equal(bootstrap_indices(times, baseline=np.nan, baseboot=1, baseln=baseln), [0, 1]) + np.testing.assert_array_equal(bootstrap_indices(times, baseline=np.nan, baseboot=0, baseln=baseln), []) + np.testing.assert_array_equal(bootstrap_indices(times, baseline=np.nan, baseboot=[50, 200], baseln=baseln), [3, 4]) + np.testing.assert_array_equal(bootstrap_indices(times, baseboot=1, baseln=None, limit_to_baseboot=True), [0, 1, 2]) + + surrogates = np.arange(24, dtype=float).reshape(2, 3, 4) + np.testing.assert_allclose( + thresholds_by_frequency(surrogates, alpha=0.1, bootside="both"), + _thresholds_by_frequency(surrogates, alpha=0.1, both=True), + ) + np.testing.assert_allclose( + thresholds_by_frequency(surrogates, alpha=0.1, bootside="upper"), + _upper_thresholds_by_frequency(surrogates, alpha=0.1), + ) + assert threshold_vector(2.0, (3, 4)).shape == (3, 4) + assert threshold_vector(np.asarray([1.0, 2.0, 3.0]), (3, 4)).shape == (3, 1) + + values = np.arange(24, dtype=float).reshape(2, 3, 4) + shuffled = resample_trials(values, np.random.default_rng(0), "shuffle") + randomized = resample_trials(values.astype(complex), np.random.default_rng(0), "rand", complex_phase=True) + assert shuffled.shape == values.shape + np.testing.assert_allclose(np.abs(randomized), np.abs(values)) + + def test_newtimef_fdr_branch_matches_canonical_fdr_threshold(): rng = np.random.default_rng(11) pvalues = rng.random(size=(4, 6)) @@ -591,6 +630,7 @@ def test_newtimef_fdr_branch_matches_canonical_fdr_threshold(): def test_timefreq_is_on_and_threshold_vector_are_the_canonical_shared_helpers(): # newcrossf reuses newtimef's threshold helper and the canonical is_on whitelist. assert newcrossf_threshold_vector is newtimef_threshold_vector + assert newtimef_threshold_vector is threshold_vector assert newcrossf_is_on is newtimef_is_on assert newcrossf_is_on("on") is True assert newcrossf_is_on("display") is False @@ -611,6 +651,17 @@ def statistic(value): bootstat(data, statistic=statistic, basevect=[0], naccu=1, rng=0) +def test_empirical_pvalue_conventions_are_intentionally_distinct(): + distribution = np.asarray([1.0, 2.0, 3.0, 4.0]) + observed = 5.0 + + assert pac_empirical_pvalue(distribution, observed) == pytest.approx(1 / 5) + np.testing.assert_allclose( + stat_surrogate_pvals(distribution[np.newaxis, :], np.asarray([observed]), "right"), [0.0] + ) + np.testing.assert_allclose(exact_p_values(observed, distribution, center=0.0), 0.0) + + def test_correct_mc_returns_phase4_standalone_shapes(): rng = np.random.default_rng(0) eeg = { diff --git a/tests/test_plugin_menu.py b/tests/test_plugin_menu.py index 466abe03..95f005f2 100644 --- a/tests/test_plugin_menu.py +++ b/tests/test_plugin_menu.py @@ -58,6 +58,17 @@ def test_bundled_plugins_match_extension_registry_records() -> None: assert plugin["tags"] == record.spec.capabilities +def test_bundled_plugins_use_registry_menu_projection() -> None: + plugin_inventory = {plugin["plugin"]: plugin for plugin in bundled_plugins()} + manager_inventory = { + plugin["plugin"]: plugin for plugin in plugin_menu(catalog=_catalog(), include_entry_points=False, show=False) + } + + assert plugin_inventory.keys() == manager_inventory.keys() + for name, plugin in plugin_inventory.items(): + assert plugin["menu"] == manager_inventory[name]["menu"] + + def test_bundled_plugins_returns_copies() -> None: plugins = bundled_plugins() plugins[0]["status"] = "changed" @@ -101,7 +112,7 @@ def test_format_plugin_menu_includes_external_plugin_exclusion() -> None: assert "Available EEGPrep extensions" in text assert "ICLabel" in text - assert "File > Import data > import data > From BIDS folder structure" in text + assert "File > Import data / Export / BIDS tools" in text assert EXTERNAL_PLUGIN_NOTICE in text assert INSTALL_TRUST_WARNING in text diff --git a/tests/test_pop_prop_extended.py b/tests/test_pop_prop_extended.py index c90d5ff4..2f5bb433 100644 --- a/tests/test_pop_prop_extended.py +++ b/tests/test_pop_prop_extended.py @@ -6,6 +6,8 @@ import numpy as np import pytest +from eegprep.plugins.ICLabel import _prop_browser, _prop_numerics +from eegprep.plugins.ICLabel import pop_prop_extended as pop_prop_extended_module from eegprep.plugins.ICLabel.pop_prop_extended import ( DEFAULT_ICLABEL_CLASSES, build_extended_property_data, @@ -180,3 +182,10 @@ def test_missing_classifier_falls_back_to_lightweight_viewprops_display() -> Non assert len(figures[0].eegprep_activity_views) == 2 assert figures[0].eegprep_activity_views[0].state.events plt.close(figures[0]) + + +def test_pop_prop_extended_facade_preserves_public_helper_imports() -> None: + assert pop_prop_extended_module.build_extended_property_data is _prop_numerics.build_extended_property_data + assert pop_prop_extended_module.resolve_classifier_data is _prop_numerics.resolve_classifier_data + assert pop_prop_extended_module.resolve_dipfit_data is _prop_numerics.resolve_dipfit_data + assert _prop_browser.build_navigable_dashboard.__module__.endswith("._prop_browser") diff --git a/tests/test_pop_resample_python.py b/tests/test_pop_resample_python.py index 95c84b47..7c4a13d3 100644 --- a/tests/test_pop_resample_python.py +++ b/tests/test_pop_resample_python.py @@ -2,6 +2,7 @@ import numpy as np +from eegprep.functions.adminfunc.eeg_options import EEG_OPTIONS from eegprep.functions.popfunc.pop_resample import pop_resample @@ -59,6 +60,20 @@ def test_resample_preserves_numpy_event_containers(self): self.assertIsInstance(out["urevent"], np.ndarray) np.testing.assert_allclose([event["latency"] for event in out["urevent"]], [3.5, 5.5, 8.5]) + def test_numeric_boundary99_splits_continuous_segments_when_enabled(self): + eeg = _continuous_eeg() + eeg["event"][1]["type"] = -99 + eeg["urevent"][1]["type"] = -99 + old = EEG_OPTIONS["option_boundary99"] + EEG_OPTIONS["option_boundary99"] = 1 + try: + out = pop_resample(eeg, 50, engine="scipy") + finally: + EEG_OPTIONS["option_boundary99"] = old + + np.testing.assert_allclose([event["latency"] for event in out["event"]], [3.5, 5.5, 8.5]) + np.testing.assert_allclose([event["latency"] for event in out["urevent"]], [3.5, 5.5, 8.5]) + def test_epoched_data_resamples_each_epoch_and_clears_urevents(self): eeg = _continuous_eeg() eeg["data"] = np.arange(80, dtype=np.float32).reshape(2, 20, 2) diff --git a/tests/test_pop_utils.py b/tests/test_pop_utils.py index 80027409..660ad448 100644 --- a/tests/test_pop_utils.py +++ b/tests/test_pop_utils.py @@ -40,6 +40,7 @@ def test_parse_numeric_sequence_handles_eeglab_colon_ranges(self): self.assertEqual(parse_numeric_sequence("1:3", dtype=int), [1, 2, 3]) self.assertEqual(parse_numeric_sequence("5:-2:1", dtype=int), [5, 3, 1]) self.assertEqual(parse_numeric_sequence("[1, 2.5 4]", dtype=float), [1.0, 2.5, 4.0]) + self.assertEqual(parse_numeric_sequence("[1 2; 3 4]", dtype=int), [1, 2, 3, 4]) self.assertEqual(parse_numeric_sequence(["1:2", 4], dtype=int), [1, 2, 4]) parsed = parse_numeric_sequence("nan Inf -Inf", dtype=float) @@ -60,9 +61,13 @@ def test_is_on_normalizes_eeglab_on_off_values(self): self.assertTrue(is_on("on")) self.assertTrue(is_on("1")) self.assertTrue(is_on(True)) + self.assertTrue(is_on([1, 0])) + self.assertTrue(is_on(np.array([1, 0]))) self.assertFalse(is_on("off")) self.assertFalse(is_on("0")) self.assertFalse(is_on(False)) + self.assertFalse(is_on([0, 1])) + self.assertFalse(is_on(np.array([0, 1]))) def test_format_history_value_defaults_to_eeglab_like_literals(self): self.assertEqual(format_history_value("F'z"), "'F''z'") diff --git a/tests/test_rejection_workflows.py b/tests/test_rejection_workflows.py index 0e78225d..1dc0855f 100644 --- a/tests/test_rejection_workflows.py +++ b/tests/test_rejection_workflows.py @@ -672,6 +672,18 @@ def test_channel_and_continuous_rejection_work_on_sample_data_without_ica(): pop_eegthresh(sample, 0, [1], -10, 10, 0, 1) +def test_rejection_component_threshold_recomputes_stale_stored_icaact(): + eeg = _epoched_eeg() + eeg["icaweights"] = 2.0 * np.eye(4) + eeg["icasphere"] = np.eye(4) + eeg["icaact"] = np.zeros((4, eeg["pnts"], eeg["trials"])) + + out, rejected = pop_eegthresh(eeg, 0, [1], -40, 40, 0, 0.79, 0, 0) + + assert rejected == [2] + assert out["reject"]["icarejthresh"].tolist() == [False, True, False, False, False] + + def test_pop_rejchan_default_threshold_matches_gui_zscore_default(): eeg = create_test_eeg(n_channels=2, n_samples=20, n_trials=1, srate=100) eeg["data"] = np.zeros((2, 20)) diff --git a/tests/test_statistics_package.py b/tests/test_statistics_package.py index 26916b47..0ee57430 100644 --- a/tests/test_statistics_package.py +++ b/tests/test_statistics_package.py @@ -1,5 +1,6 @@ from __future__ import annotations +import importlib import os from pathlib import Path from typing import Any @@ -267,6 +268,20 @@ def test_teststat_smoke_helper_runs_deterministic_checks(): assert sorted(result) == ["one_way_f_mean", "paired_t_mean", "two_way_interaction_mean"] +def test_statistics_package_exports_remain_functions_after_submodule_imports(): + import eegprep.functions.statistics as statistics + + fdr_module = importlib.import_module("eegprep.functions.statistics.fdr") + statcond_module = importlib.import_module("eegprep.functions.statistics.statcond") + core_module = importlib.import_module("eegprep.functions.statistics._core") + + assert statistics.fdr is fdr_module.fdr + assert statistics.statcond is statcond_module.statcond + assert statistics.FDRResult is fdr_module.FDRResult + assert statistics.StatcondResult is statcond_module.StatcondResult + assert core_module.fdr is fdr_module.fdr + + @pytest.fixture(scope="module") def matlab_statistics_engine(): stats_dir = _eeglab_statistics_reference_dir() diff --git a/tests/test_study_clustering.py b/tests/test_study_clustering.py index e172df7f..67c1792a 100644 --- a/tests/test_study_clustering.py +++ b/tests/test_study_clustering.py @@ -1,5 +1,7 @@ from __future__ import annotations +import ast + import numpy as np import pytest from matplotlib import pyplot as plt @@ -16,6 +18,7 @@ from eegprep.functions.studyfunc.robust_kmeans import robust_kmeans from eegprep.functions.studyfunc.std_apcluster import std_apcluster from eegprep.functions.studyfunc.std_centroid import std_centroid +from eegprep.functions.studyfunc.std_clustplot import std_clustplot from eegprep.functions.studyfunc.std_createclust import std_createclust from eegprep.functions.studyfunc.std_findoutlierclust import std_findoutlierclust from eegprep.functions.studyfunc.std_mergeclust import std_mergeclust @@ -225,6 +228,36 @@ def test_pop_clust_outlier_threshold_uses_mean_distance_guard(): assert not any(str(cluster["name"]).startswith("outlier") for cluster in clustered["cluster"]) +def test_pop_clust_finite_outliers_uses_robust_kmeans_provenance_and_labels(): + study, alleeg = _preclustered_study() + data = np.asarray([[0.0, 0.0], [0.1, 0.0], [0.2, 0.0], [5.0, 5.0], [5.1, 5.0], [20.0, 20.0]]) + study["etc"]["preclust"]["preclustdata"] = data.tolist() + expected_labels, _centers, _sumd, _distances, expected_outliers = robust_kmeans( + data, + 2, + STD=2.5, + MAXiter=5, + method="kmeans", + random_state=7, + ) + + clustered, command = pop_clust(study, alleeg, clus_num=2, outliers=2.5, random_state=7, return_com=True) + + outlier_clusters = [ + cluster for cluster in clustered["cluster"] if str(cluster.get("name") or "").startswith("outlier") + ] + assert len(outlier_clusters) == 1 + np.testing.assert_allclose(outlier_clusters[0]["preclust"]["preclustdata"], data[expected_outliers - 1]) + assert all(cluster.get("algorithm") == ["robust_kmeans", 2] for cluster in clustered["cluster"][1:]) + assert "outliers=2.5" in command + ast.parse(command) + + rows_by_cluster = [] + for cluster in clustered["cluster"][1:]: + rows_by_cluster.extend(np.asarray(cluster["preclust"]["preclustdata"], dtype=float).tolist()) + np.testing.assert_allclose(sorted(rows_by_cluster), sorted(data[expected_labels >= 0].tolist())) + + def test_std_createclust_numbers_outliers_separately_from_clusters(): study, alleeg = _preclustered_study() @@ -274,6 +307,21 @@ def test_cluster_gui_all_selection_expands_to_all_clusters(): plt.close(figure) +def test_std_clustplot_history_omits_empty_cluster_selection_and_replays(): + study, alleeg = _preclustered_study() + study = pop_clust(study, alleeg, clus_num=2, random_state=11) + + _study, command, figure = std_clustplot(study, alleeg, clusters=[], return_com=True) + + assert "clusters=" not in command + ast.parse(command) + namespace = {"STUDY": study, "ALLEEG": alleeg, "std_clustplot": std_clustplot} + exec(command, namespace) + assert "FIGURE" in namespace + plt.close(figure) + plt.close(namespace["FIGURE"]) + + def test_moveoutlier_reuses_outlier_cluster_after_source_rename(): study, alleeg = _preclustered_study() study = pop_clust(study, alleeg, clus_num=2, random_state=11) diff --git a/tests/test_study_long_tail_helpers.py b/tests/test_study_long_tail_helpers.py index 84243248..786245f1 100644 --- a/tests/test_study_long_tail_helpers.py +++ b/tests/test_study_long_tail_helpers.py @@ -81,6 +81,40 @@ def test_independent_variable_selection_and_trialinfo_queries_are_1_based(): assert [row["type"] for row in combined] == ["rare", "standard", "rare"] +def test_columnar_trialinfo_rows_are_shared_by_study_helpers(): + datasetinfo = [ + { + "subject": "S01", + "condition": "target", + "trialinfo": {"type": ["rare", "standard"], "rt": [320.0, 410.0]}, + }, + { + "subject": "S02", + "condition": "standard", + "trialinfo": {"type": ["rare"], "rt": [350.0]}, + }, + ] + study = {"datasetinfo": deepcopy(datasetinfo)} + + combined = std_combtrialinfo(datasetinfo, [1, 2]) + selected, trials = std_selectdataset(study, None, "type", "rare") + factors, factorvals, subjects, paired = std_getindvar(study, mode="trialinfo") + rt_trials, rt_values = std_gettrialsind( + {"trialinfo": datasetinfo[0]["trialinfo"]}, "rt", "300<400", return_values=True + ) + + assert [row["condition"] for row in combined] == ["target", "target", "standard"] + assert [row["type"] for row in combined] == ["rare", "standard", "rare"] + assert selected == [1, 2] + assert trials == [[1], [1]] + assert set(factors) == {"rt", "type"} + assert factorvals[factors.index("type")] == ["rare", "standard"] + assert subjects[factors.index("type")] == [["S01", "S02"], ["S01"]] + assert paired[factors.index("type")] == "off" + assert rt_trials == [1] + assert rt_values == [[320.0]] + + def test_indvarmatch_and_gettrialsind_validate_standalone_inputs(): assert std_indvarmatch("target", ["standard", "target", "target"]) == [2, 3] assert std_indvarmatch([2, 3], [1, 2, 3]) == [2, 3] diff --git a/tools/validate_extension_catalog.py b/tools/validate_extension_catalog.py index 891c4ac7..6b2ce28e 100644 --- a/tools/validate_extension_catalog.py +++ b/tools/validate_extension_catalog.py @@ -2,7 +2,7 @@ from __future__ import annotations -from eegprep.extension_catalog import main +from eegprep.extension_catalog_validation import main if __name__ == "__main__":