diff --git a/.agents/skills/oc-autoreview-adapted/SKILL.md b/.agents/skills/oc-autoreview-adapted/SKILL.md deleted file mode 100644 index a0810028..00000000 --- a/.agents/skills/oc-autoreview-adapted/SKILL.md +++ /dev/null @@ -1,166 +0,0 @@ ---- -name: oc-autoreview-adapted -description: Run an autonomous EEGPrep-focused structured autoreview on local changes, branches, commits, or PRs using the bundled Codex helper. Use when the user asks for autoreview, OC autoreview, closeout review, second-pass review, final review before commit/push/PR, or when non-trivial EEGPrep code changes need a high-signal correctness, EEGLAB parity, GUI/session, tests, and repo-instruction check. ---- - -# OC Autoreview Adapted - -Run the bundled structured review helper as an autonomous closeout check for -EEGPrep. This skill adapts the OpenClaw autoreview principles to this project: -one frozen diff bundle, one structured JSON result, validated changed-file -findings, read-only inspection, heartbeat progress, optional parallel tests, and -repeat-until-clean behavior. - -## Contract - -- Run the helper for real unless the user explicitly asks for a plan or manual - review only. -- Treat review output as advisory. Verify every accepted finding by reading the - real code path and adjacent files before fixing or reporting it. -- Keep going until the helper exits cleanly with no accepted/actionable findings - or until you consciously reject a remaining finding with a concrete reason. -- If a review-triggered fix changes code, rerun focused tests and rerun the - helper on the same target. -- Do not run nested review tools from inside a review. The helper builds one - bundle, calls Codex in read-only mode, validates the result, and exits. -- Do not push, stage, commit, or open a PR just to run autoreview. Do those only - when the user requested that action. -- Be patient. The helper prints heartbeat lines such as - `review still running: codex elapsed=... pid=...`; those are healthy progress. - -## Helper - -Use the repo-local helper: - -```bash -.agents/skills/oc-autoreview-adapted/scripts/autoreview --help -``` - -The helper: - -- defaults to Codex with read-only sandboxing and web search enabled; -- chooses dirty local changes first in `--mode auto`; -- otherwise uses the current PR base when discoverable, then `origin/develop`; -- accepts `--mode local`, `--mode branch --base origin/develop`, and - `--mode commit --commit HEAD`; -- includes root/scoped `AGENTS.md` instructions in the review bundle; -- validates structured JSON against an EEGPrep-specific schema; -- filters findings to changed files only; -- exits nonzero when accepted/actionable findings remain; -- supports `--prompt`, `--prompt-file`, `--dataset`, `--json-output`, - `--output`, `--parallel-tests`, `--require-finding`, `--expect-findings`, - `--no-web-search`, `--model`, and `--thinking`. - -The smoke harness creates a temporary EEG-style fixture repo: - -```bash -.agents/skills/oc-autoreview-adapted/scripts/test-review-harness --dry-run -``` - -Run the full harness only when it is acceptable to spend a real Codex review: - -```bash -.agents/skills/oc-autoreview-adapted/scripts/test-review-harness --fixture buggy -``` - -## Pick Target - -Use the smallest target that covers the request. - -Dirty local work: - -```bash -.agents/skills/oc-autoreview-adapted/scripts/autoreview --mode local -``` - -Branch or PR work: - -```bash -.agents/skills/oc-autoreview-adapted/scripts/autoreview --mode branch --base origin/develop -``` - -If an open PR exists, prefer its actual base: - -```bash -base=$(gh pr view --json baseRefName --jq .baseRefName) -.agents/skills/oc-autoreview-adapted/scripts/autoreview --mode branch --base "origin/$base" -``` - -Committed single change: - -```bash -.agents/skills/oc-autoreview-adapted/scripts/autoreview --mode commit --commit HEAD -``` - -Do not force local mode after committing. A clean local review only proves there -is no dirty patch. - -## Parallel Closeout - -It is OK to run focused tests concurrently with review after formatting-sensitive -work is done: - -```bash -.agents/skills/oc-autoreview-adapted/scripts/autoreview \ - --parallel-tests "uv run pytest tests/test_pop_select.py" -``` - -If tests or review findings lead to edits, rerun the affected tests and rerun -autoreview. Stop when the final helper run exits 0 with no accepted/actionable -findings. Do not run another review only for cleaner wording. - -## EEGPrep Review Surface - -The helper prompt asks Codex to prioritize: - -- correctness bugs, import/runtime failures, wrong numerical results, and broken - common workflows; -- EEGLAB parity in APIs, `pop_*` wrappers, history commands, GUI behavior, event - semantics, and expected data structures; -- EEG dict fields including `data`, `nbchan`, `pnts`, `trials`, `srate`, - `xmin`, `xmax`, `times`, `chanlocs`, `event`, `urevent`, `epoch`, `history`, - `icaact`, `icawinv`, `icasphere`, `icaweights`, and `icachansind`; -- MATLAB/Python indexing boundaries, especially 1-based EEGLAB latencies and - user-facing indices versus 0-based Python arrays; -- channel-major shape assumptions: continuous `(nbchan, pnts)` and epoched - `(nbchan, pnts, trials)`; -- GUI plus `eegprep-console` synchronization through `EEGPrepSession`; -- `return_com=True`, `(EEG, com)` returns, history strings, and session update - paths for user-facing `pop_*` functions; -- runtime independence from `src/eegprep/eeglab/`; -- packaged Markdown help resources for GUI Help or `pophelp`; -- missing tests tied to changed behavior; -- concrete security, path, file I/O, and dependency risks; -- realistic EEG-size performance regressions. - -## Triage Findings - -Accept findings only when they are concrete and introduced or exposed by the -reviewed change. Reject: - -- pre-existing issues outside the diff; -- generic linter/formatter comments; -- broad refactors and speculative abstractions; -- unlikely edge cases that would complicate the code without protecting real - workflows; -- subjective MATLAB-vs-Python style preferences that do not break EEGPrep's - parity contract. - -For each accepted finding, fix the smallest ownership boundary that addresses -the bug. For each rejected finding, record the reason briefly in the final -report. Add an inline code comment only when it documents a real invariant that -future reviewers need to know. - -## Final Report - -Include: - -- review command used; -- tests/proof run; -- findings accepted, fixed, or rejected, briefly why; -- the clean result from the final helper run, or the exact remaining risk if a - finding was consciously left open. - -If the final helper run exits 0 and prints -`autoreview clean: no accepted/actionable findings reported`, report that run as -clean and stop. diff --git a/.agents/skills/oc-autoreview-adapted/agents/openai.yaml b/.agents/skills/oc-autoreview-adapted/agents/openai.yaml deleted file mode 100644 index 08ca5ab1..00000000 --- a/.agents/skills/oc-autoreview-adapted/agents/openai.yaml +++ /dev/null @@ -1,4 +0,0 @@ -interface: - display_name: "OC Autoreview Adapted" - short_description: "EEGPrep-focused autoreview workflow" - default_prompt: "Use $oc-autoreview-adapted to review the current EEGPrep changes before closeout." diff --git a/.agents/skills/oc-autoreview-adapted/scripts/autoreview b/.agents/skills/oc-autoreview-adapted/scripts/autoreview deleted file mode 100755 index 751fb963..00000000 --- a/.agents/skills/oc-autoreview-adapted/scripts/autoreview +++ /dev/null @@ -1,895 +0,0 @@ -#!/usr/bin/env python3 -from __future__ import annotations - -# Adapted for EEGPrep from OpenClaw's MIT-licensed autoreview helper. -# Original copyright (c) 2026 openclaw. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import argparse -import json -import os -import subprocess -import sys -import tempfile -import textwrap -import time -from pathlib import Path -from typing import Any - - -DEFAULT_BASE = "origin/develop" -TRUNK_BRANCHES = {"develop", "main", "master"} -REPORT_KEYS = { - "findings", - "overall_correctness", - "overall_explanation", - "overall_confidence", -} -FINDING_KEYS = { - "title", - "body", - "priority", - "confidence", - "category", - "code_location", -} -CATEGORIES = { - "bug", - "security", - "regression", - "test_gap", - "maintainability", - "eeglab_parity", - "data_structure", - "gui_session", - "docs_help", - "performance", -} - -SCHEMA: dict[str, Any] = { - "type": "object", - "additionalProperties": False, - "required": [ - "findings", - "overall_correctness", - "overall_explanation", - "overall_confidence", - ], - "properties": { - "findings": { - "type": "array", - "items": { - "type": "object", - "additionalProperties": False, - "required": [ - "title", - "body", - "priority", - "confidence", - "category", - "code_location", - ], - "properties": { - "title": {"type": "string", "minLength": 1, "maxLength": 140}, - "body": {"type": "string", "minLength": 1, "maxLength": 2400}, - "priority": {"type": "string", "enum": ["P0", "P1", "P2", "P3"]}, - "confidence": {"type": "number", "minimum": 0, "maximum": 1}, - "category": {"type": "string", "enum": sorted(CATEGORIES)}, - "code_location": { - "type": "object", - "additionalProperties": False, - "required": ["file_path", "line"], - "properties": { - "file_path": {"type": "string", "minLength": 1}, - "line": {"type": "integer", "minimum": 1}, - }, - }, - }, - }, - }, - "overall_correctness": { - "type": "string", - "enum": ["patch is correct", "patch is incorrect"], - }, - "overall_explanation": {"type": "string", "minLength": 1, "maxLength": 3000}, - "overall_confidence": {"type": "number", "minimum": 0, "maximum": 1}, - }, -} - - -def run( - args: list[str], - cwd: Path, - *, - input_text: str | None = None, - check: bool = True, -) -> subprocess.CompletedProcess[str]: - result = subprocess.run( - args, - cwd=cwd, - input=input_text, - text=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - if check and result.returncode != 0: - command = " ".join(args) - raise SystemExit( - f"command failed ({result.returncode}): {command}\n" - f"{result.stderr or result.stdout}" - ) - return result - - -def run_with_heartbeat( - args: list[str], - cwd: Path, - *, - input_text: str, - label: str, - heartbeat_seconds: int, -) -> subprocess.CompletedProcess[str]: - started = time.monotonic() - proc = subprocess.Popen( - args, - cwd=cwd, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - ) - first_communicate = True - while True: - try: - stdout, stderr = proc.communicate( - input=input_text if first_communicate else None, - timeout=heartbeat_seconds, - ) - return subprocess.CompletedProcess(args, int(proc.returncode or 0), stdout, stderr) - except subprocess.TimeoutExpired: - first_communicate = False - elapsed = int(time.monotonic() - started) - print( - f"review still running: {label} elapsed={elapsed}s pid={proc.pid}", - file=sys.stderr, - flush=True, - ) - - -def repo_root() -> Path: - start = Path.cwd().resolve() - unsafe_root = discover_repo_root(start) or start - git_bin = find_command("git", unsafe_root) - if not git_bin: - raise SystemExit("git executable not found. Install Git or add it to PATH.") - result = subprocess.run( - [git_bin, "rev-parse", "--show-toplevel"], - text=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - if result.returncode != 0: - raise SystemExit("autoreview must run inside a git repository") - return Path(result.stdout.strip()).resolve() - - -def discover_repo_root(start: Path) -> Path | None: - current = start - while True: - if (current / ".git").exists(): - return current - if current.parent == current: - return None - current = current.parent - - -def git(repo: Path, *args: str, check: bool = True) -> str: - return run([resolve_command("git", repo), *args], repo, check=check).stdout - - -def current_branch(repo: Path) -> str: - branch = git(repo, "branch", "--show-current", check=False).strip() - return branch or "detached" - - -def is_dirty(repo: Path) -> bool: - return bool(git(repo, "status", "--porcelain").strip()) - - -def choose_target(repo: Path, mode: str, base_ref: str | None) -> tuple[str, str | None]: - normalized = "local" if mode == "uncommitted" else mode - branch = current_branch(repo) - if normalized == "local" or (normalized == "auto" and is_dirty(repo)): - return "local", None - if normalized == "commit": - return "commit", None - if normalized == "branch" or (normalized == "auto" and branch not in TRUNK_BRANCHES): - return "branch", base_ref or detect_pr_base(repo) or DEFAULT_BASE - raise SystemExit( - "no review target: clean trunk checkout and no forced mode. " - "Pass --mode branch --base or --mode commit --commit ." - ) - - -def detect_pr_base(repo: Path) -> str | None: - gh_bin = find_command("gh", repo) - if not gh_bin: - return None - result = run( - [gh_bin, "pr", "view", "--json", "baseRefName", "--jq", ".baseRefName"], - repo, - check=False, - ) - base = result.stdout.strip() - if result.returncode != 0 or not base: - return None - return f"origin/{base}" - - -def resolve_command(name: str, repo: Path) -> str: - resolved = find_command(name, repo) - if resolved: - return resolved - raise SystemExit( - f"executable not found: {name}. Install it or pass an explicit trusted path." - ) - - -def find_command(name: str, repo: Path) -> str | None: - command = Path(name) - if has_directory_component(name, command): - base = command if command.is_absolute() else repo / command - return first_executable_candidate(base) - for part in os.environ.get("PATH", "").split(os.pathsep): - if not part or part == ".": - continue - path_part = Path(part) - if not path_part.is_absolute(): - continue - try: - resolved_part = path_part.resolve() - resolved_repo = repo.resolve() - except OSError: - continue - if is_within(resolved_part, resolved_repo): - continue - found = first_executable_candidate(resolved_part / name, reject_root=resolved_repo) - if found: - return found - return None - - -def has_directory_component(name: str, command: Path) -> bool: - separators = [separator for separator in (os.sep, os.altsep) if separator] - return command.is_absolute() or bool(command.drive) or any( - separator in name for separator in separators - ) - - -def first_executable_candidate(path: Path, *, reject_root: Path | None = None) -> str | None: - if os.name == "nt" and not path.suffix: - extensions = [ - ext for ext in os.environ.get("PATHEXT", ".COM;.EXE;.BAT;.CMD").split(";") if ext - ] - candidates = [path.with_suffix(ext.lower()) for ext in extensions] - candidates.extend(path.with_suffix(ext.upper()) for ext in extensions) - candidates.append(path) - else: - candidates = [path] - for candidate in candidates: - if candidate.is_file() and os.access(candidate, os.X_OK): - if reject_root is not None: - try: - if is_within(candidate.resolve(), reject_root): - continue - except OSError: - continue - return str(candidate) - return None - - -def is_within(path: Path, root: Path) -> bool: - return path == root or path.is_relative_to(root) - - -def bounded(text: str, limit: int = 200_000) -> str: - if len(text) <= limit: - return text - return text[:limit] + f"\n\n[truncated at {limit} characters]\n" - - -def bounded_field(text: str, limit: int) -> str: - if len(text) <= limit: - return text - suffix = "\n\n[truncated]" - return text[: max(0, limit - len(suffix))] + suffix - - -def read_text(path: Path, limit: int = 50_000) -> str: - try: - data = path.read_bytes() - except OSError as exc: - return f"[unreadable: {exc}]" - if b"\0" in data: - return "[binary file omitted]" - return bounded(data.decode("utf-8", errors="replace"), limit) - - -def local_bundle(repo: Path) -> str: - parts = [ - "# Git Status", - git(repo, "status", "--short"), - "# Staged Diff", - git(repo, "diff", "--cached", "--stat"), - bounded(git(repo, "diff", "--cached", "--patch", "--find-renames")), - "# Unstaged Diff", - git(repo, "diff", "--stat"), - bounded(git(repo, "diff", "--patch", "--find-renames")), - ] - untracked = [ - line for line in git(repo, "ls-files", "--others", "--exclude-standard").splitlines() if line - ] - if untracked: - parts.append("# Untracked Files") - for rel in untracked: - parts.append(f"## {rel}\n{read_text(repo / rel)}") - return "\n\n".join(parts) - - -def branch_bundle(repo: Path, base_ref: str, *, skip_fetch: bool) -> str: - if not skip_fetch: - git(repo, "fetch", "origin", "--quiet", check=False) - return "\n\n".join( - [ - "# Branch Diff", - f"base: {base_ref}", - git(repo, "diff", "--stat", f"{base_ref}...HEAD"), - bounded(git(repo, "diff", "--patch", "--find-renames", f"{base_ref}...HEAD")), - ] - ) - - -def commit_bundle(repo: Path, commit_ref: str) -> str: - return "\n\n".join( - [ - "# Commit Diff", - f"commit: {commit_ref}", - git(repo, "show", "--stat", "--format=fuller", commit_ref), - bounded(git(repo, "show", "--patch", "--find-renames", "--format=fuller", commit_ref)), - ] - ) - - -def review_paths(repo: Path, target: str, target_ref: str | None, commit_ref: str) -> set[str]: - names: set[str] = set() - if target == "local": - sources = [ - git(repo, "diff", "--name-only", "--cached"), - git(repo, "diff", "--name-only"), - git(repo, "ls-files", "--others", "--exclude-standard"), - ] - elif target == "branch": - if target_ref is None: - raise SystemExit("internal error: branch target missing base ref") - sources = [git(repo, "diff", "--name-only", f"{target_ref}...HEAD")] - else: - sources = [git(repo, "show", "--name-only", "--format=", commit_ref)] - for source in sources: - for line in source.splitlines(): - path = line.strip() - if path: - names.add(path) - return names - - -def instruction_paths(repo: Path, changed_paths: set[str]) -> list[Path]: - paths = {repo / "AGENTS.md"} - for rel in changed_paths: - rel_path = Path(rel) - if rel_path.is_absolute() or ".." in rel_path.parts: - continue - current = (repo / rel_path).parent - while True: - candidate = current / "AGENTS.md" - if candidate.exists(): - paths.add(candidate) - if current == repo or current.parent == current: - break - current = current.parent - return sorted(path for path in paths if path.exists()) - - -def instruction_bundle(repo: Path, changed_paths: set[str]) -> str: - paths = instruction_paths(repo, changed_paths) - if not paths: - return "# Repository Instructions\n[no AGENTS.md files found]" - parts = ["# Repository Instructions"] - for path in paths: - rel = path.relative_to(repo) - parts.append(f"## {rel}\n{read_text(path)}") - return "\n\n".join(parts) - - -def load_extra_prompt(args: argparse.Namespace) -> str: - chunks: list[str] = [] - for value in args.prompt or []: - chunks.append(value) - for path in args.prompt_file or []: - chunks.append(Path(path).read_text()) - return "\n\n".join(chunks) - - -def load_datasets(args: argparse.Namespace) -> str: - chunks: list[str] = [] - for spec in args.dataset or []: - path = Path(spec) - if path.is_dir(): - raise SystemExit(f"--dataset must be a file, got directory: {path}") - chunks.append(f"# Dataset: {path}\n{read_text(path)}") - return "\n\n".join(chunks) - - -def build_prompt( - repo: Path, - target: str, - target_ref: str | None, - changed_paths: set[str], - instructions: str, - bundle: str, - extra_prompt: str, - datasets: str, -) -> str: - target_line = f"{target} {target_ref}" if target_ref else target - changed = "\n".join(f"- {path}" for path in sorted(changed_paths)) or "[no changed paths]" - return textwrap.dedent( - f""" - You are a senior EEGPrep code reviewer. Review the provided git change bundle only. - Be autonomous: inspect files as needed, reason through the changed behavior, and return a - structured result without asking follow-up questions. - - Hard rules: - - Return exactly one JSON object and nothing else. Do not wrap it in Markdown. - - The JSON object must match this schema exactly: - {json.dumps(SCHEMA, indent=2)} - - Do not modify files. - - Do not invoke nested reviewers or review tools. Forbidden commands include: - codex review, autoreview, oracle review, and any reviewer-panel workflow. - - You may use read-only tools and web search to inspect source files, dependency docs, - EEGLAB reference behavior, current APIs, and security implications. - - Shell commands, if available, must be read-only inspection commands. Do not run tests, - formatters, package installs, generators, git mutation commands, or commands that write files. - - Report only actionable defects introduced or exposed by this change. - - Prefer high-signal findings over style feedback. False positives waste maintainer time. - - For each finding, use the smallest file/line location that demonstrates the issue. - - If there are no actionable findings, return an empty findings array and mark the patch correct. - - EEGPrep review priorities: - - Correctness bugs, import/runtime failures, wrong numerical results, and broken common workflows. - - EEGLAB parity regressions in public API behavior, pop_* wrappers, history commands, - GUI layout/behavior, event semantics, and expected data structures. - - EEG dict semantics for data, nbchan, pnts, trials, srate, xmin, xmax, times, - chanlocs, event, urevent, epoch, history, icaact, icawinv, icasphere, - icaweights, and icachansind. - - MATLAB/Python boundary mistakes, especially 1-based EEGLAB event latencies and - user-facing indices versus 0-based Python arrays. - - Channel-major data shape assumptions: continuous data is usually (nbchan, pnts), - and epoched data is usually (nbchan, pnts, trials). - - GUI plus eegprep-console session sync: EEG, ALLEEG, CURRENTSET, LASTCOM, ALLCOM, - STUDY, and CURRENTSTUDY must stay synchronized through EEGPrepSession helpers. - - User-facing pop_* contracts: return_com=True, (EEG, com) returns, history strings, - and GUI/session update paths. - - Runtime code must not depend on src/eegprep/eeglab existing. Use it only as a - development reference. - - User-facing GUI Help or pophelp behavior needs packaged Markdown help resources. - - Tests should cover realistic regressions. Suggest exact missing tests only when the - gap is tied to changed behavior. - - Security findings must be concrete: path traversal, unsafe shell/filesystem use, - unsafe deserialization, credential/privacy leaks, or trust-boundary validation loss. - - Performance findings must be realistic for EEG data sizes. - - Do not flag: - - Pre-existing issues outside the reviewed change. - - Generic linter/formatter comments. - - Broad refactors or speculative abstractions. - - Unlikely edge cases that would complicate the code without protecting real workflows. - - Subjective MATLAB-vs-Python style preferences unless they break EEGPrep's parity contract. - - Review target: {target_line} - Repository: {repo} - - # Changed Paths - {changed} - - {extra_prompt} - - {datasets} - - {instructions} - - # Change Bundle - {bundle} - """ - ).strip() - - -def write_json_temp(data: dict[str, Any]) -> Path: - handle = tempfile.NamedTemporaryFile("w", suffix=".json", delete=False) - with handle: - json.dump(data, handle) - return Path(handle.name) - - -def run_codex(args: argparse.Namespace, repo: Path, prompt: str) -> str: - schema_path = write_json_temp(SCHEMA) - output_path = Path(tempfile.NamedTemporaryFile("w", suffix=".json", delete=False).name) - cmd = [resolve_command(args.codex_bin, repo), "--ask-for-approval", "never"] - if args.web_search: - cmd.append("--search") - if args.model: - cmd.extend(["--model", args.model]) - if args.thinking: - cmd.extend(["-c", f'model_reasoning_effort="{args.thinking}"']) - cmd.extend( - [ - "exec", - "--ephemeral", - "-C", - str(repo), - "-s", - "read-only", - "--output-schema", - str(schema_path), - "--output-last-message", - str(output_path), - "-", - ] - ) - result = run_with_heartbeat( - cmd, - repo, - input_text=prompt, - label="codex", - heartbeat_seconds=args.heartbeat_seconds, - ) - try: - output = output_path.read_text() - finally: - schema_path.unlink(missing_ok=True) - output_path.unlink(missing_ok=True) - if result.returncode != 0: - raise SystemExit(f"codex engine failed ({result.returncode})\n{result.stderr or result.stdout}") - return output or result.stdout - - -def extract_json(text: str) -> dict[str, Any]: - stripped = text.strip() - if not stripped: - raise SystemExit("review engine returned empty output") - try: - parsed = json.loads(stripped) - except json.JSONDecodeError as exc: - candidate = parse_json_candidate(stripped) - if isinstance(candidate, dict) and "findings" in candidate: - return candidate - jsonl_report = extract_json_from_jsonl(stripped) - if jsonl_report: - return jsonl_report - raise SystemExit(f"review engine returned non-JSON output: {exc}\n{stripped[:2000]}") - if isinstance(parsed, dict) and "findings" in parsed: - return parsed - if isinstance(parsed, dict) and isinstance(parsed.get("structured_output"), dict): - return parsed["structured_output"] - if isinstance(parsed, dict) and isinstance(parsed.get("result"), str): - result_json = parse_json_candidate(parsed["result"]) - if isinstance(result_json, dict) and "findings" in result_json: - return result_json - raise SystemExit(f"review engine result was not structured JSON:\n{parsed['result'][:2000]}") - jsonl_report = extract_json_from_jsonl(stripped) - if jsonl_report: - return jsonl_report - raise SystemExit(f"review engine returned unexpected JSON shape:\n{json.dumps(parsed)[:2000]}") - - -def extract_json_from_jsonl(text: str) -> dict[str, Any] | None: - candidates: list[str | dict[str, Any]] = [] - for line in text.splitlines(): - line = line.strip() - if not line: - continue - try: - event = json.loads(line) - except json.JSONDecodeError: - continue - if not isinstance(event, dict): - continue - part = event.get("part") - if isinstance(part, dict) and isinstance(part.get("text"), str): - candidates.append(part["text"]) - data = event.get("data") - if isinstance(data, dict) and isinstance(data.get("content"), str): - candidates.append(data["content"]) - if isinstance(event.get("result"), str): - candidates.append(event["result"]) - if isinstance(event.get("structured_output"), dict): - candidates.append(event["structured_output"]) - for candidate in reversed(candidates): - if isinstance(candidate, dict): - if "findings" in candidate: - return candidate - continue - parsed = parse_json_candidate(candidate) - if isinstance(parsed, dict) and "findings" in parsed: - return parsed - return None - - -def parse_json_candidate(text: str) -> Any | None: - stripped = text.strip() - if stripped.startswith("```"): - lines = stripped.splitlines() - if lines and lines[0].startswith("```") and lines[-1].strip() == "```": - stripped = "\n".join(lines[1:-1]).strip() - try: - parsed = json.loads(stripped) - except json.JSONDecodeError: - return None - if isinstance(parsed, str) and parsed != text: - nested = parse_json_candidate(parsed) - return nested if nested is not None else parsed - return parsed - - -def validate_report( - report: dict[str, Any], - changed_paths: set[str], - required: list[str], -) -> None: - extra_top = set(report) - REPORT_KEYS - if extra_top: - raise SystemExit(f"review JSON has unexpected top-level keys: {sorted(extra_top)}") - for key in SCHEMA["required"]: - if key not in report: - raise SystemExit(f"review JSON missing required key: {key}") - if not isinstance(report["findings"], list): - raise SystemExit("review JSON findings must be an array") - if report.get("overall_correctness") not in {"patch is correct", "patch is incorrect"}: - raise SystemExit(f"review JSON has invalid overall_correctness: {report.get('overall_correctness')}") - if not isinstance(report.get("overall_explanation"), str) or not report["overall_explanation"]: - raise SystemExit("review JSON overall_explanation must be a non-empty string") - if len(report["overall_explanation"]) > 3000: - raise SystemExit("review JSON overall_explanation is too long") - if not number_in_range(report.get("overall_confidence")): - raise SystemExit("review JSON overall_confidence must be numeric") - - kept_findings: list[dict[str, Any]] = [] - ignored_findings: list[tuple[int, dict[str, Any], str, int]] = [] - finding_text = "" - for index, finding in enumerate(report["findings"]): - validate_finding(index, finding) - location = finding["code_location"] - rel = str(location["file_path"]).strip() - line = int(location["line"]) - if rel not in changed_paths: - ignored_findings.append((index, finding, rel, line)) - continue - kept_findings.append(finding) - finding_text += "\n" + json.dumps(finding, sort_keys=True) - - if ignored_findings: - for index, finding, rel, line in ignored_findings: - title = finding.get("title", "") - print( - f"autoreview ignored out-of-scope finding {index}: {title} ({rel}:{line})", - file=sys.stderr, - ) - print(bounded_field(str(finding.get("body", "")), 500), file=sys.stderr) - report["findings"] = kept_findings - if not kept_findings and report["overall_correctness"] == "patch is incorrect": - note = f"Ignored {len(ignored_findings)} out-of-scope finding(s) outside the reviewed change." - explanation = report["overall_explanation"].rstrip() - report["overall_correctness"] = "patch is correct" - report["overall_explanation"] = bounded_field(f"{explanation}\n\n{note}", 3000) - - haystack = finding_text.lower() - for needle in required: - if needle.lower() not in haystack: - raise SystemExit(f"required finding text not found: {needle}") - - -def validate_finding(index: int, finding: Any) -> None: - if not isinstance(finding, dict): - raise SystemExit(f"finding {index} must be an object") - extra_finding = set(finding) - FINDING_KEYS - if extra_finding: - raise SystemExit(f"finding {index} has unexpected keys: {sorted(extra_finding)}") - for key in FINDING_KEYS: - if key not in finding: - raise SystemExit(f"finding {index} missing required key: {key}") - title = finding.get("title") - if not isinstance(title, str) or not title or len(title) > 140: - raise SystemExit(f"finding {index} has invalid title") - body = finding.get("body") - if not isinstance(body, str) or not body or len(body) > 2400: - raise SystemExit(f"finding {index} has invalid body") - priority = finding.get("priority") - if priority not in {"P0", "P1", "P2", "P3"}: - raise SystemExit(f"finding {index} has invalid priority: {priority}") - if not number_in_range(finding.get("confidence")): - raise SystemExit(f"finding {index} has invalid confidence") - category = finding.get("category") - if category not in CATEGORIES: - raise SystemExit(f"finding {index} has invalid category: {category}") - location = finding.get("code_location") - if not isinstance(location, dict): - raise SystemExit(f"finding {index} missing code_location") - rel = str(location.get("file_path", "")).strip() - line = location.get("line") - if not rel or not isinstance(line, int) or line < 1: - raise SystemExit(f"finding {index} has invalid location: {location}") - path = Path(rel) - if path.is_absolute() or ".." in path.parts: - raise SystemExit(f"finding {index} uses invalid file path: {rel}") - - -def number_in_range(value: Any) -> bool: - return isinstance(value, (int, float)) and not isinstance(value, bool) and 0 <= value <= 1 - - -def print_report(report: dict[str, Any]) -> None: - findings = report["findings"] - if findings: - print(f"autoreview findings: {len(findings)}") - elif report["overall_correctness"] == "patch is incorrect": - print("autoreview verdict: patch is incorrect without discrete findings") - else: - print("autoreview clean: no accepted/actionable findings reported") - for finding in findings: - loc = finding["code_location"] - print(f"[{finding['priority']}] {finding['title']} ({finding['category']})") - print(f"{loc['file_path']}:{loc['line']}") - print(finding["body"]) - print() - print(f"overall: {report['overall_correctness']} ({report['overall_confidence']})") - print(report["overall_explanation"]) - - -def start_parallel_tests(command: str, repo: Path) -> tuple[subprocess.Popen[Any], float]: - print(f"tests: {command}") - return subprocess.Popen(command, cwd=repo, shell=True), time.time() - - -def finish_parallel_tests(proc: subprocess.Popen[Any], started: float) -> int: - proc.wait() - print(f"tests exit: {proc.returncode} after {int(time.time() - started)}s") - return int(proc.returncode or 0) - - -def parse_args(argv: list[str]) -> argparse.Namespace: - parser = argparse.ArgumentParser(description="EEGPrep bundle-driven autonomous code review.") - parser.add_argument("--mode", choices=["auto", "local", "uncommitted", "branch", "commit"], default="auto") - parser.add_argument("--base") - parser.add_argument("--commit", default="HEAD") - parser.add_argument("--codex-bin", default=os.environ.get("CODEX_BIN", "codex")) - parser.add_argument("--model", default=os.environ.get("AUTOREVIEW_MODEL")) - parser.add_argument( - "--thinking", - choices=["low", "medium", "high", "xhigh"], - default=os.environ.get("AUTOREVIEW_THINKING", "high"), - ) - parser.add_argument("--no-web-search", dest="web_search", action="store_false", default=True) - parser.add_argument("--prompt", action="append", help="Additional review instruction text.") - parser.add_argument("--prompt-file", action="append", help="Additional review instruction file.") - parser.add_argument("--dataset", action="append", help="Extra evidence file to include in the bundle.") - parser.add_argument("--output", help="Write human output to a file as well as stdout.") - parser.add_argument("--json-output", help="Write validated structured review JSON.") - parser.add_argument("--parallel-tests", help="Run a focused test command concurrently with review.") - parser.add_argument("--require-finding", action="append", default=[], help="Require finding text to contain this substring.") - parser.add_argument("--expect-findings", action="store_true", help="Treat findings as success for harness checks.") - parser.add_argument("--skip-fetch", action="store_true", help="Do not fetch origin before branch diffs.") - parser.add_argument("--heartbeat-seconds", type=int, default=60) - parser.add_argument("--dry-run", action="store_true", help="Resolve target and bundle context without calling Codex.") - return parser.parse_args(argv) - - -def main(argv: list[str]) -> int: - args = parse_args(argv) - repo = repo_root() - target, target_ref = choose_target(repo, args.mode, args.base) - print(f"autoreview target: {target}") - print(f"branch: {current_branch(repo)}") - print("engine: codex") - if args.model: - print(f"model: {args.model}") - print(f"thinking: {args.thinking}") - print(f"web_search: {'on' if args.web_search else 'off'}") - display_ref = args.commit if target == "commit" else target_ref - if display_ref: - print(f"ref: {display_ref}") - - if target == "local": - bundle = local_bundle(repo) - elif target == "branch": - if target_ref is None: - raise SystemExit("internal error: branch target missing base ref") - bundle = branch_bundle(repo, target_ref, skip_fetch=args.skip_fetch) - else: - bundle = commit_bundle(repo, args.commit) - target_ref = args.commit - changed_paths = review_paths(repo, target, target_ref, args.commit) - instructions = instruction_bundle(repo, changed_paths) - prompt = build_prompt( - repo, - target, - target_ref, - changed_paths, - instructions, - bundle, - load_extra_prompt(args), - load_datasets(args), - ) - print(f"changed paths: {len(changed_paths)}") - print(f"bundle: {len(prompt)} chars") - if args.dry_run: - return 0 - - tests_proc: tuple[subprocess.Popen[Any], float] | None = None - if args.parallel_tests: - tests_proc = start_parallel_tests(args.parallel_tests, repo) - try: - raw = run_codex(args, repo, prompt) - report = extract_json(raw) - validate_report(report, changed_paths, args.require_finding) - if args.json_output: - Path(args.json_output).write_text(json.dumps(report, indent=2) + "\n") - if args.output: - original_stdout = sys.stdout - with Path(args.output).open("w") as handle: - sys.stdout = Tee(original_stdout, handle) - print_report(report) - sys.stdout = original_stdout - else: - print_report(report) - finally: - tests_status = finish_parallel_tests(*tests_proc) if tests_proc else 0 - - has_findings = bool(report["findings"]) - overall_incorrect = report["overall_correctness"] == "patch is incorrect" - if tests_status != 0: - return 1 - if args.expect_findings: - return 0 if has_findings else 1 - return 1 if has_findings or overall_incorrect else 0 - - -class Tee: - def __init__(self, *streams: Any) -> None: - self.streams = streams - - def write(self, data: str) -> None: - for stream in self.streams: - stream.write(data) - - def flush(self) -> None: - for stream in self.streams: - stream.flush() - - -if __name__ == "__main__": - raise SystemExit(main(sys.argv[1:])) diff --git a/.agents/skills/oc-autoreview-adapted/scripts/test-review-harness b/.agents/skills/oc-autoreview-adapted/scripts/test-review-harness deleted file mode 100755 index ab98338a..00000000 --- a/.agents/skills/oc-autoreview-adapted/scripts/test-review-harness +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/env bash -set -euo pipefail - -script_dir=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) -harness="$script_dir/test-review-harness.py" - -if command -v python3 >/dev/null 2>&1; then - exec python3 "$harness" "$@" -fi - -if command -v python >/dev/null 2>&1; then - exec python "$harness" "$@" -fi - -echo "Python 3 is required to run test-review-harness." >&2 -exit 127 diff --git a/.agents/skills/oc-autoreview-adapted/scripts/test-review-harness.py b/.agents/skills/oc-autoreview-adapted/scripts/test-review-harness.py deleted file mode 100755 index 8a7540d5..00000000 --- a/.agents/skills/oc-autoreview-adapted/scripts/test-review-harness.py +++ /dev/null @@ -1,150 +0,0 @@ -#!/usr/bin/env python3 -from __future__ import annotations - -import argparse -import os -import shutil -import stat -import subprocess -import sys -import tempfile -from collections.abc import Callable -from pathlib import Path - - -SAFE_INITIAL = """import numpy as np - - -def trim_eeg(eeg, start_sample, stop_sample): - data = np.asarray(eeg["data"]) - start = int(start_sample) - 1 - stop = int(stop_sample) - trimmed = data[:, start:stop] - out = dict(eeg) - out["data"] = trimmed - out["pnts"] = trimmed.shape[1] - out["xmin"] = start / float(eeg["srate"]) - out["xmax"] = (stop - 1) / float(eeg["srate"]) - return out -""" - -BUGGY_CHANGED = """import numpy as np - - -def trim_eeg(eeg, start_sample, stop_sample): - data = np.asarray(eeg["data"]) - trimmed = data[start_sample:stop_sample, :] - out = dict(eeg) - out["data"] = trimmed - out["pnts"] = stop_sample - start_sample - return out -""" - -BENIGN_CHANGED = """import numpy as np - - -def trim_eeg(eeg, start_sample, stop_sample): - data = np.asarray(eeg["data"]) - start = int(start_sample) - 1 - stop = int(stop_sample) - if start < 0 or stop <= start or stop > data.shape[1]: - raise ValueError("sample range is outside EEG data") - trimmed = data[:, start:stop] - out = dict(eeg) - out["data"] = trimmed - out["pnts"] = trimmed.shape[1] - out["xmin"] = start / float(eeg["srate"]) - out["xmax"] = (stop - 1) / float(eeg["srate"]) - return out -""" - -BUGGY_PROMPT = ( - "Acceptance fixture: this EEG change contains a real EEGPrep-style bug. " - "Review normally and report only concrete defects introduced by the patch." -) -BENIGN_PROMPT = ( - "Calibration fixture: this EEG change intentionally validates sample bounds " - "and preserves channel-major data. Do not flag it unless there is a concrete bug." -) - - -def parse_args(argv: list[str]) -> argparse.Namespace: - parser = argparse.ArgumentParser( - description=( - "Create a temporary EEG-style git repo and run the adapted autoreview helper " - "against a buggy or benign patch." - ) - ) - parser.add_argument("--fixture", choices=("buggy", "benign"), default="buggy") - parser.add_argument("--dry-run", action="store_true", help="Only verify helper target selection.") - return parser.parse_args(argv) - - -def run(command: list[str], cwd: Path) -> None: - subprocess.run(command, cwd=cwd, check=True) - - -def write_fixture_file(repo: Path, content: str) -> None: - (repo / "eeg_ops.py").write_text(content, encoding="utf-8", newline="\n") - - -def create_fixture_repo(repo: Path, fixture: str) -> None: - run(["git", "init", "--quiet"], repo) - run(["git", "config", "user.name", "Review Fixture"], repo) - run(["git", "config", "user.email", "review-fixture@example.com"], repo) - write_fixture_file(repo, SAFE_INITIAL) - run(["git", "add", "eeg_ops.py"], repo) - run(["git", "commit", "--quiet", "-m", "initial safe EEG trim"], repo) - write_fixture_file(repo, BUGGY_CHANGED if fixture == "buggy" else BENIGN_CHANGED) - - -def run_review(repo: Path, script_dir: Path, fixture: str, *, dry_run: bool) -> None: - autoreview = script_dir / "autoreview" - command = [ - sys.executable, - str(autoreview), - "--mode", - "local", - "--prompt", - BUGGY_PROMPT if fixture == "buggy" else BENIGN_PROMPT, - ] - if fixture == "buggy": - command.extend(["--require-finding", "channel", "--expect-findings"]) - if dry_run: - command.append("--dry-run") - run(command, repo) - - -def cleanup_repo(repo: Path) -> None: - def make_writable_and_retry( - function: Callable[[str], object], - path: str, - _exc_info: object, - ) -> None: - try: - os.chmod(path, stat.S_IREAD | stat.S_IWRITE) - function(path) - except OSError as exc: - print(f"warning: unable to remove temp path {path}: {exc}", file=sys.stderr) - - if not repo.exists(): - return - shutil.rmtree(repo, onerror=make_writable_and_retry) - - -def main(argv: list[str]) -> int: - args = parse_args(argv) - script_dir = Path(__file__).resolve().parent - repo = Path(tempfile.mkdtemp(prefix="eegprep-autoreview-fixture.")) - try: - create_fixture_repo(repo, args.fixture) - run_review(repo, script_dir, args.fixture, dry_run=args.dry_run) - except subprocess.CalledProcessError as exc: - return int(exc.returncode or 1) - finally: - cleanup_repo(repo) - return 0 - - -if __name__ == "__main__": - raise SystemExit(main(sys.argv[1:])) diff --git a/.agents/skills/thermo-nuclear-code-quality-review/SKILL.md b/.agents/skills/thermo-nuclear-code-quality-review/SKILL.md new file mode 100644 index 00000000..a5d463fe --- /dev/null +++ b/.agents/skills/thermo-nuclear-code-quality-review/SKILL.md @@ -0,0 +1,100 @@ +--- +name: thermo-nuclear-code-quality-review +description: Run an unusually strict EEGPrep code-quality review for architecture, maintainability, abstraction quality, file sprawl, spaghetti branching, and missed simplification. Use for thermo-nuclear review, thermonuclear review, deep maintainability audit, strict architecture review, or when code technically works but may make EEGPrep harder to ship. +--- + +# Thermo-Nuclear Code Quality Review + +Use this for a demanding maintainability review, not a routine bug pass. The goal is to make EEGPrep shippable: standalone, EEGLAB-familiar, easy for EEG researchers, and structurally simple enough for future agents and humans to extend safely. + +Inspired by Cursor's MIT-licensed `thermo-nuclear-code-quality-review` skill: https://github.com/cursor/plugins/blob/main/cursor-team-kit/skills/thermo-nuclear-code-quality-review/SKILL.md + +## Contract + +- Read `AGENTS.md` first and keep its EEGLAB parity, GUI/console, testing, docs, and style rules in force. +- Review the current diff, PR, branch, or named scope. Do not turn this into a whole-codebase rewrite unless asked. +- Be ambitious about simplification, but only flag structural issues with a concrete failure mode or clear maintainability cost. +- Prefer fewer high-conviction findings over a long list of taste comments. +- If asked to fix findings, verify each one from first principles, make the smallest durable change, run focused tests, and commit only when requested. +- Do not rubber-stamp code because tests pass. Passing behavior can still be architecturally wrong. + +## Review Bar + +Ask these questions for every meaningful change: + +- Is there a simpler framing that deletes branches, modes, wrappers, flags, or helper layers? +- Did this add special cases to an already busy flow instead of moving logic to the owning module? +- Is the logic in the canonical EEGPrep layer? +- Does the code preserve EEG dict invariants and EEGLAB-facing semantics without hidden global state? +- Does GUI/menu/console code update `EEGPrepSession`, history, and visible state atomically? +- Is this abstraction earning its keep, or is it a pass-through wrapper? +- Did the change create duplicate helpers instead of reusing an existing contract? +- Did it make data boundaries weaker through unnecessary optionality, casts, duck typing, or silent fallbacks? +- Did a file cross or approach roughly 1000 lines because new concepts were not decomposed? +- Does the code remain understandable to an EEG researcher migrating from EEGLAB? + +## EEGPrep Architecture Boundaries + +Keep ownership clear: + +- `popfunc`: user-facing `pop_*` wrappers, history strings, dialogs, `return_com=True`, and EEGLAB-compatible command surfaces. +- `sigprocfunc`: low-level signal processing and numerical transforms. No GUI/session orchestration here. +- `guifunc`: Qt/inputgui/menu rendering and GUI coordination. No low-level numerical algorithm ownership here. +- `adminfunc`: session, options, console, history, storage, and administrative runtime behavior. +- `plugins/*`: bundled plugin ports and plugin-owned helpers. +- `resources/help`: EEGPrep-owned Help Markdown for user-facing dialogs/functions. +- `eeglab/`: development reference only. Runtime code must not depend on it. + +Flag layer leaks aggressively when they make future behavior harder to reason about. + +## What To Flag + +Flag issues such as: + +- A "works but messy" implementation where a clear code-judo move would delete complexity. +- One-off booleans, nullable modes, or scattered feature checks. +- Repeated conditionals that indicate a missing helper, model, or dispatcher. +- Partial session/history/dataset updates that can leave GUI and console out of sync. +- EEGLAB user-facing indices mixed with Python 0-based indices without an explicit boundary. +- Channel-major EEG data assumptions hidden behind generic array handling. +- New runtime dependency on vendored EEGLAB, local paths, optional files, or unstated environment state. +- Thin wrappers, identity helpers, or generic magic that obscure simple EEG invariants. +- Copy-pasted parsing/history/dialog helpers when a canonical helper exists. +- Large-file growth that should be split into focused modules. +- Tests that only assert implementation details while missing user-observable behavior. + +## Preferred Remedies + +Prefer remedies that: + +- Delete concepts rather than rename them. +- Collapse duplicate branches into one explicit flow. +- Move logic to the module that owns the concept. +- Extract small pure helpers for repeated parsing, shape, or indexing contracts. +- Make state transitions atomic through `EEGPrepSession` helpers. +- Make data boundaries explicit before converting between EEGLAB-facing and Python-facing indices. +- Split large modules by stable ownership, not by arbitrary line count. +- Replace loose optional/cast-heavy code with a concrete contract. +- Add focused tests for externally observable EEG dict, GUI/session, history, or file behavior. + +## Tone And Output + +Lead with findings. For each finding include: + +- file and line; +- why it matters for correctness, parity, or maintainability; +- the concrete scenario or future failure mode; +- a preferred fix direction. + +Order findings by severity: + +1. Structural regressions that can create bugs or block maintainability. +2. Missed simplification that would remove significant complexity. +3. Wrong ownership/layering or canonical-helper duplication. +4. State/session/history atomicity risks. +5. File sprawl and decomposition concerns. +6. Lower-level legibility issues with real cost. + +If there are no actionable issues, say so directly and mention any residual test or review limits. + +Do not soften major structural problems into vague nits. Also do not invent architecture work without a real failure mode. diff --git a/.agents/skills/thermo-nuclear-code-quality-review/agents/openai.yaml b/.agents/skills/thermo-nuclear-code-quality-review/agents/openai.yaml new file mode 100644 index 00000000..9de350b0 --- /dev/null +++ b/.agents/skills/thermo-nuclear-code-quality-review/agents/openai.yaml @@ -0,0 +1,4 @@ +interface: + display_name: "Thermo-Nuclear Code Quality Review" + short_description: "Strict EEGPrep architecture and maintainability review" + default_prompt: "Use $thermo-nuclear-code-quality-review to review the current EEGPrep branch for structural regressions, spaghetti branching, wrong ownership layers, and missed simplifications." diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2e0e80a7..6d65683e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -2,9 +2,9 @@ name: Tests on: push: - branches: [ main, develop ] + branches: [ master, develop ] pull_request: - branches: [ main, develop ] + branches: [ master, develop ] env: # Set MATLAB batch licensing token for private repos or when using MATLAB Engine diff --git a/.notes/implementation-notes.html b/.notes/implementation-notes.html index f2d38960..0a23504f 100644 --- a/.notes/implementation-notes.html +++ b/.notes/implementation-notes.html @@ -542,5 +542,29 @@

Verification Notes

EEGLAB-facing 1-based public QC indices, and compare nonfinite data masks in eegprep eeglab compare. +

Async GUI ICA Progress Notes

+

Design Decisions

+
    +
  • Kept pop_runica itself synchronous for scripts, tests, CLI, + and console calls. The background worker is owned by the Qt menu action + layer because only the GUI needs to keep repainting while the computation + runs.
  • +
  • Split pop_runica GUI option collection from ICA execution + so the options dialog always opens on the main Qt thread and only the pure + computation runs in the worker.
  • +
  • Session mutation remains on the main thread. The worker returns an + updated EEG object and command string; EEGPrepSession is updated + only from the success callback.
  • +
+

Tradeoffs

+
    +
  • The progress dialog is indeterminate because runica progress is + iteration/log-message based rather than a reliable percentage. Cancellation + is intentionally not exposed until the ICA backends support safe + interruption.
  • +
  • The reusable long-task helper captures EEGPrep log messages for the + dialog, while the console action boundary keeps command echo and progress + output ordered for mixed GUI-plus-console workflows.
  • +
diff --git a/README.md b/README.md index 4a1b8ae0..bf23fd9b 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,6 @@ # What is EEGPrep? [![Documentation Status](https://github.com/sccn/eegprep/actions/workflows/docs.yml/badge.svg)](https://github.com/sccn/eegprep/actions/workflows/docs.yml) -[![GitHub Pages](https://github.com/sccn/eegprep/actions/workflows/pages.yml/badge.svg)](https://sccn.github.io/eegprep/) EEGPrep is a Python package that reproduces the EEGLAB default preprocessing pipeline with numerical accuracy down to 1e-5 uV, including clean_rawdata and ICLabel, enabling MATLAB-to-Python equivalence for EEG analysis. It takes BIDS data as input and produces BIDS derivative dataset as output, which can then be reimported into other packages as needed (EEGLAB, Fieldtrip, Brainstorm, MNE). It does produce plots. The package will be fully documented for conversion, packaging, and testing workflows, with installation available via PyPI. diff --git a/docs/source/contributing.rst b/docs/source/contributing.rst index 147deecf..d5226922 100644 --- a/docs/source/contributing.rst +++ b/docs/source/contributing.rst @@ -58,6 +58,7 @@ If you only need documentation dependencies, sync the docs extra: - The eegprep package in editable mode - Repo tooling dependencies +- GUI and ``eegprep-console`` runtime dependencies - Documentation dependencies when ``--extra docs`` is used Code Style Guidelines @@ -121,19 +122,32 @@ When adding new features, include tests: .. code-block:: python + import numpy as np from eegprep import EEGobj def test_new_feature(): """Test description of what this tests.""" - # Setup - eeg = EEGobj() - - # Execute - result = eeg.new_feature() + # Setup: EEGobj wraps an EEG dict (or a .set file path). + eeg_dict = { + "data": np.zeros((4, 100), dtype=np.float32), + "nbchan": 4, + "pnts": 100, + "trials": 1, + "srate": 128.0, + "xmin": 0.0, + "xmax": 99 / 128.0, + "chanlocs": [{"labels": f"Ch{i + 1}"} for i in range(4)], + "event": [], + "epoch": [], + } + eeg = EEGobj(eeg_dict) + + # Execute: EEGobj dispatches pop_* operations, e.g. eeg.pop_reref([]). + result = eeg.pop_reref([]) # Assert assert result is not None - assert len(result) > 0 + assert result["nbchan"] == 4 Documentation Standards ======================= diff --git a/docs/source/development.rst b/docs/source/development.rst index f04c9171..ed86b16b 100644 --- a/docs/source/development.rst +++ b/docs/source/development.rst @@ -58,8 +58,10 @@ Install the default development environment: uv sync --group dev ``uv sync`` creates ``.venv/`` and installs EEGPrep in editable mode from the -locked dependency set. Use ``uv run`` for commands so they execute inside this -environment. +locked dependency set. The development environment includes the GUI and +``eegprep-console`` runtime dependencies so ``uv run eegprep-console --full`` +works from a fresh checkout. Use ``uv run`` for commands so they execute inside +this environment. Install Documentation Dependencies ---------------------------------- diff --git a/docs/source/user_guide/agent_cli.rst b/docs/source/user_guide/agent_cli.rst index b4bef338..f97fb9df 100644 --- a/docs/source/user_guide/agent_cli.rst +++ b/docs/source/user_guide/agent_cli.rst @@ -123,18 +123,19 @@ QC results include stable recommendation codes that an agent can reason over. HTML reports are for human review; the paired JSON and manifests are for automation. -BIDS And EEGLAB Migration -========================= +BIDS And Migration +================== .. code-block:: bash eegprep bids validate bids_root --json eegprep bids import bids_root --subject 01 --task rest --output sub-01.set --json eegprep bids export clean.set --bids-root bids_out --subject 01 --task rest --json - eegprep eeglab history old_pipeline.set --json - eegprep eeglab compare matlab_output.set eegprep_output.set --json - eegprep eeglab convert-script old_pipeline.m --output preprocess.yaml --json - -The EEGLAB helpers are migration aids. Script conversion is intentionally -best-effort and reports unsupported commands instead of silently inventing -behavior. + eegprep migrate history old_pipeline.set --json + eegprep migrate compare matlab_output.set eegprep_output.set --json + eegprep migrate convert-script old_pipeline.m --output preprocess.yaml --json + +Migration helpers can inspect EEGLAB command histories and compare datasets +without making normal EEGPrep CLI usage depend on MATLAB or an EEGLAB checkout. +Script conversion is intentionally best-effort and reports unsupported commands +instead of silently inventing behavior. diff --git a/docs/source/user_guide/gui_console_session.rst b/docs/source/user_guide/gui_console_session.rst index 77acda0b..16f33382 100644 --- a/docs/source/user_guide/gui_console_session.rst +++ b/docs/source/user_guide/gui_console_session.rst @@ -101,6 +101,14 @@ GUI actions should update state through session helpers such as ``notify_changed()``. They should not mutate a GUI-only copy of ``EEG`` that the console cannot see. +Long-running GUI actions use the same session boundary. For example, +GUI-launched ICA opens the EEGLAB-like ``pop_runica`` options dialog on the main +thread, runs the ICA computation behind a progress dialog, then stores the +updated dataset and history only after the worker finishes successfully. While +the worker is running, progress messages are buffered safely for +``eegprep-console`` so the replayable command remains visible before related +output. + ``eegprep-console`` wraps registered ``pop_*`` functions. When a bare call such as ``pop_resample(EEG, 64)`` returns a dataset and command string, the wrapper stores the returned dataset, updates ``LASTCOM`` and ``ALLCOM``, and tells the diff --git a/docs/source/user_guide/installation.rst b/docs/source/user_guide/installation.rst index 25f4a0ce..49a98696 100644 --- a/docs/source/user_guide/installation.rst +++ b/docs/source/user_guide/installation.rst @@ -47,22 +47,6 @@ environments: pip install eegprep -Using conda ------------ - -If you prefer conda, you can install eegprep from the conda-forge channel: - -.. code-block:: bash - - conda install -c conda-forge eegprep - -To create a new conda environment with eegprep: - -.. code-block:: bash - - conda create -n eegprep-env python=3.10 eegprep - conda activate eegprep-env - From Source ----------- @@ -75,7 +59,9 @@ To install eegprep from source for development: uv sync --group dev ``uv sync`` creates the project environment, installs EEGPrep in editable mode, -and uses ``uv.lock`` for reproducible dependency resolution. +and uses ``uv.lock`` for reproducible dependency resolution. The development +environment includes the GUI and console runtime dependencies, so a fresh +checkout can immediately launch ``uv run eegprep-console --full``. To develop or build documentation from source, include the docs extra: @@ -109,24 +95,6 @@ For CPU-only PyTorch: uv add torch --index-url https://download.pytorch.org/whl/cpu -EEGLAB I/O Support ------------------- - -To enable reading and writing EEGLAB .set files: - -.. code-block:: bash - - uv add eeglabio - -MNE-Python Integration ----------------------- - -For integration with MNE-Python: - -.. code-block:: bash - - uv add mne - AMICA ----- @@ -179,7 +147,7 @@ Or with specific extras: .. code-block:: bash - uv add "eegprep[torch,eeglabio,gui,docs]" + uv add "eegprep[torch,gui,docs]" Verification ============ @@ -268,18 +236,23 @@ EEGLAB File Format Issues **Problem**: Cannot read .set files -**Solution**: Install eeglabio: - -.. code-block:: bash - - uv add eeglabio +EEGPrep reads ``.set`` files directly with SciPy (``scipy.io.loadmat``) and +``h5py`` for newer MATLAB v7.3 files, and writes them with +``scipy.io.savemat``. These libraries are installed automatically, so a load +failure is almost always a path or format problem rather than a missing +dependency. -Then verify: +**Solution**: Check that the path points at an existing ``.set`` file and that +any companion ``.fdt`` data file sits next to it: .. code-block:: python + from pathlib import Path from eegprep import pop_loadset - # Should work without errors + + path = Path("sub-01_task-rest_eeg.set") + assert path.is_file(), f"no such file: {path}" + EEG = pop_loadset(str(path)) Memory Issues ------------- diff --git a/docs/source/user_guide/visual_parity.rst b/docs/source/user_guide/visual_parity.rst index c6245ec0..0db14b19 100644 --- a/docs/source/user_guide/visual_parity.rst +++ b/docs/source/user_guide/visual_parity.rst @@ -97,7 +97,7 @@ Capture only EEGPREP using a command supplied by the caller: uv run --no-sync python tools/visual_parity/capture.py \ --case file_menu \ --target eegprep \ - --eegprep-command "python -m eegprep.functions.guifunc.visual_capture --case {case_id} --output {output}" + --eegprep-command "python -m tools.visual_parity.visual_capture --case {case_id} --output {output}" Capture commands receive these environment variables: diff --git a/pyproject.toml b/pyproject.toml index 5fb8d58f..074ac85c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,9 +46,6 @@ dependencies = [ torch = [ "torch>=2.0" ] -eeglabio = [ - "eeglabio>=0.1.2" -] gui = [ "pyqtgraph>=0.13.7", "PySide6>=6.6", @@ -71,7 +68,6 @@ docs = [ ] all = [ "eegprep[torch]", - "eegprep[eeglabio]", "eegprep[gui]", "eegprep[console]", "eegprep[docs]", @@ -85,7 +81,12 @@ eegprep-validate-extension-catalog = "eegprep.extension_catalog:main" [dependency-groups] dev = [ + # Keep `uv run eegprep-console --full` working from a fresh source checkout. + # Published installs still keep GUI/console dependencies behind extras. + "ipython>=8.0", "pytest>=8.0", + "pyqtgraph>=0.13.7", + "PySide6>=6.6", "ruff>=0.15.14", "tomli>=2.0; python_version < '3.11'", "ty>=0.0.39", diff --git a/scripts/pop_reref_helper.py b/scripts/pop_reref_helper.py new file mode 100644 index 00000000..dc7ca012 --- /dev/null +++ b/scripts/pop_reref_helper.py @@ -0,0 +1,20 @@ +"""Helper script for re-referencing EEG data.""" + +import sys + +from eegprep.functions.popfunc.pop_loadset import pop_loadset +from eegprep.functions.popfunc.pop_reref import pop_reref +from eegprep.functions.popfunc.pop_saveset import pop_saveset + +if __name__ == "__main__": + # check if a parameter is present and if it is assign eeglab_file_path to it + if len(sys.argv) > 2: + eeglab_file_path_in = sys.argv[1] + eeglab_file_path_out = sys.argv[2] + else: + eeglab_file_path_in = './eeglab_data_with_ica_tmp.set' + eeglab_file_path_out = './eeglab_data_with_ica_tmp_averef.set' + + EEG = pop_loadset(eeglab_file_path_in) + EEG = pop_reref(EEG, []) + pop_saveset(EEG, eeglab_file_path_out) diff --git a/src/eegprep/cli/commands/bids.py b/src/eegprep/cli/commands/bids.py index 22053e51..32a1444f 100644 --- a/src/eegprep/cli/commands/bids.py +++ b/src/eegprep/cli/commands/bids.py @@ -93,7 +93,7 @@ def validate_dataset(root: str | Path) -> dict[str, Any]: if not (root_path / "dataset_description.json").exists(): warnings.append(_issue("BIDS_VALIDATION_WARNING", "dataset_description.json is missing", root_path)) if not files: - warnings.append(_issue("BIDS_VALIDATION_WARNING", "No supported EEG files were found", root_path)) + errors.append(_issue("BIDS_EEG_FILES_MISSING", "No supported EEG files were found", root_path)) status = "error" if errors else "warning" if warnings else "ok" return { "status": status, @@ -294,12 +294,11 @@ def _output_record(path: Path) -> dict[str, Any]: def _import_bids_file(eeg_file: Path) -> tuple[dict[str, Any] | list[dict[str, Any]], str, list[dict[str, str]]]: - try: - EEG, history = pop_importbids(eeg_file, return_com=True) - return EEG, history, [] - except IndexError as exc: - if "only integers" not in str(exc) or eeg_file.suffix.lower() != ".set": - raise + # The BIDS metadata importer (pop_importbids -> pop_load_frombids) only applies sidecar + # metadata to raw recording formats (.edf/.bdf/.vhdr). EEGLAB .set files carry their own + # metadata and cannot be routed through the sidecar-application path, so load them directly + # with the EEGLAB .set loader instead of attempting BIDS sidecar import. + if eeg_file.suffix.lower() == ".set": EEG = load_dataset(eeg_file) history = f"EEG = pop_importbids('{eeg_file.as_posix()}');" EEG["history"] = history @@ -310,13 +309,15 @@ def _import_bids_file(eeg_file: Path) -> tuple[dict[str, Any] | list[dict[str, A { "code": "BIDS_SIDECARS_SKIPPED", "message": ( - "BIDS sidecar application failed for this EEGLAB .set file; " - "data import was retried with the EEGLAB .set loader." + "EEGLAB .set files are imported with the EEGLAB .set loader; " + "BIDS sidecar metadata is not applied for this format." ), "path": str(eeg_file), } ], ) + EEG, history = pop_importbids(eeg_file, return_com=True) + return EEG, history, [] def _dataset_cli_summary(EEG: dict[str, Any], path: Path) -> dict[str, Any]: diff --git a/src/eegprep/cli/commands/eeglab.py b/src/eegprep/cli/commands/migrate.py similarity index 90% rename from src/eegprep/cli/commands/eeglab.py rename to src/eegprep/cli/commands/migrate.py index 51da4bf7..3cbd6e34 100644 --- a/src/eegprep/cli/commands/eeglab.py +++ b/src/eegprep/cli/commands/migrate.py @@ -1,4 +1,4 @@ -"""EEGLAB migration and compatibility commands for the EEGPrep CLI.""" +"""Migration and compatibility commands for the EEGPrep CLI.""" from __future__ import annotations @@ -33,22 +33,22 @@ def register(subparsers: argparse._SubParsersAction) -> argparse.ArgumentParser: - """Register ``eeglab`` compatibility commands.""" - parser = subparsers.add_parser("eeglab", help="Inspect EEGLAB history and migration compatibility.") - eeglab_sub = parser.add_subparsers(dest="eeglab_command", required=True) + """Register migration compatibility commands.""" + parser = subparsers.add_parser("migrate", help="Inspect old EEGLAB histories and migration compatibility.") + migrate_sub = parser.add_subparsers(dest="migrate_command", required=True) - history_parser = eeglab_sub.add_parser("history", help="Inspect mapped EEGLAB history operations.") + history_parser = migrate_sub.add_parser("history", help="Inspect mapped EEGLAB history operations.") history_parser.add_argument("input") history_parser.add_argument("--json", action="store_true") history_parser.set_defaults(handler=lambda args: history(args.input)) - compare_parser = eeglab_sub.add_parser("compare", help="Compare two EEGLAB .set datasets.") + compare_parser = migrate_sub.add_parser("compare", help="Compare two EEGLAB .set datasets.") compare_parser.add_argument("left") compare_parser.add_argument("right") compare_parser.add_argument("--json", action="store_true") compare_parser.set_defaults(handler=lambda args: compare(args.left, args.right)) - convert = eeglab_sub.add_parser("convert-script", help="Best-effort conversion of simple EEGLAB scripts to YAML.") + convert = migrate_sub.add_parser("convert-script", help="Best-effort conversion of simple EEGLAB scripts to YAML.") convert.add_argument("script") convert.add_argument("--to", choices=("eegprep-yaml",), default="eegprep-yaml") convert.add_argument("--output") @@ -95,7 +95,7 @@ def history(input_path: str | Path) -> dict[str, Any]: operations.append(record) return { "status": "ok", - "schema_version": "eegprep.eeglab.history.v1", + "schema_version": "eegprep.migrate.history.v1", "input": str(input_path), "history_detected": bool(operations), "operations": operations, @@ -148,7 +148,7 @@ def compare(left: str | Path, right: str | Path) -> dict[str, Any]: differences.append({"path": "data", "code": "DATA_VALUE_MISMATCH", "max_abs_diff": max_abs_diff}) return { "status": "ok", - "schema_version": "eegprep.eeglab.compare.v1", + "schema_version": "eegprep.migrate.compare.v1", "left": str(left), "right": str(right), "equivalent": not differences, @@ -207,7 +207,7 @@ def convert_script( output_path_value = str(destination) return { "status": "ok", - "schema_version": "eegprep.eeglab.convert_script.v1", + "schema_version": "eegprep.migrate.convert_script.v1", "source": str(source), "target": target, "output": output_path_value, @@ -219,10 +219,10 @@ def convert_script( def main(argv: list[str] | None = None) -> int: """Standalone module harness for tests and local debugging.""" - parser = argparse.ArgumentParser(prog="eegprep eeglab") + parser = argparse.ArgumentParser(prog="eegprep migrate") subparsers = parser.add_subparsers(dest="command", required=True) register(subparsers) - args = parser.parse_args(["eeglab", *(sys.argv[1:] if argv is None else argv)]) + args = parser.parse_args(["migrate", *(sys.argv[1:] if argv is None else argv)]) result = args.handler(args) print(json.dumps(json_safe(result), sort_keys=True)) return 0 if result.get("status") in {"ok", "warning"} else 1 diff --git a/src/eegprep/cli/commands/pipeline.py b/src/eegprep/cli/commands/pipeline.py index b850da33..7b22548b 100644 --- a/src/eegprep/cli/commands/pipeline.py +++ b/src/eegprep/cli/commands/pipeline.py @@ -4,7 +4,9 @@ import argparse import importlib.util +import logging import sys +from contextlib import redirect_stdout from pathlib import Path from typing import Any @@ -195,12 +197,14 @@ def handle_registered(args: argparse.Namespace) -> dict[str, Any]: if args.pipeline_action == "plan": return plan_pipeline_config(args.config) if args.pipeline_action == "run": - return run_pipeline_config( - args.config, - dry_run=args.dry_run, - manifest_path=args.manifest, - overwrite=True if args.overwrite else None, - ) + _configure_logging(args) + with redirect_stdout(sys.stderr): + return run_pipeline_config( + args.config, + dry_run=args.dry_run, + manifest_path=args.manifest, + overwrite=True if args.overwrite else None, + ) raise CommandError("COMMAND_NOT_IMPLEMENTED", f"Unknown pipeline action: {args.pipeline_action}") @@ -525,9 +529,17 @@ def _apply_filter(EEG: dict[str, Any], parameters: dict[str, Any]) -> tuple[dict notch = parameters.get("notch") if notch is not None: width = float(parameters.get("notch_width") or 2.0) + lower_edge = float(notch) - width / 2 + if lower_edge <= 0: + raise CommandError( + "CONFIG_SCHEMA_ERROR", + "notch minus half notch_width must be positive.", + path="steps[].notch", + suggestion="Increase notch or decrease notch_width so the notch stop band stays above 0 Hz.", + ) EEG, history = pop_eegfiltnew( EEG, - locutoff=float(notch) - width / 2, + locutoff=lower_edge, hicutoff=float(notch) + width / 2, revfilt=True, plotfreqz=False, @@ -729,5 +741,15 @@ def _warning(code: str, path: str, message: str) -> dict[str, str]: return {"code": code, "path": path, "message": message} +def _configure_logging(args: argparse.Namespace) -> None: + if getattr(args, "quiet", False) or getattr(args, "no_progress", False): + level = logging.WARNING + elif getattr(args, "verbose", False): + level = logging.DEBUG + else: + level = logging.INFO + logging.basicConfig(level=level, format="%(message)s", stream=sys.stderr, force=True) + + if __name__ == "__main__": raise SystemExit(main()) diff --git a/src/eegprep/cli/commands/transforms.py b/src/eegprep/cli/commands/transforms.py index 55a70874..fc3f1813 100644 --- a/src/eegprep/cli/commands/transforms.py +++ b/src/eegprep/cli/commands/transforms.py @@ -1,9 +1,8 @@ """Core EEG dataset transform commands for the headless CLI. -This module is intentionally dispatcher-neutral. The eventual top-level CLI can -call :func:`register_subcommands`; tests can use the module-level harness with -``python -m eegprep.cli.commands.transforms`` until the shared CLI foundation is -available. +This module is dispatcher-neutral: the top-level CLI mounts these commands via +:func:`register_subcommands`, and a standalone ``python -m eegprep.cli.commands.transforms`` +harness runs the same handlers for local testing. """ from __future__ import annotations @@ -80,7 +79,7 @@ def register_subcommands(subparsers: argparse._SubParsersAction, *, include_comm def build_parser() -> argparse.ArgumentParser: - """Build a standalone parser for this module's temporary harness.""" + """Build a standalone parser for this module's local test harness.""" parser = argparse.ArgumentParser(prog="python -m eegprep.cli.commands.transforms") subparsers = parser.add_subparsers(dest="transform_command", required=True) @@ -97,7 +96,7 @@ def run_transform_command(args: argparse.Namespace) -> dict[str, Any]: def main(argv: list[str] | None = None) -> int: - """Temporary module harness until the shared CLI dispatcher is available.""" + """Standalone module entry point for local transform testing.""" parser = build_parser() args = parser.parse_args(argv) diff --git a/src/eegprep/cli/core.py b/src/eegprep/cli/core.py index 6e0be4ba..e5378438 100644 --- a/src/eegprep/cli/core.py +++ b/src/eegprep/cli/core.py @@ -111,6 +111,7 @@ def command_error(command: str, error: EEGPrepCLIError) -> dict[str, Any]: "command": command, "code": error.code, "message": error.message, + "exit_code": error.exit_code, "error": payload, } @@ -122,7 +123,9 @@ def emit_command_result(result: dict[str, Any], *, json_output: bool = True) -> print(result.get("status", "ok")) if result.get("status") == "error": print(result.get("error", {}).get("message", ""), file=sys.stderr) - return 0 if result.get("status") == "ok" else 1 + if result.get("status") == "ok": + return 0 + return int(result.get("exit_code", 1) or 1) def print_result(result: dict[str, Any], *, as_json: bool) -> None: @@ -288,44 +291,3 @@ def _input_file_record(path: Path | dict[str, Any]) -> dict[str, Any]: if isinstance(path, dict): return json_safe(path) return {"path": str(path), "sha256": sha256_file(path)} - - -def run_main(handler: Any, args: Any) -> int: - try: - result = handler(args) - print_result(result, as_json=bool(getattr(args, "json", False))) - return 1 if result.get("status") == "error" else 0 - except EEGPrepCLIError as exc: - print_result(exc.to_response(), as_json=bool(getattr(args, "json", False))) - return exc.exit_code - except Exception as exc: - code = getattr(exc, "code", None) - message = getattr(exc, "message", None) - if code and message: - payload = { - "status": "error", - "schema_version": "eegprep.error.v1", - "code": code, - "message": message, - } - path = getattr(exc, "path", None) - suggestion = getattr(exc, "suggestion", None) - if path is not None: - payload["path"] = str(path) - if suggestion is not None: - payload["suggestion"] = suggestion - print_result(payload, as_json=bool(getattr(args, "json", False))) - return int(getattr(exc, "exit_code", 1) or 1) - error = EEGPrepCLIError( - "UNEXPECTED_ERROR", - str(exc), - suggestion="Rerun with --verbose or file an issue if this is reproducible.", - ) - print_result(error.to_response(), as_json=bool(getattr(args, "json", False))) - if bool(getattr(args, "verbose", False)): - raise - return 1 - - -def eprint(message: str) -> None: - print(message, file=sys.stderr) diff --git a/src/eegprep/cli/discovery.py b/src/eegprep/cli/discovery.py index 25b49268..174eb49e 100644 --- a/src/eegprep/cli/discovery.py +++ b/src/eegprep/cli/discovery.py @@ -70,7 +70,7 @@ def capabilities() -> dict[str, Any]: "supports_json": True, "supports_dry_run": False, }, - "eeglab": { + "migrate": { "description": "Inspect EEGLAB history, compare datasets, and convert simple MATLAB histories.", "inputs": ["eeglab_set", "matlab_script"], "outputs": ["json", "eegprep_pipeline_yaml"], @@ -96,12 +96,79 @@ def command_schema(command: str) -> dict[str, Any]: "properties": {"kind": {"enum": ["dataset", "events", "channels", "ica"]}}, }, "pipeline": pipeline_schema()["schema"], - "resample": _command_schema("resample", required=["input", "freq", "output"]), - "rereference": _command_schema("rereference", required=["input", "method", "output"]), - "filter": _command_schema("filter", required=["input", "output"]), - "clean": _command_schema("clean", required=["input", "method", "output"]), - "epoch": _command_schema("epoch", required=["input", "event_type", "tmin", "tmax", "output"]), - "ica": _command_schema("ica", required=["input", "method", "output"]), + "resample": _command_schema( + "resample", + required=["input", "freq"], + properties={ + "freq": {"type": "number"}, + "engine": {"enum": ["poly", "scipy"], "default": "poly"}, + }, + ), + "rereference": _command_schema( + "rereference", + required=["input", "method"], + properties={ + "method": {"enum": ["average", "channels"], "default": "average"}, + "channels": {"type": "array", "items": {"type": "string"}}, + "exclude": {"type": "array", "items": {"type": "string"}}, + "keep_ref": {"type": "boolean", "default": False}, + "huber": {"type": "number"}, + "refica": {"enum": ["on", "off", "backwardcomp", "remove"], "default": "on"}, + }, + ), + "filter": _command_schema( + "filter", + required=["input"], + properties={ + "highpass": {"type": "number"}, + "lowpass": {"type": "number"}, + "notch": {"type": "number"}, + "notch_width": {"type": "number", "default": 2.0}, + "order": {"type": "integer"}, + "minphase": {"type": "boolean", "default": False}, + "usefftfilt": {"type": "boolean", "default": False}, + }, + ), + "clean": _command_schema( + "clean", + required=["input", "method"], + properties={ + "method": {"enum": ["asr", "rawdata"], "default": "asr"}, + "burst_criterion": {"type": "number", "default": 20.0}, + "burst_rejection": {"type": "boolean", "default": False}, + "distance": {"enum": ["euclidean", "riemannian"], "default": "euclidean"}, + "flatline_criterion": {"type": "number"}, + "channel_criterion": {"type": "number"}, + "line_noise_criterion": {"type": "number"}, + "window_criterion": {"type": "number"}, + "highpass": {"type": "array", "items": {"type": "number"}}, + }, + ), + "epoch": _command_schema( + "epoch", + required=["input", "event_type", "tmin", "tmax"], + properties={ + "event_type": {"type": "array", "items": {"type": "string"}}, + "tmin": {"type": "number"}, + "tmax": {"type": "number"}, + "new_name": {"type": "string"}, + }, + ), + "ica": _command_schema( + "ica", + required=["input", "method"], + properties={ + "method": {"enum": ["runica", "picard", "amica", "runamica15"], "default": "runica"}, + "seed": {"type": "integer"}, + "deterministic": {"type": "boolean", "default": True}, + "maxsteps": {"type": "integer"}, + "pca": {"type": "integer"}, + "extended": {"type": "integer"}, + "channels": {"type": "array", "items": {"type": "string"}}, + "option": {"type": "array", "items": {"type": "string"}}, + "reorder": {"type": "boolean", "default": True}, + }, + ), "batch": { "schema_version": "eegprep.schema.command.batch.v1", "syntax": "eegprep batch run --pipeline --output-dir --json", @@ -149,9 +216,9 @@ def command_schema(command: str) -> dict[str, Any]: "task": {"type": "string"}, }, }, - "eeglab": { - "schema_version": "eegprep.schema.command.eeglab.v1", - "syntax": "eegprep eeglab ... --json", + "migrate": { + "schema_version": "eegprep.schema.command.migrate.v1", + "syntax": "eegprep migrate ... --json", "required": ["subcommand"], "properties": { "left": {"type": "string"}, @@ -238,9 +305,9 @@ def examples(name: str) -> dict[str, Any]: "eegprep bids validate bids_root --json", "eegprep bids export input.set --bids-root bids_out --subject 01 --task rest --json", ], - "eeglab": [ - "eegprep eeglab history sample_data/eeglab_data.set --json", - "eegprep eeglab compare left.set right.set --json", + "migrate": [ + "eegprep migrate history sample_data/eeglab_data.set --json", + "eegprep migrate compare left.set right.set --json", ], "skills": ["eegprep skills list --json", "eegprep skills get eegprep-cli"], } @@ -291,20 +358,22 @@ def _write_capability(description: str) -> dict[str, Any]: "outputs": ["eeglab_set", "manifest"], "supports_json": True, "supports_dry_run": False, - "requires_output": True, + "requires_output_or_overwrite": True, } -def _command_schema(command: str, *, required: list[str]) -> dict[str, Any]: +def _command_schema(command: str, *, required: list[str], properties: dict[str, Any]) -> dict[str, Any]: return { "schema_version": f"eegprep.schema.command.{command}.v1", "type": "object", "required": required, + "anyOf": [{"required": ["output"]}, {"required": ["overwrite"]}], "properties": { "input": {"type": "string"}, "output": {"type": "string"}, "manifest": {"type": "string"}, "overwrite": {"type": "boolean", "default": False}, "json": {"type": "boolean", "default": False}, + **properties, }, } diff --git a/src/eegprep/cli/main.py b/src/eegprep/cli/main.py index 265af365..f8cf8e12 100644 --- a/src/eegprep/cli/main.py +++ b/src/eegprep/cli/main.py @@ -13,7 +13,7 @@ from eegprep.cli.dataset import inspect_channels, inspect_dataset, inspect_events, inspect_ica, validate_dataset from eegprep.cli.commands import batch as batch_commands from eegprep.cli.commands import bids as bids_commands -from eegprep.cli.commands import eeglab as eeglab_commands +from eegprep.cli.commands import migrate as migrate_commands from eegprep.cli.commands import pipeline as pipeline_commands from eegprep.cli.commands import qc as qc_commands from eegprep.cli.commands import report as report_commands @@ -126,7 +126,7 @@ def build_parser() -> EEGPrepArgumentParser: report_commands.register(subparsers) batch_commands.register(subparsers) bids_commands.register(subparsers) - eeglab_commands.register(subparsers) + migrate_commands.register(subparsers) parser.set_defaults(handler=_handle_root) return parser @@ -228,7 +228,11 @@ def _handle_skill_path(args: argparse.Namespace) -> dict[str, Any]: def _json_requested(args: argparse.Namespace) -> bool: if bool(getattr(args, "json", False)): return True - return "--json" in (getattr(args, "qc_args", []) or []) + # Commands that consume their flags via argparse.REMAINDER (e.g. ``qc``) never bind a + # top-level ``--json`` on ``args``. The root parser already records whether ``--json`` + # appeared anywhere in argv, so consult that command-agnostic flag instead of + # introspecting any single subcommand's argument attribute. + return bool(EEGPrepArgumentParser.json_requested) if __name__ == "__main__": diff --git a/src/eegprep/extension_catalog.py b/src/eegprep/extension_catalog.py index 89dcbc25..770aa816 100644 --- a/src/eegprep/extension_catalog.py +++ b/src/eegprep/extension_catalog.py @@ -1,34 +1,25 @@ -"""Catalog metadata validation for EEGPrep extension submissions.""" +"""Runtime Extension Manager catalog loading for EEGPrep. + +This module loads the metadata-only catalog that the Extension Manager dialog and +console inventory display, and builds copyable (never executed) install/update +commands. The submission-curation CI validator lives in +``eegprep.extension_catalog_validation``; its public names are re-exported here so +the ``eegprep-validate-extension-catalog`` entry point and existing imports keep +working unchanged. +""" from __future__ import annotations -import argparse import json import os -import re import shlex -import sys -from collections.abc import Iterable, Mapping from dataclasses import dataclass, field from enum import Enum -from importlib import metadata, resources +from importlib import resources from pathlib import Path from typing import Any from urllib.parse import urlparse -from eegprep.extensions import ( - EXTENSION_API_VERSION, - EXTENSION_ENTRY_POINT_GROUP, - EXTENSION_NAMING_PREFIX, - ExtensionDependency, - ExtensionRegistry, - ExtensionSpec, - ExtensionStatus, - check_extension_compatibility, - extension_version_satisfies, - extension_version_spec_is_valid, -) - CATALOG_SCHEMA_VERSION = 1 CATALOG_KIND_MANAGER = "extension_manager" CATALOG_KIND_CURATION = "extension_curation" @@ -312,6 +303,11 @@ def _looks_like_archive(url: str) -> bool: return bool(lowered) and lowered.endswith(_ARCHIVE_SUFFIXES) +def _is_web_url(value: str) -> bool: + parsed = urlparse(value) + return parsed.scheme in {"https", "http"} and bool(parsed.netloc) + + def _text(value: Any) -> str: return str(value).strip() if value is not None else "" @@ -330,685 +326,23 @@ def _catalog_normalize_name(name: str) -> str: return str(name).strip().lower() -CATALOG_CURATION_STATUSES = ("submitted", "curated", "private", "internal") -CATALOG_REQUIRED_FIELDS = ( - "id", - "package_name", - "entry_point", - "extension_name", - "version", - "api_version", - "eegprep_requires", - "python_requires", - "license", - "maintainer", - "docs_url", - "source_url", - "description", - "curation", +# The submission-curation CI validator lives in extension_catalog_validation. Its +# public names are re-exported here so the eegprep-validate-extension-catalog entry +# point and existing imports of these symbols from eegprep.extension_catalog keep +# working. The validator imports the shared catalog constants and _is_web_url from +# this module lazily, so this re-import stays one-directional at module-load time. +from eegprep.extension_catalog_validation import ( # noqa: E402 + CATALOG_CURATION_STATUSES, + CATALOG_REQUIRED_FIELDS, + CatalogValidationIssue, + CatalogValidationOptions, + CatalogValidationReport, + load_catalog_entries, + main, + validate_catalog_entries, + validate_catalog_file, ) -_NAME_RE = re.compile(r"^[A-Za-z][A-Za-z0-9_.-]*$") -_PACKAGE_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9_.-]*$") - - -@dataclass(frozen=True) -class CatalogValidationIssue: - """One catalog metadata validation issue.""" - - message: str - entry_id: str = "" - field: str = "" - - def format(self) -> str: - """Return a concise human-readable issue line.""" - location = self.entry_id or "" - if self.field: - location = f"{location}.{self.field}" - return f"{location}: {self.message}" - - -@dataclass(frozen=True) -class CatalogValidationReport: - """Catalog validation result with blocking errors and non-blocking warnings.""" - - errors: tuple[CatalogValidationIssue, ...] = field(default_factory=tuple) - warnings: tuple[CatalogValidationIssue, ...] = field(default_factory=tuple) - - @property - def ok(self) -> bool: - """Return whether the catalog metadata has no blocking errors.""" - return not self.errors - - def format(self) -> str: - """Return a readable validation summary.""" - lines: list[str] = [] - if self.errors: - lines.append("Errors:") - lines.extend(f"- {issue.format()}" for issue in self.errors) - if self.warnings: - lines.append("Warnings:") - lines.extend(f"- {issue.format()}" for issue in self.warnings) - if not lines: - return "Catalog metadata is valid." - return "\n".join(lines) - - -@dataclass(frozen=True) -class CatalogValidationOptions: - """Validation switches for local and future catalog-CI checks.""" - - allow_private: bool = False - check_installed: bool = False - check_import: bool = False - current_eegprep_version: str | None = None - version_provider: Any = metadata.version - entry_points_provider: Any = metadata.entry_points - - -def load_catalog_entries(path: str | Path) -> tuple[dict[str, Any], ...]: - """Load catalog entries from a JSON file or a directory of JSON files.""" - catalog_path = Path(path) - if catalog_path.is_dir(): - files = _catalog_files(catalog_path) - entries: list[dict[str, Any]] = [] - for file_path in files: - entries.extend(load_catalog_entries(file_path)) - return tuple(entries) - - with catalog_path.open(encoding="utf-8") as handle: - payload = json.load(handle) - return _entries_from_payload(payload, catalog_path) - - -def validate_catalog_file( - path: str | Path, - *, - options: CatalogValidationOptions | None = None, -) -> CatalogValidationReport: - """Validate catalog metadata loaded from ``path``.""" - try: - entries = load_catalog_entries(path) - except (OSError, ValueError, json.JSONDecodeError) as exc: - return CatalogValidationReport(errors=(CatalogValidationIssue(str(exc)),)) - return validate_catalog_entries(entries, options=options) - - -def validate_catalog_entries( - entries: Iterable[Mapping[str, Any]], - *, - options: CatalogValidationOptions | None = None, -) -> CatalogValidationReport: - """Validate catalog metadata entries. - - Static validation does not require an installed extension package. Set - ``check_installed`` to verify distribution metadata and entry points, and - ``check_import`` to import the declared entry point through the registry. - """ - opts = options or CatalogValidationOptions() - errors: list[CatalogValidationIssue] = [] - warnings: list[CatalogValidationIssue] = [] - normalized_entries: list[dict[str, Any]] = [] - for index, entry in enumerate(entries): - if not isinstance(entry, Mapping): - errors.append(CatalogValidationIssue("Catalog entries must be mapping objects", f"entry[{index}]")) - continue - normalized_entries.append(dict(entry)) - - _validate_static_entries(normalized_entries, opts, errors, warnings) - _validate_catalog_conflicts(normalized_entries, errors) - if opts.check_installed or opts.check_import: - _validate_installed_entries(normalized_entries, opts, errors) - - return CatalogValidationReport(errors=tuple(errors), warnings=tuple(warnings)) - - -def _catalog_files(path: Path) -> tuple[Path, ...]: - catalog_json = path / "catalog.json" - if catalog_json.is_file(): - return (catalog_json,) - return tuple(sorted(candidate for candidate in path.rglob("*.json") if candidate.is_file())) - - -def _entries_from_payload(payload: Any, path: Path) -> tuple[dict[str, Any], ...]: - if isinstance(payload, list): - entries = payload - elif isinstance(payload, dict) and "extensions" in payload: - _validate_schema_version(payload, path) - entries = payload["extensions"] - elif isinstance(payload, dict) and "entries" in payload: - _validate_schema_version(payload, path) - entries = payload["entries"] - elif isinstance(payload, dict) and "id" in payload: - entries = [payload] - else: - raise ValueError(f"{path} must contain an extension entry or an 'extensions' list") - - if not isinstance(entries, list): - raise ValueError(f"{path} extensions payload must be a list") - if not all(isinstance(entry, dict) for entry in entries): - raise ValueError(f"{path} extension entries must be JSON objects") - return tuple(dict(entry) for entry in entries) - - -def _validate_schema_version(payload: Mapping[str, Any], path: Path) -> None: - catalog_kind = payload.get("catalog_kind", CATALOG_KIND_CURATION) - if catalog_kind != CATALOG_KIND_CURATION: - raise ValueError(f"{path} catalog_kind must be {CATALOG_KIND_CURATION!r}; got {catalog_kind!r}") - schema_version = payload.get("schema_version", CATALOG_SCHEMA_VERSION) - if schema_version != CATALOG_SCHEMA_VERSION: - raise ValueError(f"{path} schema_version must be {CATALOG_SCHEMA_VERSION}; got {schema_version!r}") - - -def _validate_static_entries( - entries: list[dict[str, Any]], - options: CatalogValidationOptions, - errors: list[CatalogValidationIssue], - warnings: list[CatalogValidationIssue], -) -> None: - for entry in entries: - entry_id = _entry_id(entry) - _validate_required_fields(entry, entry_id, errors) - _validate_names(entry, entry_id, errors, warnings) - _validate_urls(entry, entry_id, errors) - _validate_text_metadata(entry, entry_id, errors) - _validate_curation(entry, entry_id, options, errors, warnings) - _validate_version_policy(entry, entry_id, options, errors) - _validate_dependency_metadata(entry, entry_id, options, errors) - - -def _validate_required_fields(entry: Mapping[str, Any], entry_id: str, errors: list[CatalogValidationIssue]) -> None: - for field_name in CATALOG_REQUIRED_FIELDS: - if field_name not in entry: - errors.append(CatalogValidationIssue("Required catalog metadata is missing", entry_id, field_name)) - continue - value = entry[field_name] - if value in ("", None, [], {}): - errors.append(CatalogValidationIssue("Required catalog metadata must not be empty", entry_id, field_name)) - - -def _validate_names( - entry: Mapping[str, Any], - entry_id: str, - errors: list[CatalogValidationIssue], - warnings: list[CatalogValidationIssue], -) -> None: - for field_name in ("id", "entry_point", "extension_name"): - value = entry.get(field_name) - if isinstance(value, str) and _NAME_RE.match(value): - continue - errors.append( - CatalogValidationIssue( - "Must start with a letter and contain only letters, numbers, '.', '_', or '-'", - entry_id, - field_name, - ) - ) - - package_name = entry.get("package_name") - if not isinstance(package_name, str) or not _PACKAGE_RE.match(package_name): - errors.append( - CatalogValidationIssue( - "Must contain only letters, numbers, '.', '_', or '-' and start with a letter or number", - entry_id, - "package_name", - ) - ) - elif not package_name.startswith(EXTENSION_NAMING_PREFIX): - warnings.append( - CatalogValidationIssue( - f"Recommended package names start with {EXTENSION_NAMING_PREFIX!r}; discovery still uses entry points", - entry_id, - "package_name", - ) - ) - - -def _validate_urls(entry: Mapping[str, Any], entry_id: str, errors: list[CatalogValidationIssue]) -> None: - for field_name in ("docs_url", "source_url"): - value = entry.get(field_name) - if isinstance(value, str) and _is_web_url(value): - continue - errors.append(CatalogValidationIssue("Must be an https:// or http:// URL", entry_id, field_name)) - - -def _validate_text_metadata(entry: Mapping[str, Any], entry_id: str, errors: list[CatalogValidationIssue]) -> None: - license_value = entry.get("license") - if isinstance(license_value, str) and license_value.strip().lower() in {"unknown", "none", "n/a"}: - errors.append(CatalogValidationIssue("License must identify the extension license", entry_id, "license")) - - maintainer = entry.get("maintainer") - if isinstance(maintainer, dict): - maintainer_name = maintainer.get("name") - maintainer_contact = maintainer.get("email") or maintainer.get("url") - if not maintainer_name or not maintainer_contact: - errors.append( - CatalogValidationIssue( - "Maintainer metadata must include a name and email or URL", - entry_id, - "maintainer", - ) - ) - elif not isinstance(maintainer, str) or not maintainer.strip(): - errors.append(CatalogValidationIssue("Maintainer metadata must be a string or object", entry_id, "maintainer")) - - description = entry.get("description") - if isinstance(description, str) and len(description.strip()) < 20: - errors.append(CatalogValidationIssue("Description must be at least 20 characters", entry_id, "description")) - - -def _validate_curation( - entry: Mapping[str, Any], - entry_id: str, - options: CatalogValidationOptions, - errors: list[CatalogValidationIssue], - warnings: list[CatalogValidationIssue], -) -> None: - curation = entry.get("curation") - if not isinstance(curation, dict): - errors.append(CatalogValidationIssue("Curation metadata must be an object", entry_id, "curation")) - return - - status = str(curation.get("status") or "").strip().lower() - if status not in CATALOG_CURATION_STATUSES: - errors.append( - CatalogValidationIssue( - f"Curation status must be one of {', '.join(CATALOG_CURATION_STATUSES)}", - entry_id, - "curation.status", - ) - ) - return - - private = bool(entry.get("private") or entry.get("internal") or status in {"private", "internal"}) - if private and not options.allow_private: - errors.append( - CatalogValidationIssue( - "Private/internal extensions are supported by direct installation, not the public curated catalog", - entry_id, - "curation.status", - ) - ) - elif private: - warnings.append( - CatalogValidationIssue( - "Private/internal extension metadata is valid locally but does not carry curated status", - entry_id, - "curation.status", - ) - ) - - if status == "curated": - if not curation.get("reviewed_by"): - errors.append(CatalogValidationIssue("Curated entries must record reviewed_by", entry_id, "curation")) - if not curation.get("reviewed_at"): - errors.append(CatalogValidationIssue("Curated entries must record reviewed_at", entry_id, "curation")) - - -def _validate_version_policy( - entry: Mapping[str, Any], - entry_id: str, - options: CatalogValidationOptions, - errors: list[CatalogValidationIssue], -) -> None: - api_version = str(entry.get("api_version") or "") - if _major_version(api_version) != _major_version(EXTENSION_API_VERSION): - errors.append( - CatalogValidationIssue( - f"Extension API version {api_version!r} is not supported by this EEGPrep extension API", - entry_id, - "api_version", - ) - ) - - eegprep_requires = str(entry.get("eegprep_requires") or "") - eegprep_requires_valid = True - for field_name in ("eegprep_requires", "python_requires"): - value = entry.get(field_name) - if not isinstance(value, str) or not extension_version_spec_is_valid(value): - errors.append(CatalogValidationIssue("Must be a simple version specifier", entry_id, field_name)) - if field_name == "eegprep_requires": - eegprep_requires_valid = False - - if eegprep_requires and eegprep_requires_valid: - spec = ExtensionSpec( - name=str(entry.get("extension_name") or entry.get("id") or "catalog_entry"), - version=str(entry.get("version") or ""), - api_version=EXTENSION_API_VERSION, - package_name=str(entry.get("package_name") or ""), - eegprep_requires=eegprep_requires, - dependencies=tuple(_catalog_dependencies(entry)), - ) - compatibility = check_extension_compatibility( - spec, - current_version=options.current_eegprep_version, - version_provider=options.version_provider, - check_dependencies=False, - ) - for message in compatibility.incompatible: - errors.append(CatalogValidationIssue(message, entry_id, "eegprep_requires")) - - python_requires = str(entry.get("python_requires") or "") - python_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" - if python_requires and extension_version_spec_is_valid(python_requires): - if not extension_version_satisfies(python_version, python_requires): - errors.append( - CatalogValidationIssue( - f"Extension requires Python {python_requires}; current version is {python_version}", - entry_id, - "python_requires", - ) - ) - - -def _validate_dependency_metadata( - entry: Mapping[str, Any], - entry_id: str, - options: CatalogValidationOptions, - errors: list[CatalogValidationIssue], -) -> None: - if "dependencies" not in entry: - return - dependencies = entry.get("dependencies", ()) - if dependencies in ("", None): - return - if not isinstance(dependencies, list): - errors.append(CatalogValidationIssue("Dependencies must be a list", entry_id, "dependencies")) - return - for index, dependency in enumerate(dependencies): - field = f"dependencies[{index}]" - if not isinstance(dependency, dict): - errors.append(CatalogValidationIssue("Dependency entries must be objects", entry_id, field)) - continue - package = dependency.get("package") - if not isinstance(package, str) or not _PACKAGE_RE.match(package): - errors.append(CatalogValidationIssue("Dependency package name is invalid", entry_id, f"{field}.package")) - version_spec = dependency.get("version_spec", "") - if version_spec and (not isinstance(version_spec, str) or not extension_version_spec_is_valid(version_spec)): - errors.append(CatalogValidationIssue("Dependency version specifier is invalid", entry_id, field)) - if not options.check_installed: - continue - try: - installed_version = options.version_provider(str(package)) - except metadata.PackageNotFoundError: - if not dependency.get("optional", False): - errors.append( - CatalogValidationIssue(f"Required dependency {package!r} is not installed", entry_id, field) - ) - continue - if version_spec and not extension_version_satisfies(installed_version, str(version_spec)): - errors.append( - CatalogValidationIssue( - f"Dependency {package!r} requires {version_spec}; installed version is {installed_version}", - entry_id, - field, - ) - ) - - -def _validate_catalog_conflicts(entries: list[dict[str, Any]], errors: list[CatalogValidationIssue]) -> None: - seen: dict[tuple[str, str], str] = {} - for entry in entries: - entry_id = _entry_id(entry) - for field_name in ("id", "extension_name", "display_name"): - value = entry.get(field_name) - if not isinstance(value, str) or not value: - continue - key = (field_name, _normalize(value)) - owner = seen.get(key) - if owner is not None: - errors.append( - CatalogValidationIssue( - f"Conflicts with catalog entry {owner!r}", - entry_id, - field_name, - ) - ) - seen[key] = entry_id - - package_name = entry.get("package_name") - entry_point = entry.get("entry_point") - if isinstance(package_name, str) and isinstance(entry_point, str): - key = ("package-entry-point", f"{_normalize(package_name)}:{_normalize(entry_point)}") - owner = seen.get(key) - if owner is not None: - errors.append( - CatalogValidationIssue( - f"Conflicts with catalog entry {owner!r}", - entry_id, - "entry_point", - ) - ) - seen[key] = entry_id - - -def _validate_installed_entries( - entries: list[dict[str, Any]], - options: CatalogValidationOptions, - errors: list[CatalogValidationIssue], -) -> None: - selected_entry_points = tuple(_select_entry_points(options.entry_points_provider)) - for entry in entries: - entry_id = _entry_id(entry) - package_name = str(entry.get("package_name") or "") - entry_point_name = str(entry.get("entry_point") or "") - if not package_name or not entry_point_name: - continue - - try: - installed_version = options.version_provider(package_name) - except metadata.PackageNotFoundError: - errors.append( - CatalogValidationIssue(f"Package {package_name!r} is not installed", entry_id, "package_name") - ) - continue - - expected_version = str(entry.get("version") or "") - if expected_version and installed_version != expected_version: - errors.append( - CatalogValidationIssue( - f"Catalog version {expected_version} does not match installed package version {installed_version}", - entry_id, - "version", - ) - ) - - entry_point = _matching_entry_point(selected_entry_points, package_name, entry_point_name) - if entry_point is None: - errors.append( - CatalogValidationIssue( - f"Package {package_name!r} does not expose {entry_point_name!r} in {EXTENSION_ENTRY_POINT_GROUP}", - entry_id, - "entry_point", - ) - ) - continue - - if options.check_import: - _validate_imported_entry(entry, entry_point, options, errors) - - -def _validate_imported_entry( - entry: Mapping[str, Any], - entry_point: Any, - options: CatalogValidationOptions, - errors: list[CatalogValidationIssue], -) -> None: - entry_id = _entry_id(entry) - - def provider(*, group: str) -> tuple[Any, ...]: - if group == EXTENSION_ENTRY_POINT_GROUP: - return (entry_point,) - return () - - registry = ExtensionRegistry( - include_bundled=False, - entry_points_provider=provider, - current_version=options.current_eegprep_version, - version_provider=options.version_provider, - ) - record = registry.discover()[0] - if record.status in {ExtensionStatus.FAILED_IMPORT, ExtensionStatus.INVALID_SPEC, ExtensionStatus.INCOMPATIBLE}: - for message in record.errors: - errors.append(CatalogValidationIssue(message, entry_id, "entry_point")) - return - if record.status == ExtensionStatus.MISSING_DEPENDENCY: - for message in record.errors: - errors.append(CatalogValidationIssue(message, entry_id, "dependencies")) - return - if record.spec is None: - errors.append(CatalogValidationIssue("Entry point did not return an extension spec", entry_id, "entry_point")) - return - - _validate_imported_spec_matches_catalog(entry, record.spec, errors) - - -def _validate_imported_spec_matches_catalog( - entry: Mapping[str, Any], - spec: ExtensionSpec, - errors: list[CatalogValidationIssue], -) -> None: - entry_id = _entry_id(entry) - checks = ( - ("extension_name", spec.name), - ("version", spec.version), - ("api_version", spec.api_version), - ) - for field_name, actual in checks: - expected = str(entry.get(field_name) or "") - if expected and expected != actual: - errors.append( - CatalogValidationIssue( - f"Catalog value {expected!r} does not match imported spec value {actual!r}", - entry_id, - field_name, - ) - ) - - -def _catalog_dependencies(entry: Mapping[str, Any]) -> tuple[ExtensionDependency, ...]: - dependencies = entry.get("dependencies", ()) - if not isinstance(dependencies, list): - return () - parsed: list[ExtensionDependency] = [] - for dependency in dependencies: - if not isinstance(dependency, dict): - continue - package = dependency.get("package") - if not package: - continue - parsed.append( - ExtensionDependency( - package=str(package), - version_spec=str(dependency.get("version_spec") or ""), - optional=bool(dependency.get("optional", False)), - ) - ) - return tuple(parsed) - - -def _select_entry_points(provider: Any) -> tuple[Any, ...]: - try: - selected = provider(group=EXTENSION_ENTRY_POINT_GROUP) - except TypeError: - entry_points = provider() - if hasattr(entry_points, "select"): - selected = entry_points.select(group=EXTENSION_ENTRY_POINT_GROUP) - else: - selected = [ - entry_point - for entry_point in entry_points - if getattr(entry_point, "group", None) == EXTENSION_ENTRY_POINT_GROUP - ] - return tuple(selected or ()) - - -def _matching_entry_point(entry_points: tuple[Any, ...], package_name: str, entry_point_name: str) -> Any | None: - normalized_package = _normalize(package_name) - for entry_point in entry_points: - if getattr(entry_point, "name", None) != entry_point_name: - continue - entry_point_package = _entry_point_package_name(entry_point) - if entry_point_package and _normalize(entry_point_package) != normalized_package: - continue - return entry_point - return None - - -def _entry_point_package_name(entry_point: Any) -> str: - dist = getattr(entry_point, "dist", None) - dist_metadata = getattr(dist, "metadata", None) - if dist_metadata is None: - return "" - try: - return str(dist_metadata.get("Name") or "") - except AttributeError: - return "" - - -def _is_web_url(value: str) -> bool: - parsed = urlparse(value) - return parsed.scheme in {"https", "http"} and bool(parsed.netloc) - - -def _entry_id(entry: Mapping[str, Any]) -> str: - value = entry.get("id") - if isinstance(value, str) and value: - return value - package_name = entry.get("package_name") - if isinstance(package_name, str) and package_name: - return package_name - return "" - - -def _major_version(version: str) -> int: - token = str(version).strip().split(".", 1)[0] - return int(token) if token.isdigit() else -1 - - -def _normalize(value: str) -> str: - return value.strip().lower().replace("_", "-") - - -def _build_arg_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(description="Validate EEGPrep extension catalog metadata.") - parser.add_argument("path", help="Catalog JSON file, entry JSON file, or directory of JSON files") - parser.add_argument( - "--allow-private", action="store_true", help="Allow private/internal metadata for local catalogs" - ) - parser.add_argument( - "--check-installed", action="store_true", help="Verify installed package metadata and entry points" - ) - parser.add_argument("--check-import", action="store_true", help="Import declared entry points through the registry") - parser.add_argument("--json", action="store_true", help="Emit JSON validation output") - return parser - - -def main(argv: list[str] | None = None) -> int: - """Run the catalog validator command-line interface.""" - args = _build_arg_parser().parse_args(argv) - options = CatalogValidationOptions( - allow_private=args.allow_private, - check_installed=args.check_installed or args.check_import, - check_import=args.check_import, - ) - report = validate_catalog_file(args.path, options=options) - if args.json: - print( - json.dumps( - { - "ok": report.ok, - "errors": [issue.format() for issue in report.errors], - "warnings": [issue.format() for issue in report.warnings], - }, - indent=2, - ) - ) - else: - print(report.format()) - return 0 if report.ok else 1 - - __all__ = [ "CATALOG_CURATION_STATUSES", "CATALOG_KIND_CURATION", @@ -1033,7 +367,3 @@ def main(argv: list[str] | None = None) -> int: "validate_catalog_entries", "validate_catalog_file", ] - - -if __name__ == "__main__": # pragma: no cover - raise SystemExit(main()) diff --git a/src/eegprep/extension_catalog_validation.py b/src/eegprep/extension_catalog_validation.py new file mode 100644 index 00000000..331343bc --- /dev/null +++ b/src/eegprep/extension_catalog_validation.py @@ -0,0 +1,699 @@ +"""Submission-curation validator for EEGPrep extension catalog metadata. + +This module powers the ``eegprep-validate-extension-catalog`` curation CI command. +It is separate from the runtime Extension Manager catalog loader in +``eegprep.extension_catalog``; the two share only a handful of catalog +constants and the ``_is_web_url`` helper. +""" + +from __future__ import annotations + +import argparse +import json +import re +import sys +from collections.abc import Iterable, Mapping +from dataclasses import dataclass, field +from importlib import metadata +from pathlib import Path +from typing import Any + +from eegprep.extensions import ( + EXTENSION_API_VERSION, + EXTENSION_ENTRY_POINT_GROUP, + EXTENSION_NAMING_PREFIX, + ExtensionDependency, + ExtensionRegistry, + ExtensionSpec, + ExtensionStatus, + _entry_point_package_name, + _major_version, + _select_entry_points, + check_extension_compatibility, + extension_version_satisfies, + extension_version_spec_is_valid, +) + +CATALOG_CURATION_STATUSES = ("submitted", "curated", "private", "internal") +CATALOG_REQUIRED_FIELDS = ( + "id", + "package_name", + "entry_point", + "extension_name", + "version", + "api_version", + "eegprep_requires", + "python_requires", + "license", + "maintainer", + "docs_url", + "source_url", + "description", + "curation", +) + +_NAME_RE = re.compile(r"^[A-Za-z][A-Za-z0-9_.-]*$") +_PACKAGE_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9_.-]*$") + + +@dataclass(frozen=True) +class CatalogValidationIssue: + """One catalog metadata validation issue.""" + + message: str + entry_id: str = "" + field: str = "" + + def format(self) -> str: + """Return a concise human-readable issue line.""" + location = self.entry_id or "" + if self.field: + location = f"{location}.{self.field}" + return f"{location}: {self.message}" + + +@dataclass(frozen=True) +class CatalogValidationReport: + """Catalog validation result with blocking errors and non-blocking warnings.""" + + errors: tuple[CatalogValidationIssue, ...] = field(default_factory=tuple) + warnings: tuple[CatalogValidationIssue, ...] = field(default_factory=tuple) + + @property + def ok(self) -> bool: + """Return whether the catalog metadata has no blocking errors.""" + return not self.errors + + def format(self) -> str: + """Return a readable validation summary.""" + lines: list[str] = [] + if self.errors: + lines.append("Errors:") + lines.extend(f"- {issue.format()}" for issue in self.errors) + if self.warnings: + lines.append("Warnings:") + lines.extend(f"- {issue.format()}" for issue in self.warnings) + if not lines: + return "Catalog metadata is valid." + return "\n".join(lines) + + +@dataclass(frozen=True) +class CatalogValidationOptions: + """Validation switches for local and future catalog-CI checks.""" + + allow_private: bool = False + check_installed: bool = False + check_import: bool = False + current_eegprep_version: str | None = None + version_provider: Any = metadata.version + entry_points_provider: Any = metadata.entry_points + + +def load_catalog_entries(path: str | Path) -> tuple[dict[str, Any], ...]: + """Load catalog entries from a JSON file or a directory of JSON files.""" + catalog_path = Path(path) + if catalog_path.is_dir(): + files = _catalog_files(catalog_path) + entries: list[dict[str, Any]] = [] + for file_path in files: + entries.extend(load_catalog_entries(file_path)) + return tuple(entries) + + with catalog_path.open(encoding="utf-8") as handle: + payload = json.load(handle) + return _entries_from_payload(payload, catalog_path) + + +def validate_catalog_file( + path: str | Path, + *, + options: CatalogValidationOptions | None = None, +) -> CatalogValidationReport: + """Validate catalog metadata loaded from ``path``.""" + try: + entries = load_catalog_entries(path) + except (OSError, ValueError, json.JSONDecodeError) as exc: + return CatalogValidationReport(errors=(CatalogValidationIssue(str(exc)),)) + return validate_catalog_entries(entries, options=options) + + +def validate_catalog_entries( + entries: Iterable[Mapping[str, Any]], + *, + options: CatalogValidationOptions | None = None, +) -> CatalogValidationReport: + """Validate catalog metadata entries. + + Static validation does not require an installed extension package. Set + ``check_installed`` to verify distribution metadata and entry points, and + ``check_import`` to import the declared entry point through the registry. + """ + opts = options or CatalogValidationOptions() + errors: list[CatalogValidationIssue] = [] + warnings: list[CatalogValidationIssue] = [] + normalized_entries: list[dict[str, Any]] = [] + for index, entry in enumerate(entries): + if not isinstance(entry, Mapping): + errors.append(CatalogValidationIssue("Catalog entries must be mapping objects", f"entry[{index}]")) + continue + normalized_entries.append(dict(entry)) + + _validate_static_entries(normalized_entries, opts, errors, warnings) + _validate_catalog_conflicts(normalized_entries, errors) + if opts.check_installed or opts.check_import: + _validate_installed_entries(normalized_entries, opts, errors) + + return CatalogValidationReport(errors=tuple(errors), warnings=tuple(warnings)) + + +def _catalog_files(path: Path) -> tuple[Path, ...]: + catalog_json = path / "catalog.json" + if catalog_json.is_file(): + return (catalog_json,) + return tuple(sorted(candidate for candidate in path.rglob("*.json") if candidate.is_file())) + + +def _entries_from_payload(payload: Any, path: Path) -> tuple[dict[str, Any], ...]: + if isinstance(payload, list): + entries = payload + elif isinstance(payload, dict) and "extensions" in payload: + _validate_schema_version(payload, path) + entries = payload["extensions"] + elif isinstance(payload, dict) and "entries" in payload: + _validate_schema_version(payload, path) + entries = payload["entries"] + elif isinstance(payload, dict) and "id" in payload: + entries = [payload] + else: + raise ValueError(f"{path} must contain an extension entry or an 'extensions' list") + + if not isinstance(entries, list): + raise ValueError(f"{path} extensions payload must be a list") + if not all(isinstance(entry, dict) for entry in entries): + raise ValueError(f"{path} extension entries must be JSON objects") + return tuple(dict(entry) for entry in entries) + + +def _validate_schema_version(payload: Mapping[str, Any], path: Path) -> None: + # Local import breaks the extension_catalog <-> extension_catalog_validation cycle. + from eegprep.extension_catalog import CATALOG_KIND_CURATION, CATALOG_SCHEMA_VERSION + + catalog_kind = payload.get("catalog_kind", CATALOG_KIND_CURATION) + if catalog_kind != CATALOG_KIND_CURATION: + raise ValueError(f"{path} catalog_kind must be {CATALOG_KIND_CURATION!r}; got {catalog_kind!r}") + schema_version = payload.get("schema_version", CATALOG_SCHEMA_VERSION) + if schema_version != CATALOG_SCHEMA_VERSION: + raise ValueError(f"{path} schema_version must be {CATALOG_SCHEMA_VERSION}; got {schema_version!r}") + + +def _validate_static_entries( + entries: list[dict[str, Any]], + options: CatalogValidationOptions, + errors: list[CatalogValidationIssue], + warnings: list[CatalogValidationIssue], +) -> None: + for entry in entries: + entry_id = _entry_id(entry) + _validate_required_fields(entry, entry_id, errors) + _validate_names(entry, entry_id, errors, warnings) + _validate_urls(entry, entry_id, errors) + _validate_text_metadata(entry, entry_id, errors) + _validate_curation(entry, entry_id, options, errors, warnings) + _validate_version_policy(entry, entry_id, options, errors) + _validate_dependency_metadata(entry, entry_id, options, errors) + + +def _validate_required_fields(entry: Mapping[str, Any], entry_id: str, errors: list[CatalogValidationIssue]) -> None: + for field_name in CATALOG_REQUIRED_FIELDS: + if field_name not in entry: + errors.append(CatalogValidationIssue("Required catalog metadata is missing", entry_id, field_name)) + continue + value = entry[field_name] + if value in ("", None, [], {}): + errors.append(CatalogValidationIssue("Required catalog metadata must not be empty", entry_id, field_name)) + + +def _validate_names( + entry: Mapping[str, Any], + entry_id: str, + errors: list[CatalogValidationIssue], + warnings: list[CatalogValidationIssue], +) -> None: + for field_name in ("id", "entry_point", "extension_name"): + value = entry.get(field_name) + if isinstance(value, str) and _NAME_RE.match(value): + continue + errors.append( + CatalogValidationIssue( + "Must start with a letter and contain only letters, numbers, '.', '_', or '-'", + entry_id, + field_name, + ) + ) + + package_name = entry.get("package_name") + if not isinstance(package_name, str) or not _PACKAGE_RE.match(package_name): + errors.append( + CatalogValidationIssue( + "Must contain only letters, numbers, '.', '_', or '-' and start with a letter or number", + entry_id, + "package_name", + ) + ) + elif not package_name.startswith(EXTENSION_NAMING_PREFIX): + warnings.append( + CatalogValidationIssue( + f"Recommended package names start with {EXTENSION_NAMING_PREFIX!r}; discovery still uses entry points", + entry_id, + "package_name", + ) + ) + + +def _validate_urls(entry: Mapping[str, Any], entry_id: str, errors: list[CatalogValidationIssue]) -> None: + # Local import breaks the extension_catalog <-> extension_catalog_validation cycle. + from eegprep.extension_catalog import _is_web_url + + for field_name in ("docs_url", "source_url"): + value = entry.get(field_name) + if isinstance(value, str) and _is_web_url(value): + continue + errors.append(CatalogValidationIssue("Must be an https:// or http:// URL", entry_id, field_name)) + + +def _validate_text_metadata(entry: Mapping[str, Any], entry_id: str, errors: list[CatalogValidationIssue]) -> None: + license_value = entry.get("license") + if isinstance(license_value, str) and license_value.strip().lower() in {"unknown", "none", "n/a"}: + errors.append(CatalogValidationIssue("License must identify the extension license", entry_id, "license")) + + maintainer = entry.get("maintainer") + if isinstance(maintainer, dict): + maintainer_name = maintainer.get("name") + maintainer_contact = maintainer.get("email") or maintainer.get("url") + if not maintainer_name or not maintainer_contact: + errors.append( + CatalogValidationIssue( + "Maintainer metadata must include a name and email or URL", + entry_id, + "maintainer", + ) + ) + elif not isinstance(maintainer, str) or not maintainer.strip(): + errors.append(CatalogValidationIssue("Maintainer metadata must be a string or object", entry_id, "maintainer")) + + description = entry.get("description") + if isinstance(description, str) and len(description.strip()) < 20: + errors.append(CatalogValidationIssue("Description must be at least 20 characters", entry_id, "description")) + + +def _validate_curation( + entry: Mapping[str, Any], + entry_id: str, + options: CatalogValidationOptions, + errors: list[CatalogValidationIssue], + warnings: list[CatalogValidationIssue], +) -> None: + curation = entry.get("curation") + if not isinstance(curation, dict): + errors.append(CatalogValidationIssue("Curation metadata must be an object", entry_id, "curation")) + return + + status = str(curation.get("status") or "").strip().lower() + if status not in CATALOG_CURATION_STATUSES: + errors.append( + CatalogValidationIssue( + f"Curation status must be one of {', '.join(CATALOG_CURATION_STATUSES)}", + entry_id, + "curation.status", + ) + ) + return + + private = bool(entry.get("private") or entry.get("internal") or status in {"private", "internal"}) + if private and not options.allow_private: + errors.append( + CatalogValidationIssue( + "Private/internal extensions are supported by direct installation, not the public curated catalog", + entry_id, + "curation.status", + ) + ) + elif private: + warnings.append( + CatalogValidationIssue( + "Private/internal extension metadata is valid locally but does not carry curated status", + entry_id, + "curation.status", + ) + ) + + if status == "curated": + if not curation.get("reviewed_by"): + errors.append(CatalogValidationIssue("Curated entries must record reviewed_by", entry_id, "curation")) + if not curation.get("reviewed_at"): + errors.append(CatalogValidationIssue("Curated entries must record reviewed_at", entry_id, "curation")) + + +def _validate_version_policy( + entry: Mapping[str, Any], + entry_id: str, + options: CatalogValidationOptions, + errors: list[CatalogValidationIssue], +) -> None: + api_version = str(entry.get("api_version") or "") + if _major_version(api_version) != _major_version(EXTENSION_API_VERSION): + errors.append( + CatalogValidationIssue( + f"Extension API version {api_version!r} is not supported by this EEGPrep extension API", + entry_id, + "api_version", + ) + ) + + eegprep_requires = str(entry.get("eegprep_requires") or "") + eegprep_requires_valid = True + for field_name in ("eegprep_requires", "python_requires"): + value = entry.get(field_name) + if not isinstance(value, str) or not extension_version_spec_is_valid(value): + errors.append(CatalogValidationIssue("Must be a simple version specifier", entry_id, field_name)) + if field_name == "eegprep_requires": + eegprep_requires_valid = False + + if eegprep_requires and eegprep_requires_valid: + spec = ExtensionSpec( + name=str(entry.get("extension_name") or entry.get("id") or "catalog_entry"), + version=str(entry.get("version") or ""), + api_version=EXTENSION_API_VERSION, + package_name=str(entry.get("package_name") or ""), + eegprep_requires=eegprep_requires, + dependencies=tuple(_catalog_dependencies(entry)), + ) + compatibility = check_extension_compatibility( + spec, + current_version=options.current_eegprep_version, + version_provider=options.version_provider, + check_dependencies=False, + ) + for message in compatibility.incompatible: + errors.append(CatalogValidationIssue(message, entry_id, "eegprep_requires")) + + python_requires = str(entry.get("python_requires") or "") + python_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" + if python_requires and extension_version_spec_is_valid(python_requires): + if not extension_version_satisfies(python_version, python_requires): + errors.append( + CatalogValidationIssue( + f"Extension requires Python {python_requires}; current version is {python_version}", + entry_id, + "python_requires", + ) + ) + + +def _validate_dependency_metadata( + entry: Mapping[str, Any], + entry_id: str, + options: CatalogValidationOptions, + errors: list[CatalogValidationIssue], +) -> None: + if "dependencies" not in entry: + return + dependencies = entry.get("dependencies", ()) + if dependencies in ("", None): + return + if not isinstance(dependencies, list): + errors.append(CatalogValidationIssue("Dependencies must be a list", entry_id, "dependencies")) + return + for index, dependency in enumerate(dependencies): + field = f"dependencies[{index}]" + if not isinstance(dependency, dict): + errors.append(CatalogValidationIssue("Dependency entries must be objects", entry_id, field)) + continue + package = dependency.get("package") + if not isinstance(package, str) or not _PACKAGE_RE.match(package): + errors.append(CatalogValidationIssue("Dependency package name is invalid", entry_id, f"{field}.package")) + version_spec = dependency.get("version_spec", "") + if version_spec and (not isinstance(version_spec, str) or not extension_version_spec_is_valid(version_spec)): + errors.append(CatalogValidationIssue("Dependency version specifier is invalid", entry_id, field)) + if not options.check_installed: + continue + try: + installed_version = options.version_provider(str(package)) + except metadata.PackageNotFoundError: + if not dependency.get("optional", False): + errors.append( + CatalogValidationIssue(f"Required dependency {package!r} is not installed", entry_id, field) + ) + continue + if version_spec and not extension_version_satisfies(installed_version, str(version_spec)): + errors.append( + CatalogValidationIssue( + f"Dependency {package!r} requires {version_spec}; installed version is {installed_version}", + entry_id, + field, + ) + ) + + +def _validate_catalog_conflicts(entries: list[dict[str, Any]], errors: list[CatalogValidationIssue]) -> None: + seen: dict[tuple[str, str], str] = {} + for entry in entries: + entry_id = _entry_id(entry) + for field_name in ("id", "extension_name", "display_name"): + value = entry.get(field_name) + if not isinstance(value, str) or not value: + continue + key = (field_name, _normalize(value)) + owner = seen.get(key) + if owner is not None: + errors.append( + CatalogValidationIssue( + f"Conflicts with catalog entry {owner!r}", + entry_id, + field_name, + ) + ) + seen[key] = entry_id + + package_name = entry.get("package_name") + entry_point = entry.get("entry_point") + if isinstance(package_name, str) and isinstance(entry_point, str): + key = ("package-entry-point", f"{_normalize(package_name)}:{_normalize(entry_point)}") + owner = seen.get(key) + if owner is not None: + errors.append( + CatalogValidationIssue( + f"Conflicts with catalog entry {owner!r}", + entry_id, + "entry_point", + ) + ) + seen[key] = entry_id + + +def _validate_installed_entries( + entries: list[dict[str, Any]], + options: CatalogValidationOptions, + errors: list[CatalogValidationIssue], +) -> None: + selected_entry_points = _select_entry_points(options.entry_points_provider, EXTENSION_ENTRY_POINT_GROUP) + for entry in entries: + entry_id = _entry_id(entry) + package_name = str(entry.get("package_name") or "") + entry_point_name = str(entry.get("entry_point") or "") + if not package_name or not entry_point_name: + continue + + try: + installed_version = options.version_provider(package_name) + except metadata.PackageNotFoundError: + errors.append( + CatalogValidationIssue(f"Package {package_name!r} is not installed", entry_id, "package_name") + ) + continue + + expected_version = str(entry.get("version") or "") + if expected_version and installed_version != expected_version: + errors.append( + CatalogValidationIssue( + f"Catalog version {expected_version} does not match installed package version {installed_version}", + entry_id, + "version", + ) + ) + + entry_point = _matching_entry_point(selected_entry_points, package_name, entry_point_name) + if entry_point is None: + errors.append( + CatalogValidationIssue( + f"Package {package_name!r} does not expose {entry_point_name!r} in {EXTENSION_ENTRY_POINT_GROUP}", + entry_id, + "entry_point", + ) + ) + continue + + if options.check_import: + _validate_imported_entry(entry, entry_point, options, errors) + + +def _validate_imported_entry( + entry: Mapping[str, Any], + entry_point: Any, + options: CatalogValidationOptions, + errors: list[CatalogValidationIssue], +) -> None: + entry_id = _entry_id(entry) + + def provider(*, group: str) -> tuple[Any, ...]: + if group == EXTENSION_ENTRY_POINT_GROUP: + return (entry_point,) + return () + + registry = ExtensionRegistry( + include_bundled=False, + entry_points_provider=provider, + current_version=options.current_eegprep_version, + version_provider=options.version_provider, + ) + record = registry.discover()[0] + if record.status in {ExtensionStatus.FAILED_IMPORT, ExtensionStatus.INVALID_SPEC, ExtensionStatus.INCOMPATIBLE}: + for message in record.errors: + errors.append(CatalogValidationIssue(message, entry_id, "entry_point")) + return + if record.status == ExtensionStatus.MISSING_DEPENDENCY: + for message in record.errors: + errors.append(CatalogValidationIssue(message, entry_id, "dependencies")) + return + if record.spec is None: + errors.append(CatalogValidationIssue("Entry point did not return an extension spec", entry_id, "entry_point")) + return + + _validate_imported_spec_matches_catalog(entry, record.spec, errors) + + +def _validate_imported_spec_matches_catalog( + entry: Mapping[str, Any], + spec: ExtensionSpec, + errors: list[CatalogValidationIssue], +) -> None: + entry_id = _entry_id(entry) + checks = ( + ("extension_name", spec.name), + ("version", spec.version), + ("api_version", spec.api_version), + ) + for field_name, actual in checks: + expected = str(entry.get(field_name) or "") + if expected and expected != actual: + errors.append( + CatalogValidationIssue( + f"Catalog value {expected!r} does not match imported spec value {actual!r}", + entry_id, + field_name, + ) + ) + + +def _catalog_dependencies(entry: Mapping[str, Any]) -> tuple[ExtensionDependency, ...]: + dependencies = entry.get("dependencies", ()) + if not isinstance(dependencies, list): + return () + parsed: list[ExtensionDependency] = [] + for dependency in dependencies: + if not isinstance(dependency, dict): + continue + package = dependency.get("package") + if not package: + continue + parsed.append( + ExtensionDependency( + package=str(package), + version_spec=str(dependency.get("version_spec") or ""), + optional=bool(dependency.get("optional", False)), + ) + ) + return tuple(parsed) + + +def _matching_entry_point(entry_points: tuple[Any, ...], package_name: str, entry_point_name: str) -> Any | None: + normalized_package = _normalize(package_name) + for entry_point in entry_points: + if getattr(entry_point, "name", None) != entry_point_name: + continue + entry_point_package = _entry_point_package_name(entry_point) + if entry_point_package and _normalize(entry_point_package) != normalized_package: + continue + return entry_point + return None + + +def _entry_id(entry: Mapping[str, Any]) -> str: + value = entry.get("id") + if isinstance(value, str) and value: + return value + package_name = entry.get("package_name") + if isinstance(package_name, str) and package_name: + return package_name + return "" + + +def _normalize(value: str) -> str: + return value.strip().lower().replace("_", "-") + + +def _build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Validate EEGPrep extension catalog metadata.") + parser.add_argument("path", help="Catalog JSON file, entry JSON file, or directory of JSON files") + parser.add_argument( + "--allow-private", action="store_true", help="Allow private/internal metadata for local catalogs" + ) + parser.add_argument( + "--check-installed", action="store_true", help="Verify installed package metadata and entry points" + ) + parser.add_argument("--check-import", action="store_true", help="Import declared entry points through the registry") + parser.add_argument("--json", action="store_true", help="Emit JSON validation output") + return parser + + +def main(argv: list[str] | None = None) -> int: + """Run the catalog validator command-line interface.""" + args = _build_arg_parser().parse_args(argv) + options = CatalogValidationOptions( + allow_private=args.allow_private, + check_installed=args.check_installed or args.check_import, + check_import=args.check_import, + ) + report = validate_catalog_file(args.path, options=options) + if args.json: + print( + json.dumps( + { + "ok": report.ok, + "errors": [issue.format() for issue in report.errors], + "warnings": [issue.format() for issue in report.warnings], + }, + indent=2, + ) + ) + else: + print(report.format()) + return 0 if report.ok else 1 + + +__all__ = [ + "CATALOG_CURATION_STATUSES", + "CATALOG_REQUIRED_FIELDS", + "CatalogValidationIssue", + "CatalogValidationOptions", + "CatalogValidationReport", + "load_catalog_entries", + "main", + "validate_catalog_entries", + "validate_catalog_file", +] + + +if __name__ == "__main__": # pragma: no cover + raise SystemExit(main()) diff --git a/src/eegprep/extensions.py b/src/eegprep/extensions.py index 24ba3edf..6ad3f681 100644 --- a/src/eegprep/extensions.py +++ b/src/eegprep/extensions.py @@ -594,7 +594,7 @@ def _mark_duplicate_contributions(self, records: list[ExtensionRecord]) -> list[ final_records: list[ExtensionRecord] = [] for record in records: - if not _can_contribute(record): + if not record.is_active: final_records.append(record) continue @@ -932,19 +932,6 @@ def _invalid_record(record: ExtensionRecord, errors: tuple[str, ...]) -> Extensi ) -def _can_contribute(record: ExtensionRecord) -> bool: - return ( - record.enabled - and record.spec is not None - and record.status - in { - ExtensionStatus.BUNDLED, - ExtensionStatus.INSTALLED, - ExtensionStatus.CURATED, - } - ) - - def _record_sort_key(record: ExtensionRecord) -> tuple[int, str, str, str]: return ( _source_rank(record.source_type), diff --git a/src/eegprep/functions/adminfunc/console.py b/src/eegprep/functions/adminfunc/console.py index 9f8ecfb4..4efb9101 100644 --- a/src/eegprep/functions/adminfunc/console.py +++ b/src/eegprep/functions/adminfunc/console.py @@ -241,6 +241,8 @@ def __init__(self, bridge: EEGPrepConsoleWorkspace) -> None: def __getattr__(self, name: str) -> Any: if name.startswith("pop_"): return self._bridge.pop_wrapper(name) + if name == "eegh": + return self._bridge.namespace["eegh"] return getattr(eegprep, name) def __dir__(self) -> list[str]: @@ -267,10 +269,14 @@ def __call__(self, command: Any = None, *args: Any) -> str: if history_command and int(command) > 0: self.bridge.execute_history_command(history_command) return history_command - normalized = eegh(command, self.bridge.session.ALLCOM) + session = self.bridge.session + normalized = str(command).strip() + if normalized: + session.add_history(command) + else: + session.LASTCOM = "" if args and isinstance(args[0], dict): eegh(normalized, args[0]) - self.bridge.session.LASTCOM = normalized self.bridge.pull_from_session() return normalized @@ -352,19 +358,18 @@ def after_execute(self, source: str, *, success: bool = True) -> None: history_command = self._history_command_for_source(source, targets) changed = False - if "ALLEEG" in targets: + if "ALLEEG" in targets or "CURRENTSET" in targets: alleeg = self.namespace.get("ALLEEG", []) if not isinstance(alleeg, list): raise ValueError("ALLEEG must be a list of EEG datasets") - self.session.ALLEEG = alleeg - changed = True - - if "CURRENTSET" in targets: - current = _normalize_currentset(self.namespace.get("CURRENTSET")) - if current: - self.session.retrieve(current if len(current) > 1 else current[0]) - else: - self.session.CURRENTSET = [] + current = ( + _normalize_currentset(self.namespace.get("CURRENTSET")) + if "CURRENTSET" in targets + else self.session.CURRENTSET + ) + self.session.apply_workspace_state( + alleeg=alleeg, currentset=current, command="", append_dataset_history=False + ) changed = True if self._namespace_eeg_changed(targets): @@ -380,16 +385,16 @@ def after_execute(self, source: str, *, success: bool = True) -> None: changed = True if "STUDY" in targets: - self.session.STUDY = self.namespace.get("STUDY") - if "CURRENTSTUDY" not in targets: - self.session.CURRENTSTUDY = 1 if self.session.STUDY else 0 + study_kwargs: dict[str, Any] = {"study": self.namespace.get("STUDY"), "command": ""} + if "CURRENTSTUDY" in targets: + study_kwargs["currentstudy"] = self.namespace.get("CURRENTSTUDY") + self.session.apply_workspace_state(**study_kwargs) changed = True - if "CURRENTSTUDY" in targets: - self.session.CURRENTSTUDY = int(self.namespace.get("CURRENTSTUDY") or 0) + elif "CURRENTSTUDY" in targets: + self.session.apply_workspace_state(currentstudy=self.namespace.get("CURRENTSTUDY"), command="") changed = True if changed: - self.session.notify_changed() if history_command and history_command != self.session.LASTCOM: self.session.add_history(history_command) self.pull_from_session() @@ -401,12 +406,13 @@ def accept_pop_result(self, result: Any, args: tuple[Any, ...], kwargs: Mapping[ dataset_state = _extract_pop_dataset_state(result) if dataset_state is not None: alleeg, eeg, currentset, command = dataset_state - self.session.ALLEEG = alleeg - self.session.EEG = eeg - self.session.CURRENTSET = _normalize_currentset(currentset) - if command: - self.session.add_history(command, notify=False) - self.session.notify_changed() + self.session.apply_workspace_state( + alleeg=alleeg, + eeg=eeg, + currentset=currentset, + command=command, + append_dataset_history=False, + ) self._pop_updated_session = True self.pull_from_session() self._refresh() @@ -451,6 +457,8 @@ def _bind_exports(self, exports: Mapping[str, Any] | None) -> None: for name in export_names: if name == "__version__": self.namespace[name] = eegprep.__version__ + elif name == "eegh": + continue elif name.startswith("pop_"): wrapped = ConsolePopFunction(name, self, None if exports is None else exports[name]) self._wrapped_pop_exports[name] = wrapped @@ -926,6 +934,8 @@ def visit_Call(self, node: ast.Call) -> ast.AST: return node if node.func.id == "pop_reref": self._convert_pop_reref(node) + elif node.func.id == "pop_select": + self._convert_pop_select(node) return node def _convert_pop_reref(self, node: ast.Call) -> None: @@ -936,6 +946,16 @@ def _convert_pop_reref(self, node: ast.Call) -> None: if key in {"exclude", "interpchan"}: node.args[index + 1] = self._zero_base_channel_arg(node.args[index + 1]) + def _convert_pop_select(self, node: ast.Call) -> None: + # pop_select history is name/value pairs after EEG (no positional + # selection arg). Channel selections are 1-based in EEGLAB history but + # the Python API is 0-based, so zero-base the numeric channel lists to + # match pop_reref; channel-by-name and chantype selections pass through. + for index in range(1, len(node.args) - 1, 2): + key = self._string_constant(node.args[index]) + if key in {"channel", "nochannel", "rmchannel"}: + node.args[index + 1] = self._zero_base_channel_arg(node.args[index + 1]) + def _zero_base_channel_arg(self, node: ast.AST) -> ast.AST: if isinstance(node, ast.List): values = [_numeric_ast_constant_value(item) for item in node.elts] diff --git a/src/eegprep/functions/adminfunc/eeg_checkset.py b/src/eegprep/functions/adminfunc/eeg_checkset.py index b94b5c58..3e04c78f 100644 --- a/src/eegprep/functions/adminfunc/eeg_checkset.py +++ b/src/eegprep/functions/adminfunc/eeg_checkset.py @@ -545,16 +545,3 @@ def eeg_checkset(EEG, *checks, load_data=True): ) return EEG - - -def test_eeg_checkset(): - from eegprep.functions.popfunc.pop_loadset import pop_loadset - - eeglab_file_path = './sample_data/eeglab_data_with_ica_tmp_out2.set' - EEG = pop_loadset(eeglab_file_path) - EEG = eeg_checkset(EEG) - logger.info('Checkset done') - - -if __name__ == '__main__': - test_eeg_checkset() diff --git a/src/eegprep/functions/adminfunc/eeglabcompat.py b/src/eegprep/functions/adminfunc/eeglabcompat.py index f7b9c532..a9147b77 100644 --- a/src/eegprep/functions/adminfunc/eeglabcompat.py +++ b/src/eegprep/functions/adminfunc/eeglabcompat.py @@ -22,16 +22,21 @@ # can be either 'OCT' (for Oct2Py) or 'MAT' (MATLAB engine) default_runtime = 'MAT' -# directory where temporary .set files are written -# use environment variable if it exists -if 'TEMP_DIR' in os.environ: - temp_dir = os.environ['TEMP_DIR'] -elif 'TMPDIR' in os.environ: - temp_dir = os.environ['TMPDIR'] -else: + +def _temp_dir() -> str: + """Return the directory for temporary .set files used by MATLAB roundtrips. + + Resolved lazily so that importing this module has no filesystem side + effect; the fallback ``temp/`` directory is created only when a MATLAB + roundtrip actually runs. + """ + if 'TEMP_DIR' in os.environ: + return os.environ['TEMP_DIR'] + if 'TMPDIR' in os.environ: + return os.environ['TMPDIR'] temp_dir = str(REPO_ROOT / 'temp') - if not os.path.exists(temp_dir): - os.makedirs(temp_dir, exist_ok=True) + os.makedirs(temp_dir, exist_ok=True) + return temp_dir def _prepare_matlab_arg(arg: Any) -> Any: @@ -161,6 +166,7 @@ def wrapper(*args, **kwargs): for i, arg in enumerate(new_args): new_args[i] = _prepare_matlab_arg(arg) + temp_dir = _temp_dir() try: # temporary files with tempfile.NamedTemporaryFile(dir=temp_dir, suffix='.set', delete=False) as temp_file1: @@ -169,9 +175,9 @@ def wrapper(*args, **kwargs): temp_filename2 = temp_file2.name result_filename = temp_filename1 + '.result.set' result_extra_filename = temp_filename1 + '.result.mat' - print(f"temp_filename1: {temp_filename1}") - print(f"temp_filename2: {temp_filename2}") - print(f"result_filename: {result_filename}") + logger.debug("MATLAB roundtrip input set path: %s", temp_filename1) + logger.debug("MATLAB roundtrip args path: %s", temp_filename2) + logger.debug("MATLAB roundtrip result set path: %s", result_filename) # save all parameters in the temp_filename which is a .mat file if len(new_args) > 0: @@ -192,7 +198,7 @@ def wrapper(*args, **kwargs): pop_saveset(args[0], temp_filename1) self.engine.eval(f"EEG = pop_loadset('{temp_filename1}');", nargout=0) - print(f"Running in MATLAB/Octave: {eval_str}") + logger.debug("Running in MATLAB/Octave: %s", eval_str) self.engine.eval(eval_str, nargout=0) # output @@ -280,12 +286,12 @@ def get_eeglab(runtime: str = default_runtime, *, auto_file_roundtrip: bool = Tr try: engine = _cache[rt] except KeyError: - print(f"Loading {runtime} runtime...", end='', flush=True) + logger.info("Loading %s runtime...", runtime) # On the command line, type "octave-8.4.0" OCTAVE_EXECUTABLE or OCTAVE var path2eeglab = str(_resolve_eeglab_root()) matlab_test_dir = REPO_ROOT / 'tests' / 'matlab' scripts_dir = str(REPO_ROOT / 'scripts') - print("This is the path2eeglab: ", path2eeglab) + logger.debug("EEGLAB reference path: %s", path2eeglab) # not yet loaded, do so now if rt == 'oct': @@ -344,7 +350,7 @@ def get_eeglab(runtime: str = default_runtime, *, auto_file_roundtrip: bool = Tr engine.logger.setLevel(logging.INFO) _cache[rt] = engine - print('done.') + logger.info("Loaded %s runtime.", runtime) # optionally wrap the engine in a file-roundtripping wrapper if auto_file_roundtrip: @@ -496,66 +502,29 @@ def clean_artifacts( else: BurstRejection = 'on' - pop_saveset(EEG, './tmp.set') # 0.8 seconds - EEG2 = eeglab.pop_loadset('./tmp.set') # 2 seconds - EEG3 = eeglab.clean_artifacts( - EEG2, - 'ChannelCriterion', - ChannelCriterion, - 'LineNoiseCriterion', - LineNoiseCriterion, - 'FlatlineCriterion', - FlatlineCriterion, - 'BurstCriterion', - BurstCriterion, - 'BurstRejection', - BurstRejection, - 'WindowCriterion', - WindowCriterion, - 'Highpass', - Highpass, - 'WindowCriterionTolerances', - WindowCriterionTolerances, - ) - eeglab.pop_saveset(EEG3, './tmp2.set') # 2.4 seconds - EEG4 = pop_loadset('./tmp2.set') # 0.2 seconds - - # delete temporary files - os.remove('./tmp.set') - os.remove('./tmp2.set') - return EEG4 - - -# sys.exit() -def test_eeglab_compat(): - """Test EEGLAB compatibility.""" - eeglab_file_path = '/System/Volumes/Data/data/matlab/eeglab/sample_data/eeglab_data_epochs_ica.set' - - EEG = pop_loadset(eeglab_file_path) - EEG = pop_eegfiltnew(EEG, locutoff=5, hicutoff=25, revfilt=True, plotfreqz=False) - EEG = clean_artifacts( - EEG, - FlatlineCriterion=5, - ChannelCriterion=0.87, - LineNoiseCriterion=4, - Highpass=False, - BurstCriterion=20, - WindowCriterion=0.25, - BurstRejection=False, - WindowCriterionTolerances=[float('-inf'), 7], - ) - - # EEG = eeglab.pop_loadset(eeglab_file_path) - # TMPEEG = eeglab.pop_eegfiltnew(EEG, 'locutoff',5,'hicutoff',25,'revfilt',1,'plotfreqz',0) - # CLEANEDEEG = eeglab.clean_artifacts(TMPEEG, 'ChannelCriterion', 'off', - # 'LineNoiseCriterion', 'off', - # 'FlatlineCriterion', 'off', - # 'BurstCriterion', 'off', - # 'WindowCriterion', 0, - # 'Highpass',[0.25, 0.75], - # 'WindowCriterionTolerances', [-10000000, 8]) - - # clean_artifacts( EEG, ChannelCriterion='on' ) - - -# test_eeglab_compat() + with tempfile.TemporaryDirectory(prefix="eegprep_clean_artifacts_") as workdir: + input_path = Path(workdir) / "input.set" + output_path = Path(workdir) / "output.set" + pop_saveset(EEG, input_path) + EEG2 = eeglab.pop_loadset(str(input_path)) + EEG3 = eeglab.clean_artifacts( + EEG2, + 'ChannelCriterion', + ChannelCriterion, + 'LineNoiseCriterion', + LineNoiseCriterion, + 'FlatlineCriterion', + FlatlineCriterion, + 'BurstCriterion', + BurstCriterion, + 'BurstRejection', + BurstRejection, + 'WindowCriterion', + WindowCriterion, + 'Highpass', + Highpass, + 'WindowCriterionTolerances', + WindowCriterionTolerances, + ) + eeglab.pop_saveset(EEG3, str(output_path)) + return pop_loadset(output_path) diff --git a/src/eegprep/functions/adminfunc/plugin_menu.py b/src/eegprep/functions/adminfunc/plugin_menu.py index 15f214cb..a6ca9c5b 100644 --- a/src/eegprep/functions/adminfunc/plugin_menu.py +++ b/src/eegprep/functions/adminfunc/plugin_menu.py @@ -71,78 +71,41 @@ ExtensionStatus.UNKNOWN.value: "#eeeeee", } -_BUNDLED_PLUGINS: tuple[dict[str, Any], ...] = ( - { - "plugin": "clean_rawdata", - "name": "clean_rawdata", - "version": "bundled", - "foldername": "clean_rawdata", - "funcname": "pop_clean_rawdata", - "status": "ok", - "installed": True, - "source": "bundled", - "menu": "Tools > Reject data using Clean Rawdata and ASR", - "description": "Artifact Subspace Reconstruction and related channel/window cleaning workflows.", - "tags": ("artifact", "preprocessing"), - }, - { - "plugin": "ICLabel", - "name": "ICLabel", - "version": "bundled", - "foldername": "ICLabel", - "funcname": "pop_iclabel", - "status": "ok", - "installed": True, - "source": "bundled", - "menu": "Tools > Classify components using ICLabel", - "description": "Independent-component classification, flagging, and extended component properties.", - "tags": ("ica", "classification"), - }, - { - "plugin": "firfilt", - "name": "firfilt", - "version": "bundled", - "foldername": "firfilt", - "funcname": "pop_eegfiltnew", - "status": "ok", - "installed": True, - "source": "bundled", - "menu": "Tools > Filter the data", - "description": "Windowed-sinc, Parks-McClellan, moving-average, and new default FIR filtering.", - "tags": ("filter", "preprocessing"), - }, - { - "plugin": "dipfit", - "name": "DIPFIT", - "version": "bundled", - "foldername": "dipfit", - "funcname": "pop_dipfit_settings", - "status": "ok", - "installed": True, - "source": "bundled", - "menu": "Tools > Source localization using DIPFIT", - "description": "Source-localization menu surfaces and FieldTrip-backed DIPFIT workflows.", - "tags": ("source", "localization"), - }, - { - "plugin": "EEG_BIDS", - "name": "EEG-BIDS", - "version": "bundled", - "foldername": "EEG_BIDS", - "funcname": "pop_importbids", - "status": "ok", - "installed": True, - "source": "bundled", - "menu": "File > Import data / Export / BIDS tools", - "description": "BIDS import, export, validation, and metadata helpers for EEG datasets.", - "tags": ("import", "export", "bids", "study"), - }, -) +# EEGLAB-style top-level menu labels for the bundled plugin ports. Names, +# versions, descriptions, capabilities, and pop functions are derived from the +# extension registry so they cannot drift from the live discovery path. +_BUNDLED_PLUGIN_MENUS: dict[str, str] = { + "clean_rawdata": "Tools > Reject data using Clean Rawdata and ASR", + "ICLabel": "Tools > Classify components using ICLabel", + "firfilt": "Tools > Filter the data", + "dipfit": "Tools > Source localization using DIPFIT", + "EEG_BIDS": "File > Import data / Export / BIDS tools", +} def bundled_plugins() -> tuple[dict[str, Any], ...]: """Return metadata for EEGPrep extensions bundled in the installed package.""" - return tuple(deepcopy(plugin) for plugin in _BUNDLED_PLUGINS) + records = ExtensionRegistry(include_entry_points=False).discover() + plugins: list[dict[str, Any]] = [] + for record in records: + spec = record.spec + package_name = record.package_name or (spec.package_name if spec is not None else "") + plugins.append( + { + "plugin": record.name, + "name": (spec.display_name if spec is not None else "") or record.name, + "version": (spec.version if spec is not None else "") or "bundled", + "foldername": _folder_name(record, package_name), + "funcname": _first_pop_function(record), + "status": "ok", + "installed": True, + "source": record.source_type.value, + "menu": _BUNDLED_PLUGIN_MENUS[record.name], + "description": spec.description if spec is not None else "", + "tags": tuple(spec.capabilities if spec is not None else ()), + } + ) + return tuple(plugins) def plugin_status( diff --git a/src/eegprep/functions/adminfunc/pymat.py b/src/eegprep/functions/adminfunc/pymat.py index d1bf6152..48f334ae 100644 --- a/src/eegprep/functions/adminfunc/pymat.py +++ b/src/eegprep/functions/adminfunc/pymat.py @@ -266,133 +266,3 @@ def mat2py(obj): else: # Fallback: return the object as-is if no conversion rule applies return obj - - -def test_py2mat(): - """Test the py2mat and mat2py conversion functions with various data structures.""" - import scipy.io - - # Test basic functionality - print("=== Basic Test ===") - dicts = [{'a': 'adsaf1', 'b': 2.0}, {'a': 'adsaf', 'b': 4.0}, {'a': 'adsaf33', 'b': 7.0}] - struct_array = py2mat(dicts) - print("Original: ", dicts) - - mat2py(struct_array) - scipy.io.savemat('test1.mat', {'struct_array': struct_array}) - struct_array2 = scipy.io.loadmat('test1.mat') - struct_array2 = struct_array2['struct_array'][0] - dicts3 = mat2py(struct_array2) - print("Converted: ", dicts3) - - # Test nested dictionaries - print("\n=== Nested Dictionary Test ===") - nested_dicts = [ - {'name': 'item1', 'value': 10.5, 'config': {'enabled': True, 'threshold': 0.8}, 'tags': ['tag1', 'tag2']}, - {'name': 'item2', 'value': 20.3, 'config': {'enabled': False, 'threshold': 0.9}, 'tags': ['tag3']}, - ] - nested_struct = py2mat(nested_dicts) - print("Original: ", nested_dicts) - - nested_dict2 = mat2py(nested_struct) - print("Converted back (not fully compatible): ", nested_dict2) - - scipy.io.savemat('test2.mat', {'nested_struct': nested_struct}) - nested_struct2 = scipy.io.loadmat('test2.mat') - nested_struct2 = nested_struct2['nested_struct'][0] - nested_dict3 = mat2py(nested_struct2) - print("Converted: ", nested_dict3) - - # Test list of dictionaries as values - print("\n=== List of Dictionaries Test ===") - list_dict_data = [ - {'id': 1, 'measurements': [{'sensor': 'A', 'reading': 1.2}, {'sensor': 'B', 'reading': 2.3}]}, - { - 'id': 2, - 'measurements': [ - {'sensor': 'A', 'reading': 3.4}, - {'sensor': 'B', 'reading': 4.5}, - {'sensor': 'C', 'reading': 5.6}, - ], - }, - ] - list_dict_struct = py2mat(list_dict_data) - scipy.io.savemat('test3.mat', {'list_dict_struct': list_dict_struct}) - list_dict_struct2 = scipy.io.loadmat('test3.mat') - list_dict_struct2 = list_dict_struct2['list_dict_struct'][0] - list_dict_data3 = mat2py(list_dict_struct2) - print("Original: ", list_dict_data) - print("Converted: ", list_dict_data3) - - # Test single dictionary input - print("\n=== Single Dictionary Test ===") - single_dict = {'x': 1, 'y': 2, 'nested': {'a': 'hello', 'b': 'world'}} - single_struct = py2mat(single_dict) - scipy.io.savemat('test4.mat', {'single_struct': single_struct}) - single_struct2 = scipy.io.loadmat('test4.mat') - single_struct2 = single_struct2['single_struct'][0] - single_dict2 = mat2py(single_struct2) - print("Original: ", single_dict) - print("Converted: ", single_dict2) - - # Test numpy array of dictionaries - print("\n=== NumPy Array of Dictionaries Test ===") - dict_array = np.array( - [{'name': 'sensor1', 'value': 1.1}, {'name': 'sensor2', 'value': 2.2}, {'name': 'sensor3', 'value': 3.3}], - dtype=object, - ) - - array_dict_data = [ - {'id': 'device1', 'sensors': dict_array}, - { - 'id': 'device2', - 'sensors': np.array([{'name': 'sensorA', 'value': 4.4}, {'name': 'sensorB', 'value': 5.5}], dtype=object), - }, - ] - - array_dict_struct = py2mat(array_dict_data) - scipy.io.savemat('test5.mat', {'array_dict_struct': array_dict_struct}) - array_dict_struct2 = scipy.io.loadmat('test5.mat') - array_dict_struct2 = array_dict_struct2['array_dict_struct'][0] - array_dict_data2 = mat2py(array_dict_struct2) - print("Original: ", array_dict_data) - print("Converted: ", array_dict_data2) # Numpy array gets converted to a list of dicts - - params = [np.vstack([np.arange(1, 21), np.arange(101, 121)]), [[5, 8]], 10.0, [{'latency': 5.0}, {'latency': 10.0}]] - params_struct = py2mat(params) - scipy.io.savemat('test6.mat', {'params_struct': params_struct}) - params_struct2 = scipy.io.loadmat('test6.mat') - params_struct2 = params_struct2['params_struct'][0] - params_data2 = mat2py(params_struct2) - print("Original: ", params) - print("Converted: ", params_data2) - - # EEGLAB dataset - eeglab_file_path = '/System/Volumes/Data/data/matlab/eeglab/sample_data/eeglab_data_epochs_ica.set' - from eegprep.functions.popfunc.pop_loadset import pop_loadset - - pop_loadset(eeglab_file_path) - - # pop_loadset wihtout index adjustment - EEG_LOADMAT = scipy.io.loadmat(eeglab_file_path) - EEG_LOADMAT = mat2py(EEG_LOADMAT['EEG'][0]) - - # pop_saveset without index adjustment - EEG_TMP = EEG_LOADMAT.copy() - EEG_TMP = py2mat(EEG_TMP) - scipy.io.savemat('test7.set', {'EEG': EEG_TMP}) - - # load again - EEG_LOADMAT2 = scipy.io.loadmat('test7.set') - EEG_LOADMAT2 = mat2py(EEG_LOADMAT2['EEG'][0]) - - # Limitations - print("\n=== Limitations ===") - print( - "- Conversion back: py2mat then mat2py does not always work for nested structures (works when the file is saved as a .mat file)" - ) - print("- Numpy arrays of dicts are converted to lists of dicts (an intented feature)") - - -if __name__ == "__main__": - test_py2mat() diff --git a/src/eegprep/functions/eegobj/eegobj.py b/src/eegprep/functions/eegobj/eegobj.py index ad060af8..c37c53cf 100644 --- a/src/eegprep/functions/eegobj/eegobj.py +++ b/src/eegprep/functions/eegobj/eegobj.py @@ -26,6 +26,36 @@ def _tolist_if_available(value: Any) -> Any: return tolist() if callable(tolist) else value +def _resolve_eegprep_function(name): + """Resolve an eegprep function by name, or return None if unknown.""" + # Try globals first (for imported functions) + cand = globals().get(name) + if callable(cand): + return cand + # Try public package exports before probing EEGLAB-style modules. + try: + import eegprep as eegpkg + + cand = getattr(eegpkg, name, None) + if callable(cand): + return cand + if isinstance(cand, types.ModuleType): + sub = getattr(cand, name, None) + if callable(sub): + return sub + except Exception: + pass + for prefix in _EEGPREP_FUNCTION_MODULE_PREFIXES: + try: + mod = importlib.import_module(f"{prefix}.{name}") + cand = getattr(mod, name, None) + if callable(cand): + return cand + except Exception: + pass + return None + + class EEGobj: """Wrapper class for EEG datasets stored as dictionaries. @@ -48,35 +78,7 @@ def __init__(self, EEG_or_path): # Internal helper to resolve and call an eegprep function name def _call_eegprep(self, fname, *args, **kwargs): - def _resolve(n): - # Try globals first (for imported functions) - cand = globals().get(n) - if callable(cand): - return cand - # Try public package exports before probing EEGLAB-style modules. - try: - import eegprep as eegpkg - - cand = getattr(eegpkg, n, None) - if callable(cand): - return cand - if isinstance(cand, types.ModuleType): - sub = getattr(cand, n, None) - if callable(sub): - return sub - except Exception: - pass - for prefix in _EEGPREP_FUNCTION_MODULE_PREFIXES: - try: - mod = importlib.import_module(f"{prefix}.{n}") - cand = getattr(mod, n, None) - if callable(cand): - return cand - except Exception: - pass - return None - - func = _resolve(fname) + func = _resolve_eegprep_function(fname) if func is None: raise AttributeError(fname) @@ -116,14 +118,19 @@ def __getattr__(self, name): """Access EEG fields or eegprep functions. - If 'name' is a key in EEG, return EEG[name] (convenience). - - If 'name' is a function in eegprep, return a wrapper that: + - If 'name' resolves to a function in eegprep, return a wrapper that: self.EEG = func(deepcopy(self.EEG), ...) and returns updated EEG for convenience. + - Otherwise raise AttributeError so field-name typos fail fast instead + of silently returning a no-op callable. """ eeg = object.__getattribute__(self, 'EEG') if isinstance(eeg, dict) and name in eeg: return eeg[name] + if _resolve_eegprep_function(name) is None: + raise AttributeError(name) + def wrapper(*args, **kwargs): return self._call_eegprep(name, *args, **kwargs) @@ -255,8 +262,3 @@ def _safe(val, default=''): return '\n'.join(lines) __str__ = __repr__ - - -if __name__ == '__main__': - obj = EEGobj('sample_data/eeglab_data.set') - print(obj) diff --git a/src/eegprep/functions/guifunc/long_task.py b/src/eegprep/functions/guifunc/long_task.py new file mode 100644 index 00000000..a7aa9889 --- /dev/null +++ b/src/eegprep/functions/guifunc/long_task.py @@ -0,0 +1,163 @@ +"""Qt helper for running long GUI actions off the UI thread.""" + +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass +import logging +import threading +from typing import Any + +try: # pragma: no cover - depends on optional GUI dependency + from PySide6 import QtCore, QtWidgets +except ImportError: # pragma: no cover - depends on optional GUI dependency + QtCore = None + QtWidgets = None + + +@dataclass +class LongTaskHandle: + """Keep Qt task objects alive until their worker thread finishes.""" + + thread: Any + worker: Any + dialog: Any + receiver: Any | None = None + + +_LOGGER_LOCK = threading.Lock() +_LOGGER_DEPTH = 0 +_LOGGER_OLD_LEVEL: int | None = None + + +def run_long_task( + *, + parent: Any | None, + title: str, + label: str, + task: Callable[[], Any], + on_success: Callable[[Any], None], + on_error: Callable[[Exception], None], + on_finished: Callable[[LongTaskHandle], None] | None = None, +) -> LongTaskHandle: + """Run ``task`` in a Qt worker thread with an indeterminate progress dialog.""" + qt_core, qt_widgets = _require_qt() + + progress = qt_widgets.QProgressDialog(label, None, 0, 0, parent) + progress.setWindowTitle(title) + progress.setCancelButton(None) + progress.setAutoClose(False) + progress.setAutoReset(False) + progress.setMinimumDuration(0) + progress.setWindowModality(qt_core.Qt.WindowModal) + + class Worker(qt_core.QObject): + succeeded = qt_core.Signal(object) + failed = qt_core.Signal(object) + message = qt_core.Signal(str) + finished = qt_core.Signal() + + def run(self) -> None: + handler = _SignalLogHandler(self.message) + handler.setFormatter(logging.Formatter("%(message)s")) + with _ForwardEegprepLogs(handler): + self.message.emit(label) + try: + result = task() + self.succeeded.emit(result) + except Exception as exc: # noqa: BLE001 - forwarded to GUI error handler. + self.failed.emit(exc) + finally: + self.finished.emit() + + thread = qt_core.QThread() + worker = Worker() + + class Receiver(qt_core.QObject): + @qt_core.Slot(str) + def handle_message(self, message: str) -> None: + _update_progress_label(progress, label, message) + + @qt_core.Slot(object) + def handle_success(self, result: Any) -> None: + on_success(result) + + @qt_core.Slot(object) + def handle_error(self, exc: Exception) -> None: + on_error(exc) + + @qt_core.Slot() + def handle_finished(self) -> None: + progress.close() + if on_finished is not None: + on_finished(handle) + + receiver = Receiver() + handle = LongTaskHandle(thread=thread, worker=worker, dialog=progress, receiver=receiver) + + worker.moveToThread(thread) + thread.started.connect(worker.run) + worker.message.connect(receiver.handle_message) + worker.succeeded.connect(receiver.handle_success) + worker.failed.connect(receiver.handle_error) + worker.finished.connect(thread.quit) + worker.finished.connect(worker.deleteLater) + thread.finished.connect(thread.deleteLater) + thread.finished.connect(receiver.handle_finished) + thread.finished.connect(receiver.deleteLater) + + progress._eegprep_long_task = handle + progress.show() + qt_core.QTimer.singleShot(0, thread.start) + return handle + + +class _SignalLogHandler(logging.Handler): + def __init__(self, signal: Any): + super().__init__(level=logging.INFO) + self.signal = signal + + def emit(self, record: logging.LogRecord) -> None: + try: + self.signal.emit(self.format(record)) + except Exception: + self.handleError(record) + + +class _ForwardEegprepLogs: + def __init__(self, handler: logging.Handler) -> None: + self.handler = handler + self.logger = logging.getLogger("eegprep") + + def __enter__(self) -> None: + global _LOGGER_DEPTH, _LOGGER_OLD_LEVEL + with _LOGGER_LOCK: + if _LOGGER_DEPTH == 0: + _LOGGER_OLD_LEVEL = self.logger.level + if self.logger.level == logging.NOTSET or self.logger.level > logging.INFO: + self.logger.setLevel(logging.INFO) + _LOGGER_DEPTH += 1 + self.logger.addHandler(self.handler) + + def __exit__(self, _exc_type: Any, _exc: Any, _tb: Any) -> None: + global _LOGGER_DEPTH, _LOGGER_OLD_LEVEL + with _LOGGER_LOCK: + self.logger.removeHandler(self.handler) + _LOGGER_DEPTH -= 1 + if _LOGGER_DEPTH == 0: + self.logger.setLevel(logging.NOTSET if _LOGGER_OLD_LEVEL is None else _LOGGER_OLD_LEVEL) + _LOGGER_OLD_LEVEL = None + + +def _update_progress_label(progress: Any, label: str, message: str) -> None: + message = str(message).strip() + progress.setLabelText(label if not message or message == label else f"{label}\n{message}") + + +def _require_qt() -> tuple[Any, Any]: + if QtCore is None or QtWidgets is None: + raise RuntimeError( + "PySide6 is required for EEGPrep GUI progress dialogs. Install it with " + "`pip install -e .[gui]` or `pip install eegprep[gui]`." + ) + return QtCore, QtWidgets diff --git a/src/eegprep/functions/guifunc/menu_actions.py b/src/eegprep/functions/guifunc/menu_actions.py index a5254755..71f3cc84 100644 --- a/src/eegprep/functions/guifunc/menu_actions.py +++ b/src/eegprep/functions/guifunc/menu_actions.py @@ -5,11 +5,12 @@ import inspect import logging import webbrowser -from collections.abc import Callable, Iterable, Mapping +from collections.abc import Callable, Mapping from pathlib import Path from typing import Any from eegprep.extension_runtime import ExtensionRuntime +from eegprep.functions.guifunc.long_task import LongTaskHandle, run_long_task from eegprep.functions.guifunc.menu_placeholders import PLACEHOLDER_ACTIONS, is_placeholder_action, placeholder_message from eegprep.functions.guifunc.pophelp import pophelp from eegprep.functions.guifunc.session import EEGPrepSession, has_eeg_data @@ -23,7 +24,6 @@ # Action handlers import user-facing pop/plugin modules lazily so launching the # main GUI does not eagerly load heavier signal-processing and optional stacks. _EXTENSION_FILE_PARAMETERS = ("filename", "filepath", "path") -_EEG_CORE_FIELDS = ("nbchan", "srate", "pnts", "trials") IMPLEMENTED_ACTIONS = { "clear_study", @@ -215,6 +215,36 @@ "pop_subcomp", } +# Actions dispatched straight to _run_pop_function by name, with no variant; that +# method owns all per-action wiring. Variant-aware actions (pop_eegplot, the +# rejection group) are dispatched separately below. +_SIMPLE_POP_ACTIONS = { + "pop_adjustevents", + "pop_comments", + "pop_editset", + "pop_editeventfield", + "pop_editeventvals", + "pop_chanedit", + "pop_clean_rawdata", + "pop_eegfilt", + "pop_eegfiltnew", + "pop_epoch", + "pop_firma", + "pop_firpm", + "pop_firws", + "pop_reref", + "pop_interp", + "pop_resample", + "pop_rmbase", + "pop_rmdat", + "pop_runica", + "pop_select", + "pop_selectevent", + "pop_iclabel", + "pop_icflag", + "pop_subcomp", +} + class MenuActionDispatcher: """Dispatch menu action identifiers to real functions or placeholders.""" @@ -231,6 +261,7 @@ def __init__( self.refresh = refresh self.native_file_dialogs = native_file_dialogs self.extension_runtime = extension_runtime or ExtensionRuntime.empty() + self._long_tasks: list[LongTaskHandle] = [] def dispatch_gui(self, action: str, parent: Any | None = None) -> None: """Run a menu action from Qt and show user-facing errors.""" @@ -351,77 +382,8 @@ def dispatch(self, action: str, parent: Any | None = None) -> None: if base in {"pop_taskinfo", "pop_participantinfo", "pop_eventinfo", "validate_bids"}: self._bids_tool_action(base, parent) return - if base == "pop_adjustevents": - self._run_pop_function("pop_adjustevents", parent) - return - if base == "pop_comments": - self._run_pop_function("pop_comments", parent) - return - if base == "pop_editset": - self._run_pop_function("pop_editset", parent) - return - if base == "pop_editeventfield": - self._run_pop_function("pop_editeventfield", parent) - return - if base == "pop_editeventvals": - self._run_pop_function("pop_editeventvals", parent) - return - if base == "pop_chanedit": - self._run_pop_function("pop_chanedit", parent) - return - if base == "pop_clean_rawdata": - self._run_pop_function("pop_clean_rawdata", parent) - return - if base == "pop_eegfilt": - self._run_pop_function("pop_eegfilt", parent) - return - if base == "pop_eegfiltnew": - self._run_pop_function("pop_eegfiltnew", parent) - return - if base == "pop_epoch": - self._run_pop_function("pop_epoch", parent) - return - if base == "pop_firma": - self._run_pop_function("pop_firma", parent) - return - if base == "pop_firpm": - self._run_pop_function("pop_firpm", parent) - return - if base == "pop_firws": - self._run_pop_function("pop_firws", parent) - return - if base == "pop_reref": - self._run_pop_function("pop_reref", parent) - return - if base == "pop_interp": - self._run_pop_function("pop_interp", parent) - return - if base == "pop_resample": - self._run_pop_function("pop_resample", parent) - return - if base == "pop_rmbase": - self._run_pop_function("pop_rmbase", parent) - return - if base == "pop_rmdat": - self._run_pop_function("pop_rmdat", parent) - return - if base == "pop_runica": - self._run_pop_function("pop_runica", parent) - return - if base == "pop_select": - self._run_pop_function("pop_select", parent) - return - if base == "pop_selectevent": - self._run_pop_function("pop_selectevent", parent) - return - if base == "pop_iclabel": - self._run_pop_function("pop_iclabel", parent) - return - if base == "pop_icflag": - self._run_pop_function("pop_icflag", parent) - return - if base == "pop_subcomp": - self._run_pop_function("pop_subcomp", parent) + if base in _SIMPLE_POP_ACTIONS: + self._run_pop_function(base, parent) return if base == "pop_eegplot": self._run_pop_function("pop_eegplot", parent, variant=variant) @@ -903,14 +865,26 @@ def _run_script(self, parent: Any | None) -> None: "EEG": self.session.EEG, "ALLEEG": self.session.ALLEEG, "CURRENTSET": self.session.current_set_value(), + "ALLCOM": list(self.session.ALLCOM), + "LASTCOM": self.session.LASTCOM, "STUDY": self.session.STUDY, + "CURRENTSTUDY": self.session.CURRENTSTUDY, } command = pop_runscript(filename, namespace) - self.session.EEG = namespace.get("EEG", self.session.EEG) - self.session.ALLEEG = namespace.get("ALLEEG", self.session.ALLEEG) - self.session.CURRENTSET = _currentset_list(namespace.get("CURRENTSET", self.session.current_set_value())) - self.session.STUDY = namespace.get("STUDY", self.session.STUDY) - self._add_history_from_gui(command) + self.session.echo_command(command) + state = { + "alleeg": namespace.get("ALLEEG", self.session.ALLEEG), + "currentset": namespace.get("CURRENTSET", self.session.current_set_value()), + "allcom": namespace.get("ALLCOM", self.session.ALLCOM), + "lastcom": namespace.get("LASTCOM", self.session.LASTCOM), + "study": namespace.get("STUDY", self.session.STUDY), + "currentstudy": namespace.get("CURRENTSTUDY", self.session.CURRENTSTUDY), + "command": command, + } + script_eeg = namespace.get("EEG", self.session.EEG) + if script_eeg is not self.session.EEG: + state["eeg"] = script_eeg + self.session.apply_workspace_state(**state) self._refresh() def _bids_tool_action(self, action: str, parent: Any | None) -> None: @@ -942,8 +916,8 @@ def _bids_tool_action(self, action: str, parent: Any | None) -> None: updated, command = getattr(bids_tools, action)(target, **metadata) if self.session.CURRENTSTUDY == 1 and self.session.STUDY: - self.session.STUDY = updated - self._add_history_from_gui(command) + self.session.echo_command(command) + self.session.set_study(updated, command=command) else: self._store_current_from_gui(updated, command=command) self._refresh() @@ -1042,17 +1016,18 @@ def _ask_extension_filename( return filename def _apply_extension_result(self, result: Any) -> None: - dataset_state = _extension_dataset_state(result) + # Imported here to break the menu_actions <-> console import cycle; the + # console owns the canonical pop-result interpretation contract. + from eegprep.functions.adminfunc.console import _extract_pop_dataset_state, _extract_pop_eeg_and_command + + dataset_state = _extract_pop_dataset_state(result) if dataset_state is not None: alleeg, eeg_out, currentset, command = dataset_state - self.session.ALLEEG = alleeg - self.session.EEG = eeg_out - self.session.CURRENTSET = _currentset_list(currentset) - self._add_history_from_gui(command) - self.session.notify_changed() + self.session.echo_command(command) + self.session.apply_workspace_state(alleeg=alleeg, eeg=eeg_out, currentset=currentset, command=command) self._refresh() return - eeg_out, command = _extension_eeg_and_command(result) + eeg_out, command = _extract_pop_eeg_and_command(result) if eeg_out is not None: self._store_current_from_gui(eeg_out, command=command) self._refresh() @@ -1147,8 +1122,14 @@ def _run_pop_function(self, name: str, parent: Any | None, *, variant: str = "") out = pop_rmdat(selection, return_com=True) elif name == "pop_runica": - from eegprep.functions.popfunc.pop_runica import pop_runica + from eegprep.functions.popfunc.pop_runica import pop_runica, pop_runica_gui_options + if parent is not None: + gui_options = pop_runica_gui_options(selection) + if gui_options is None: + return + self._run_pop_runica_long_task(selection, gui_options, parent) + return out = pop_runica(selection, return_com=True) elif name == "pop_select": from eegprep.functions.popfunc.pop_select import pop_select @@ -1291,6 +1272,59 @@ def commit_component_rejection(eeg_out: Any, _states: dict[int, bool]) -> None: self._store_current_from_gui(eeg_out, command=command) self._refresh() + def _run_pop_runica_long_task( + self, + selection: Any, + gui_options: Mapping[str, Any], + parent: Any, + ) -> None: + from eegprep.functions.popfunc.pop_runica import pop_runica + + self.session.begin_gui_action("pop_runica") + + def task() -> Any: + return pop_runica(selection, gui=False, return_com=True, **dict(gui_options)) + + def on_success(out: Any) -> None: + try: + if isinstance(out, tuple): + eeg_out, command = out[0], out[1] if len(out) > 1 else "" + else: + eeg_out, command = out, "" + if command: + self._store_current_from_gui(eeg_out, command=command) + self._refresh() + except Exception as exc: + logger.exception("EEGPrep GUI menu action failed: pop_runica") + self._warn(parent, str(exc)) + + def on_error(exc: Exception) -> None: + logger.error( + "EEGPrep GUI menu action failed: pop_runica", + exc_info=(type(exc), exc, exc.__traceback__), + ) + self._warn(parent, str(exc)) + + def on_finished(handle: LongTaskHandle) -> None: + if handle in self._long_tasks: + self._long_tasks.remove(handle) + self.session.end_gui_action("pop_runica") + + try: + handle = run_long_task( + parent=parent, + title="Running ICA decomposition", + label="Running ICA decomposition. This may take several minutes.", + task=task, + on_success=on_success, + on_error=on_error, + on_finished=on_finished, + ) + except Exception: + self.session.end_gui_action("pop_runica") + raise + self._long_tasks.append(handle) + def _run_browser_accept_pop_action( self, name: str, @@ -1360,11 +1394,8 @@ def _copy_current_dataset(self, parent: Any | None) -> None: alleeg, eeg_out, current_set, command = pop_copyset(self.session.ALLEEG, set_in, gui=True, return_com=True) if not command: return - self.session.ALLEEG = alleeg - self.session.EEG = eeg_out - self.session.CURRENTSET = _currentset_list(current_set) - self._add_history_from_gui(command) - self.session.notify_changed() + self.session.echo_command(command) + self.session.apply_workspace_state(alleeg=alleeg, eeg=eeg_out, currentset=current_set, command=command) self._refresh() def _merge_datasets(self, parent: Any | None) -> None: @@ -1417,34 +1448,36 @@ def _run_dipfit_function(self, name: str, parent: Any | None) -> None: if name == "pop_dipfit_headmodel": from eegprep.plugins.dipfit.pop_dipfit_headmodel import pop_dipfit_headmodel - pop_dipfit_headmodel(selection, return_com=True) - return - if name == "pop_dipfit_gridsearch": + out = pop_dipfit_headmodel(selection, return_com=True) + elif name == "pop_dipfit_gridsearch": from eegprep.plugins.dipfit.pop_dipfit_gridsearch import pop_dipfit_gridsearch - pop_dipfit_gridsearch(selection, return_com=True) - return - if name == "pop_dipfit_nonlinear": + out = pop_dipfit_gridsearch(selection, return_com=True) + elif name == "pop_dipfit_nonlinear": from eegprep.plugins.dipfit.pop_dipfit_nonlinear import pop_dipfit_nonlinear - pop_dipfit_nonlinear(selection, return_com=True) - return - if name == "pop_multifit": + out = pop_dipfit_nonlinear(selection, return_com=True) + elif name == "pop_multifit": from eegprep.plugins.dipfit.pop_multifit import pop_multifit - pop_multifit(selection, return_com=True) - return - if name == "pop_leadfield": + out = pop_multifit(selection, return_com=True) + elif name == "pop_leadfield": from eegprep.plugins.dipfit.pop_leadfield import pop_leadfield - pop_leadfield(selection, return_com=True) - return - if name == "pop_dipfit_loreta": + out = pop_leadfield(selection, return_com=True) + elif name == "pop_dipfit_loreta": from eegprep.plugins.dipfit.pop_dipfit_loreta import pop_dipfit_loreta - pop_dipfit_loreta(selection, return_com=True) + out = pop_dipfit_loreta(selection, return_com=True) + else: + self.show_coming_soon(name, parent) return - self.show_coming_soon(name, parent) + if not isinstance(out, tuple): + return + eeg_out, command = out[0], out[1] if len(out) > 1 else "" + if command: + self._store_current_from_gui(eeg_out, command=command) + self._refresh() def _plot_channel_locations(self, variant: str, parent: Any | None) -> None: selection = self._current_selection_or_warn(parent) @@ -1528,7 +1561,12 @@ def _run_plot_function(self, name: str, variant: str, parent: Any | None) -> Non else: self.show_coming_soon(name, parent) return - self._add_history_from_gui(command) + if name == "pop_headplot": + # pop_headplot attaches the spline file in place, so commit the edited + # dataset through the session instead of only recording history. + self._store_current_from_gui(selection, command=command) + else: + self._add_history_from_gui(command) self._refresh() def _run_chanplot(self, parent: Any | None) -> None: @@ -1540,8 +1578,8 @@ def _run_chanplot(self, parent: Any | None) -> None: study, command, _figure = pop_chanplot(self.session.STUDY, self.session.ALLEEG, gui=True, return_com=True) if not command: return - self.session.STUDY = study - self._add_history_from_gui(command) + self.session.echo_command(command) + self.session.set_study(study, command=command) self._refresh() def _store_current_from_gui(self, eeg: Any, **kwargs: Any) -> Any: @@ -1558,9 +1596,6 @@ def _commit_processed_dataset_from_gui(self, eeg: Any, *, command: str, parent: if isinstance(eeg, list): for offset, _dataset in enumerate(eeg, start=1): logger.info("Processing group dataset %s of %s.", offset, len(eeg)) - for dataset in eeg if isinstance(eeg, list) else [eeg]: - if isinstance(dataset, dict): - eegh(command, dataset) alleeg, current, current_set, newset_command = pop_newset( self.session.ALLEEG, eeg, @@ -1574,14 +1609,13 @@ def _commit_processed_dataset_from_gui(self, eeg: Any, *, command: str, parent: if old_selection: self.session.retrieve(old_selection if len(old_selection) > 1 else old_selection[0]) return + for dataset in current if isinstance(current, list) else [current]: + if isinstance(dataset, dict): + eegh(command, dataset) self.session.echo_command(command) self.session.add_history(command, notify=False) self.session.echo_command(newset_command) - self.session.ALLEEG = alleeg - self.session.EEG = current - self.session.CURRENTSET = _currentset_list(current_set) - self.session.add_history(newset_command, notify=False) - self.session.notify_changed() + self.session.apply_workspace_state(alleeg=alleeg, eeg=current, currentset=current_set, command=newset_command) def _add_history_from_gui(self, command: str | None) -> None: self.session.echo_command(command) @@ -1592,7 +1626,7 @@ def _retrieve_dataset(self, index: int) -> None: self.session.retrieve(index) command = f"[ALLEEG EEG CURRENTSET] = pop_newset(ALLEEG, EEG, CURRENTSET, 'retrieve', {index});" if was_study: - self.session.CURRENTSTUDY = 0 + self.session.apply_workspace_state(currentstudy=0) command = f"CURRENTSTUDY = 0;{command}" self._add_history_from_gui(command) self._refresh() @@ -1713,59 +1747,6 @@ def _icacomp_from_variant(variant: str) -> int: return 0 if variant == "ica" else 1 -def _currentset_list(value: Any) -> list[int]: - if value is None: - return [] - if isinstance(value, str): - if value == "": - return [] - current = int(value) - return [current] if current > 0 else [] - if isinstance(value, bool): - return [1] if value else [] - if isinstance(value, (int, float)): - current = int(value) - return [current] if current > 0 else [] - if isinstance(value, Iterable) and not isinstance(value, (str, bytes)): - return [int(item) for item in value if int(item) > 0] - current = int(value) - return [current] if current > 0 else [] - - -def _extension_dataset_state(result: Any) -> tuple[list[dict[str, Any]], Any, Any, str] | None: - if not isinstance(result, tuple) or len(result) < 4: - return None - alleeg, eeg, currentset, command = result[0], result[1], result[2], result[3] - if not isinstance(alleeg, list) or not _is_eeg_selection(eeg) or not isinstance(command, str): - return None - return alleeg, eeg, currentset, command.strip() - - -def _extension_eeg_and_command(result: Any) -> tuple[Any | None, str]: - if isinstance(result, tuple): - command = str(result[1]).strip() if len(result) > 1 and isinstance(result[1], str) else "" - if _history_only_command(command): - return None, command - if result and _is_eeg_selection(result[0]): - return result[0], command - return None, command - if _is_eeg_selection(result): - return result, "" - if isinstance(result, str): - return None, result.strip() - return None, "" - - -def _is_eeg_selection(value: Any) -> bool: - if isinstance(value, dict): - return "data" in value and any(key in value for key in _EEG_CORE_FIELDS) - return isinstance(value, list) and bool(value) and all(_is_eeg_selection(item) for item in value) - - -def _history_only_command(command: str) -> bool: - return command.lstrip().startswith("LASTCOM") - - def _extension_file_parameter(parameters: Mapping[str, inspect.Parameter]) -> str | None: for name in _EXTENSION_FILE_PARAMETERS: parameter = parameters.get(name) diff --git a/src/eegprep/functions/guifunc/qt.py b/src/eegprep/functions/guifunc/qt.py index 3dc2ac8b..cecf1d49 100644 --- a/src/eegprep/functions/guifunc/qt.py +++ b/src/eegprep/functions/guifunc/qt.py @@ -470,6 +470,11 @@ def _connect_callback(self, callback: CallbackSpec | None, widgets: dict[str, An source = widgets.get(params["button"]) if source is not None: source.clicked.connect(lambda: self._show_callback_message(source, params)) + elif callback.name == "edit_text": + source = widgets.get(params["button"]) + target = widgets.get(params.get("target", params["button"])) + if source is not None and target is not None: + source.clicked.connect(lambda: self._edit_text(source, target, params)) elif callback.name == "open_eegplot": source = widgets.get(params["button"]) if source is not None: @@ -1114,6 +1119,20 @@ def _show_callback_message(parent: Any, params: Mapping[str, Any]) -> None: _qt_core, qt_widgets = _require_qt() qt_widgets.QMessageBox.information(parent, str(params.get("title", "EEGPrep")), str(params.get("message", ""))) + @staticmethod + def _edit_text(parent: Any, target: Any, params: Mapping[str, Any]) -> None: + _qt_core, qt_widgets = _require_qt() + stored_value = target.property(_VALUE_PROPERTY) + current = stored_value if stored_value is not None else params.get("value", "") + value, accepted = qt_widgets.QInputDialog.getMultiLineText( + parent, + str(params.get("title", "Edit text")), + str(params.get("label", "Text")), + str(current), + ) + if accepted: + target.setProperty(_VALUE_PROPERTY, str(value)) + @staticmethod def _select_interp_channels(button: Any, target: Any, params: Mapping[str, Any]) -> None: source = str(params.get("source", "")).lower() diff --git a/src/eegprep/functions/guifunc/session.py b/src/eegprep/functions/guifunc/session.py index f2028ceb..951081ab 100644 --- a/src/eegprep/functions/guifunc/session.py +++ b/src/eegprep/functions/guifunc/session.py @@ -18,6 +18,9 @@ from eegprep.functions.popfunc.eeg_emptyset import eeg_emptyset +_UNSET = object() + + def has_eeg_data(eeg: Any) -> bool: """Return whether an EEG-like object contains non-empty data.""" if not isinstance(eeg, dict): @@ -223,6 +226,65 @@ def retrieve(self, indices: int | list[int]) -> dict[str, Any] | list[dict[str, self.notify_changed() return eeg + def apply_workspace_state( + self, + *, + eeg: Any = _UNSET, + alleeg: Any = _UNSET, + currentset: Any = _UNSET, + allcom: Any = _UNSET, + lastcom: Any = _UNSET, + study: Any = _UNSET, + currentstudy: Any = _UNSET, + command: str = "", + append_dataset_history: bool = False, + ) -> None: + """Apply a GUI/console workspace update as one session transaction.""" + dataset_changed = eeg is not _UNSET or alleeg is not _UNSET or currentset is not _UNSET + if dataset_changed: + resolved_alleeg = self.ALLEEG if alleeg is _UNSET else alleeg + if not isinstance(resolved_alleeg, list): + raise ValueError("ALLEEG must be a list of EEG datasets") + resolved_currentset = ( + list(self.CURRENTSET) + if currentset is _UNSET + else normalize_dataset_indices(currentset, allow_empty=True) + ) + if resolved_currentset and max(resolved_currentset) > len(resolved_alleeg): + raise ValueError("CURRENTSET contains indices outside ALLEEG") + resolved_eeg = self._resolve_workspace_eeg(eeg, resolved_alleeg, resolved_currentset) + current = resolved_eeg if isinstance(resolved_eeg, list) else [resolved_eeg] + if resolved_currentset and len(current) != len(resolved_currentset): + raise ValueError("EEG selection length must match CURRENTSET") + self.ALLEEG = resolved_alleeg + self.EEG = resolved_eeg + self.CURRENTSET = resolved_currentset + self._mirror_current_eeg_into_alleeg() + if append_dataset_history: + self._append_current_dataset_history(command) + offload_storedisk_datasets(self.ALLEEG, set(self.CURRENTSET)) + + if allcom is not _UNSET: + if not isinstance(allcom, list): + raise ValueError("ALLCOM must be a list of command strings") + self.ALLCOM = [str(item) for item in allcom if str(item).strip()] + self.LASTCOM = self.ALLCOM[-1] if self.ALLCOM else "" + if lastcom is not _UNSET: + last_command = str(lastcom or "").strip() + if last_command and (not self.ALLCOM or self.ALLCOM[-1] != last_command): + self.ALLCOM.append(last_command) + self.LASTCOM = last_command + + if study is not _UNSET: + self.STUDY = study + if currentstudy is _UNSET: + self.CURRENTSTUDY = 1 if study else 0 + if currentstudy is not _UNSET: + self.CURRENTSTUDY = int(currentstudy or 0) + + self.add_history(command, notify=False) + self.notify_changed() + def delete_current(self) -> None: """Delete the current dataset selection from memory.""" if not self.CURRENTSET: @@ -271,6 +333,29 @@ def set_study( self.add_history(command, notify=False) self.notify_changed() + def _resolve_workspace_eeg( + self, + eeg: Any, + alleeg: list[dict[str, Any]], + currentset: list[int], + ) -> dict[str, Any] | list[dict[str, Any]]: + if eeg is not _UNSET: + return eeg + if not currentset: + return eeg_emptyset() + selected = [alleeg[index - 1] for index in currentset] + return selected if len(selected) > 1 else selected[0] + + def _mirror_current_eeg_into_alleeg(self) -> None: + if not self.CURRENTSET: + return + current = self.EEG if isinstance(self.EEG, list) else [self.EEG] + if len(current) != len(self.CURRENTSET): + raise ValueError("EEG selection length must match CURRENTSET") + for index, eeg in zip(self.CURRENTSET, current): + if 1 <= index <= len(self.ALLEEG): + self.ALLEEG[index - 1] = eeg + def select_study(self, *, command: str = "CURRENTSTUDY = 1") -> None: """Select the current STUDY set in the shared workspace.""" if not self.STUDY: diff --git a/src/eegprep/functions/miscfunc/eeg_eeg2mne.py b/src/eegprep/functions/miscfunc/eeg_eeg2mne.py index 574353c2..0d026766 100644 --- a/src/eegprep/functions/miscfunc/eeg_eeg2mne.py +++ b/src/eegprep/functions/miscfunc/eeg_eeg2mne.py @@ -1,13 +1,13 @@ """EEG to MNE conversion functions.""" -from ..popfunc.pop_loadset import pop_loadset -import mne +from pathlib import Path import tempfile -import os + +import mne + from ..popfunc.pop_saveset import pop_saveset # in development -# write a funtion that converts a MNE raw object to an EEGLAB set file def eeg_eeg2mne(EEG): """Convert EEG data structure to MNE Raw object. @@ -21,34 +21,10 @@ def eeg_eeg2mne(EEG): raw : mne.io.Raw MNE Raw object """ - # Generate a temporary file name - with tempfile.NamedTemporaryFile(delete=False) as temp_file: - temp_file_path = temp_file.name - - base, _ = os.path.splitext(temp_file_path) - new_temp_file_path = base + ".set" - - # save the raw file as a new EEGLAB .set file using MNE EEGLAB writer - pop_saveset(EEG, new_temp_file_path) - - # load the EEGLAB set file - if EEG['trials'] > 1: - raw = mne.io.read_epochs_eeglab(new_temp_file_path) - else: - raw = mne.io.read_raw_eeglab(new_temp_file_path, preload=True) - - return raw - - -def test_eeg_eeg2mne(): - """Test the eeg_eeg2mne function.""" - eeglab_file_path = './eeglab_data_with_ica_tmp.set' - eeglab_file_path = '/System/Volumes/Data/data/matlab/eeglab/sample_data/eeglab_data_epochs_ica.set' - EEG = pop_loadset(eeglab_file_path) - raw = eeg_eeg2mne(EEG) - - # print the keys of the EEG dictionary - print(raw.info) - + with tempfile.TemporaryDirectory(prefix="eegprep-eeg2mne-") as temp_dir: + set_path = Path(temp_dir) / "bridge.set" + pop_saveset(EEG, str(set_path)) -# test_eeg_eeg2mne() + if EEG['trials'] > 1: + return mne.io.read_epochs_eeglab(str(set_path)) + return mne.io.read_raw_eeglab(str(set_path), preload=True) diff --git a/src/eegprep/functions/miscfunc/eeg_mne2eeg.py b/src/eegprep/functions/miscfunc/eeg_mne2eeg.py index 0e14c253..6f0eb1a2 100644 --- a/src/eegprep/functions/miscfunc/eeg_mne2eeg.py +++ b/src/eegprep/functions/miscfunc/eeg_mne2eeg.py @@ -1,11 +1,12 @@ """MNE to EEG conversion functions.""" -from ..popfunc.pop_loadset import pop_loadset -import mne +from pathlib import Path import tempfile -import os + +import mne from mne.export import export_raw, export_epochs -import numpy as np + +from ..popfunc.pop_loadset import pop_loadset def _mne_events_to_eeglab_events(raw_or_epochs): @@ -38,7 +39,6 @@ def _mne_events_to_eeglab_events(raw_or_epochs): return events -# write a funtion that converts a MNE raw object to an EEGLAB set file def eeg_mne2eeg(raw): """Convert MNE Raw object to EEG data structure. @@ -54,49 +54,16 @@ def eeg_mne2eeg(raw): """ raw_or_epochs = raw - # Generate a temporary file name - with tempfile.NamedTemporaryFile(delete=False) as temp_file: - temp_file_path = temp_file.name - - base, _ = os.path.splitext(temp_file_path) - new_temp_file_path = base + ".set" - - # save the raw/epochs file as a new EEGLAB .set file using MNE EEGLAB writer - if isinstance(raw_or_epochs, mne.BaseEpochs): - export_epochs(new_temp_file_path, raw_or_epochs, fmt='eeglab') - else: - export_raw(new_temp_file_path, raw_or_epochs, fmt='eeglab') - - # load the EEGLAB set file - EEG = pop_loadset(new_temp_file_path) + with tempfile.TemporaryDirectory(prefix="eegprep-mne2eeg-") as temp_dir: + set_path = Path(temp_dir) / "bridge.set" + if isinstance(raw_or_epochs, mne.BaseEpochs): + export_epochs(str(set_path), raw_or_epochs, fmt='eeglab') + else: + export_raw(str(set_path), raw_or_epochs, fmt='eeglab') + EEG = pop_loadset(str(set_path)) # Inject events/annotations from MNE object into EEGLAB structure eeglab_events = _mne_events_to_eeglab_events(raw_or_epochs) if eeglab_events: EEG['event'] = eeglab_events - return EEG - - -def test_eeg_mne2eeg(): - """Test the eeg_mne2eeg function.""" - eeglab_file_path = './eeglab_data_with_ica_tmp.set' - eeglab_file_path = '/System/Volumes/Data/data/matlab/eeglab/sample_data/eeglab_data_epochs_ica.set' - EEG = pop_loadset(eeglab_file_path) - - # create MNE info structure - info = mne.create_info(ch_names=[x['labels'] for x in EEG['chanlocs']], sfreq=EEG['srate'], ch_types='eeg') - if EEG['trials'] > 1: - events = np.array([[i, 0, 1] for i in range(EEG['trials'])]) # NOT CORRECT CONVERTION JUST FOR TESTING - event_id = dict(dummy=1) - raw = mne.EpochsArray(EEG['data'].transpose(2, 0, 1), info, events, tmin=0, event_id=event_id) - else: - raw = mne.io.RawArray(EEG['data'], info) - - EEG2 = eeg_mne2eeg(raw) - - # print the keys of the EEG dictionary - print(EEG2.keys()) - - -# test_eeg_mne2eeg() diff --git a/src/eegprep/functions/miscfunc/eeg_mne2eeg_epochs.py b/src/eegprep/functions/miscfunc/eeg_mne2eeg_epochs.py index 876dda17..808f8481 100644 --- a/src/eegprep/functions/miscfunc/eeg_mne2eeg_epochs.py +++ b/src/eegprep/functions/miscfunc/eeg_mne2eeg_epochs.py @@ -1,17 +1,15 @@ """MNE epochs to EEGLAB dataset conversion utilities.""" -# Example to export MNE epochs to EEGLAB dataset -# Events are not handled correctly in this example but it works - -import mne -from mne.preprocessing import ICA +import logging import math import numpy as np -from scipy.io import savemat + +from eegprep.functions.miscfunc.misc import finite_matmul, finite_pinv + +logger = logging.getLogger(__name__) -# Load example data def eeg_mne2eeg_epochs(epochs, ica): """Convert MNE epochs with ICA to EEGLAB dataset format. @@ -27,29 +25,27 @@ def eeg_mne2eeg_epochs(epochs, ica): dict EEGLAB-compatible dataset dictionary. """ - # export to EEGLAB dataset - data = epochs.get_data() # Get the data from the epochs - n_epochs, n_channels, n_times = data.shape - ica_weights = ica.get_components() # ICA weights (n_components x n_channels) - - # create identity matrix of size n_channels x n_channels - ica_sphere = np.eye(n_channels) # ICA sphere (n_channels x n_channels) - - # Compute the mixing matrix (inverse weights) - ica_inverse_weights = np.linalg.pinv(ica_weights) # Shape: (n_channels, n_components) + mne_data = epochs.get_data(copy=True) + n_epochs, n_channels, n_times = mne_data.shape + data = np.transpose(mne_data, (1, 2, 0)) ica_channels = ica.info['ch_names'] raw_channels = epochs.info['ch_names'] # Assuming you have the raw object ica_channel_indices = [raw_channels.index(ch) for ch in ica_channels] ica_channel_indices = np.array(ica_channel_indices) - ica_act = ica.get_sources(epochs).get_data(copy=True).transpose(1, 2, 0) # Get the ICA activations + ica_weights, ica_sphere, ica_inverse_weights, ica_act = _mne_ica_to_eeglab_fields( + ica, + data[ica_channel_indices], + n_times, + n_epochs, + ) - print('Reference conversion may not be accurate...') if 'custom_ref_applied' in epochs.info and epochs.info['custom_ref_applied']: ref = 'common' # Custom reference was applied else: ref = 'average' # Default to average reference + logger.info("MNE reference metadata converted to EEGPrep ref=%s.", ref) eeglab_dict = { 'setname': '', @@ -70,8 +66,8 @@ def eeg_mne2eeg_epochs(epochs, ica): 'data': data, 'icaact': ica_act, 'icawinv': ica_inverse_weights, - 'icasphere': ica_weights, - 'icaweights': ica_sphere, + 'icasphere': ica_sphere, + 'icaweights': ica_weights, 'icachansind': ica_channel_indices, 'chanlocs': np.array([]), 'urchanlocs': np.array([]), @@ -111,21 +107,27 @@ def eeg_mne2eeg_epochs(epochs, ica): Y_all = [] Z_all = [] for ch in ch_locs: - if 'loc' in ch and ch['loc'] is not None: - X_all.append(ch['loc'][1] * 1000) - Y_all.append(-ch['loc'][0] * 1000) - Z_all.append(ch['loc'][2] * 1000) - hypotxy = math.hypot(X_all[-1], Y_all[-1]) - sph_radius_all.append(math.hypot(hypotxy, Z_all[-1])) - - az = math.atan2(Y_all[-1], X_all[-1]) / math.pi * 180 - horiz = math.atan2(Z_all[-1], hypotxy) / math.pi * 180 - - sph_theta_all.append(az) - sph_phi_all.append(horiz) - - theta_all.append(-az) # warning inverse notation compared to MATLAB to match - radius_all.append(0.5 - horiz / 180) # warning inverse notation compared to MATLAB to match + loc = ch.get('loc') if isinstance(ch, dict) else None + if loc is None or len(loc) < 3: + x = y = z = 0.0 + else: + x = float(loc[1]) * 1000 + y = -float(loc[0]) * 1000 + z = float(loc[2]) * 1000 + X_all.append(x) + Y_all.append(y) + Z_all.append(z) + hypotxy = math.hypot(x, y) + sph_radius_all.append(math.hypot(hypotxy, z)) + + az = math.atan2(y, x) / math.pi * 180 + horiz = math.atan2(z, hypotxy) / math.pi * 180 + + sph_theta_all.append(az) + sph_phi_all.append(horiz) + + theta_all.append(-az) # warning inverse notation compared to MATLAB to match + radius_all.append(0.5 - horiz / 180) # warning inverse notation compared to MATLAB to match d_list = [ { @@ -166,43 +168,31 @@ def eeg_mne2eeg_epochs(epochs, ica): d_list = np.array(d_list) eeglab_dict['chanlocs'] = d_list - # # Step 4: Save the EEGLAB dataset as a .mat file return eeglab_dict - # print("EEGLAB dataset saved successfully!") - - -def test_eeg_mne2eeg_epochs(): - """Test the eeg_mne2eeg_epochs function with sample MNE data.""" - sample_data_folder = mne.datasets.sample.data_path() - sample_data_raw_file = sample_data_folder / "MEG" / "sample" / "sample_audvis_filt-0-40_raw.fif" - - raw = mne.io.read_raw_fif(sample_data_raw_file) - - # extract data epochs - events = mne.find_events(raw, stim_channel="STI 014") - event_dict = { - "auditory/left": 1, - "auditory/right": 2, - "visual/left": 3, - "visual/right": 4, - "smiley": 5, - "buttonpress": 32, - } - epochs = mne.Epochs( - raw, - events, - event_id=event_dict, - tmin=-0.2, - tmax=0.5, - preload=True, - ) - - ica = ICA(n_components=15, random_state=97, max_iter=800) - ica.fit(raw) - - EEG = eeg_mne2eeg_epochs(epochs, ica) - savemat('output_file.mat', EEG) # use pop_saveset - -# test_eeg_mne2eeg_epochs() +def _mne_ica_to_eeglab_fields(ica, data, n_times, n_epochs): + n_components = int(ica.n_components_) + n_ica_channels = data.shape[0] + prewhitener = _prewhitener_matrix(ica, n_ica_channels) + pca_unmixing = finite_matmul(np.asarray(ica.unmixing_matrix_), np.asarray(ica.pca_components_)[:n_components]) + unmixing = finite_matmul(pca_unmixing, prewhitener) + sphere = np.eye(n_ica_channels) + inverse_weights = finite_pinv(unmixing) + activations_2d = finite_matmul(unmixing, data.reshape(n_ica_channels, -1, order="F")) + activations = activations_2d.reshape(n_components, n_times, n_epochs, order="F") + return unmixing, sphere, inverse_weights, activations + + +def _prewhitener_matrix(ica, n_channels): + prewhitener = np.asarray(ica.pre_whitener_) + if ica.noise_cov is not None: + if prewhitener.shape != (n_channels, n_channels): + raise ValueError("MNE ICA pre-whitener has incompatible shape") + return prewhitener + values = prewhitener.reshape(-1) + if values.size == 1: + return np.eye(n_channels) / float(values[0]) + if values.size != n_channels: + raise ValueError("MNE ICA pre-whitener has incompatible shape") + return np.diag(1.0 / values) diff --git a/src/eegprep/functions/miscfunc/save_struct_as_hdf5.py b/src/eegprep/functions/miscfunc/save_struct_as_hdf5.py index 7567697c..749b703d 100644 --- a/src/eegprep/functions/miscfunc/save_struct_as_hdf5.py +++ b/src/eegprep/functions/miscfunc/save_struct_as_hdf5.py @@ -52,21 +52,3 @@ def save_dict_to_hdf5(data, filename, dataset_name): # Save to HDF5 with h5py.File(filename, 'w') as hdf: hdf.create_dataset(dataset_name, data=structured_data) - - -if __name__ == '__main__': - data = { - 'labels': 'FPz', - 'theta': np.array([0, 1, 2, 3]), - 'radius': 0.5066888888888889, - 'X': 84.98123361344625, - 'Y': 0, - 'Z': -1.7860385037488253, - 'sph_theta': 0, - 'sph_phi': -1.203999999999994, - 'sph_radius': 85, - 'type': 'EEG', - 'urchan': 1, - 'ref': None, - } - save_dict_to_hdf5(data, 'data.h5', 'dataset_name') diff --git a/src/eegprep/functions/popfunc/_eegplot_rejection.py b/src/eegprep/functions/popfunc/_eegplot_rejection.py index c1e4ca4c..a4a9e129 100644 --- a/src/eegprep/functions/popfunc/_eegplot_rejection.py +++ b/src/eegprep/functions/popfunc/_eegplot_rejection.py @@ -2,22 +2,114 @@ from __future__ import annotations -from typing import Any +from typing import Any, Callable import numpy as np from eegprep.functions.popfunc._chanutils import chanlocs_as_list -from eegprep.functions.popfunc._rejection import copy_eeg, reject_field_names, update_reject_fields -from eegprep.functions.popfunc.pop_eegplot import DEFAULT_REJECTION_COLORS, MANUAL_REJECTION_COLOR +from eegprep.functions.popfunc._rejection import ( + copy_eeg, + one_based_indices, + parse_numeric_sequence, + reject_field_names, + rejection_data, + update_reject_fields, +) +from eegprep.functions.popfunc.pop_eegplot import ( + DEFAULT_REJECTION_COLORS, + MANUAL_REJECTION_COLOR, + REJECTION_FAMILIES, + pad_rejection_rows, + rejection_row_count, +) from eegprep.functions.popfunc.pop_rejepoch import pop_rejepoch from eegprep.functions.sigprocfunc.eegplot import eegplot, eegplot2trial, trial2eegplot -DISPLAY_REJECTION_FAMILIES = ("manual", "thresh", "const", "jp", "kurt", "freq") # Autorej shares manual-color browser marks and is not superposed as a separate EEGLAB family. _AUTO_REJECTION_COLOR = MANUAL_REJECTION_COLOR +def run_epoched_rejection( + EEG: dict[str, Any], + icacomp: int | bool, + elecrange: Any, + locthresh: Any, + globthresh: Any, + superpose: int | bool, + reject: int | bool, + vistype: int, + *, + marks_fn: Callable[..., tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]], + kind: str, + stats_local_field: str, + stats_global_field: str, + stats_local_field_ica: str, + stats_global_field_ica: str, + error_message: str, + command_fn: Callable[[list[int]], str], + display: bool = False, + command_callback: Any | None = None, + show: bool = True, +) -> tuple[dict[str, Any], list[float], list[float], list[int], str]: + """Run a local/global epoched-rejection method and store marks and stats. + + Shared scaffold for the per-method ``pop_*`` epoched rejection wrappers + (kurtosis, joint probability, ...). The method-specific pieces are supplied + by ``marks_fn``, ``kind``, the stats field names, the error message, and + ``command_fn``, which builds the history string from the normalized + 1-based electrode/component range. + """ + out = copy_eeg(EEG) + data, row_count = rejection_data(out, icacomp) + if int(out.get("trials", data.shape[2]) or data.shape[2]) <= 1: + raise ValueError(error_message) + elecrange = one_based_indices(elecrange, limit=row_count, default_all=True) + marks, marks_e, local_scores, global_scores = marks_fn(data, elecrange, locthresh, globthresh) + out.setdefault("stats", {}) + if int(bool(icacomp)): + out["stats"][stats_local_field] = local_scores + out["stats"][stats_global_field] = global_scores + else: + out["stats"][stats_local_field_ica] = local_scores + out["stats"][stats_global_field_ica] = global_scores + update_reject_fields(out, icacomp=icacomp, kind=kind, reject=marks, reject_e=marks_e) + rejected = (np.flatnonzero(marks) + 1).tolist() + command = command_fn(elecrange) + if display: + open_epoched_rejection_browser( + out, + data=data, + icacomp=icacomp, + elecrange=elecrange, + kind=kind, + superpose=superpose, + reject=reject, + command=command, + command_callback=command_callback, + show=show, + ) + elif int(bool(reject)) and rejected: + out = pop_rejepoch(out, rejected, 0) + return ( + out, + parse_numeric_sequence(locthresh, dtype=float), + parse_numeric_sequence(globthresh, dtype=float), + rejected, + command, + ) + + +def vistype_from_gui(value: Any) -> int: + """Map an EEGLAB visualization-mode popup value to a vistype flag.""" + if isinstance(value, str): + return 0 if value.strip().lower() in {"rejecttrials", "reject trials", "0"} else 1 + try: + return 0 if int(value) == 1 else 1 + except (TypeError, ValueError): + return 1 + + def open_epoched_rejection_browser( EEG: dict[str, Any], *, @@ -165,7 +257,7 @@ def apply_epoched_rejection_browser( trials = int(out.get("trials", 1) or 1) pnts = int(out.get("pnts", np.asarray(out.get("data")).shape[1])) if row_count is None: - row_count = _row_count(out, icacomp) + row_count = rejection_row_count(out, icacomp) selected_rows = _selected_row_indices(elecrange, row_count) rows = _as_winrej_rows(winrej) if int(superpose) == 2: @@ -243,7 +335,7 @@ def _family_rows( ) -> np.ndarray: field, field_e = reject_field_names(icacomp, kind) marks = _trial_marks(reject.get(field), trials) - row_marks = _row_marks(reject.get(field_e), row_count, trials)[selected_rows, :] + row_marks = pad_rejection_rows(reject.get(field_e), row_count, trials)[selected_rows, :] return trial2eegplot(marks, row_marks, pnts, color) @@ -257,7 +349,7 @@ def _store_manual_marks( field, field_e = reject_field_names(icacomp, "rejmanual") reject = EEG.setdefault("reject", {}) current = _trial_marks(reject.get(field), np.asarray(trial_marks).size) - current_e = _row_marks(reject.get(field_e), row_marks.shape[0], np.asarray(trial_marks).size) + current_e = pad_rejection_rows(reject.get(field_e), row_marks.shape[0], np.asarray(trial_marks).size) reject[field] = current | np.asarray(trial_marks, dtype=bool) reject[field_e] = current_e | np.asarray(row_marks, dtype=bool) reject.setdefault("rejmanualcol", np.asarray(MANUAL_REJECTION_COLOR, dtype=float)) @@ -297,32 +389,11 @@ def _trial_marks(value: Any, trials: int) -> np.ndarray: return out -def _row_marks(value: Any, row_count: int, trials: int) -> np.ndarray: - out = np.zeros((row_count, trials), dtype=bool) - marks = np.asarray(value if value is not None else [], dtype=bool) - if marks.ndim == 1 and marks.size: - marks = marks.reshape(1, -1) - if marks.ndim == 2: - rows = min(row_count, marks.shape[0]) - cols = min(trials, marks.shape[1]) - out[:rows, :cols] = marks[:rows, :cols] - return out - - -def _row_count(EEG: dict[str, Any], icacomp: int | bool) -> int: - if int(bool(icacomp)): - return int(EEG.get("nbchan", np.asarray(EEG.get("data")).shape[0]) or 0) - weights = np.asarray(EEG.get("icaweights", [])) - return int(weights.shape[0]) if weights.ndim == 2 else 0 - - def _displayed_families(reject: dict[str, Any]) -> tuple[str, ...]: disprej = reject.get("disprej") if disprej is not None and np.asarray(disprej, dtype=object).size: - return tuple( - str(item) for item in np.asarray(disprej, dtype=object).ravel() if str(item) in DISPLAY_REJECTION_FAMILIES - ) - return tuple(family for family in DISPLAY_REJECTION_FAMILIES if _has_family_marks(reject, family)) + return tuple(str(item) for item in np.asarray(disprej, dtype=object).ravel() if str(item) in REJECTION_FAMILIES) + return tuple(family for family in REJECTION_FAMILIES if _has_family_marks(reject, family)) def _has_family_marks(reject: dict[str, Any], family: str) -> bool: diff --git a/src/eegprep/functions/popfunc/_file_io.py b/src/eegprep/functions/popfunc/_file_io.py index 8d9129a4..a2db4a2f 100644 --- a/src/eegprep/functions/popfunc/_file_io.py +++ b/src/eegprep/functions/popfunc/_file_io.py @@ -120,7 +120,10 @@ def eeg_from_data( else: raise ValueError("nbchan does not match imported data") if not nbchan and array.ndim == 2 and array.shape[0] > array.shape[1]: - array = array.T + raise ValueError( + "Cannot determine channel-major orientation for a 2-D array with more " + "rows than columns; pass nbchan to specify the number of channels." + ) if pnts and array.ndim == 2: pnts = int(pnts) if pnts > 0 and array.shape[1] % pnts == 0: diff --git a/src/eegprep/functions/popfunc/_ica_utils.py b/src/eegprep/functions/popfunc/_ica_utils.py index 779a9bff..f0738664 100644 --- a/src/eegprep/functions/popfunc/_ica_utils.py +++ b/src/eegprep/functions/popfunc/_ica_utils.py @@ -11,3 +11,45 @@ def reshape_ica_activations(data, pnts, trials): """Reshape 2-D ICA activations back to EEGLAB's channel x point x trial form.""" array = np.asarray(data) return array.reshape(array.shape[0], int(pnts), int(trials), order="F") + + +def finalize_ica_fields(EEG, *, sortcomps='off', posact='off'): + """Apply optional component sorting and sign normalization to ICA fields. + + Operates in place on ``EEG['icaact']``, ``EEG['icaweights']``, and + ``EEG['icawinv']`` and returns ``EEG``. Shared by the runica, AMICA, and + Picard backends so the post-decomposition behavior stays identical. + """ + # Optionally sort components by mean descending activation variance + if sortcomps in ('on', True): + # Flatten icaact to 2D for variance computation + icaact_2d = flatten_ica_data(EEG['icaact']) + # Compute variance metric: sum(icawinv^2) .* sum(icaact^2) + variance_metric = np.sum(EEG['icawinv'] ** 2, axis=0) * np.sum(icaact_2d**2, axis=1) + # Sort indices in descending order + windex = np.argsort(variance_metric)[::-1] + # Reorder components + EEG['icaact'] = EEG['icaact'][windex, :, :] + EEG['icaweights'] = EEG['icaweights'][windex, :] + EEG['icawinv'] = EEG['icawinv'][:, windex] + + # Optionally normalize components using the same rule as runica() + if posact in ('on', True): + # Flatten icaact to 2D for finding max abs values + icaact_2d = flatten_ica_data(EEG['icaact']) + # Find indices of max absolute values for each component + ix = np.argmax(np.abs(icaact_2d), axis=1) + ncomps = EEG['icaact'].shape[0] + + for r in range(ncomps): + if np.sign(icaact_2d[r, ix[r]]) < 0: + # A sign flip commutes through the factorization, so negate the + # matching row of icaweights and column of icawinv directly. This + # preserves the invariants icawinv == pinv(icaweights @ icasphere) + # and icaact == icaweights @ icasphere @ data, leaving icasphere + # untouched. + EEG['icaact'][r, :, :] = -EEG['icaact'][r, :, :] + EEG['icawinv'][:, r] = -EEG['icawinv'][:, r] + EEG['icaweights'][r, :] = -EEG['icaweights'][r, :] + + return EEG diff --git a/src/eegprep/functions/popfunc/_plot_utils.py b/src/eegprep/functions/popfunc/_plot_utils.py index 9e612aa7..0733017e 100644 --- a/src/eegprep/functions/popfunc/_plot_utils.py +++ b/src/eegprep/functions/popfunc/_plot_utils.py @@ -6,6 +6,7 @@ import numpy as np +from eegprep.functions.miscfunc.misc import finite_matmul from eegprep.functions.popfunc._chanutils import chanlocs_as_list from eegprep.functions.popfunc._pop_utils import is_empty_value, parse_numeric_sequence @@ -94,10 +95,15 @@ def selected_indices(values: Any, maximum: int, *, default_all: bool = True) -> return indices - 1 -def component_activations(EEG: dict[str, Any]) -> np.ndarray: - """Return ICA activations as ``components x points x trials``.""" +def component_activations(EEG: dict[str, Any], *, use_stored: bool = True) -> np.ndarray: + """Return ICA activations as ``components x points x trials``. + + When ``use_stored`` is true, a non-empty ``EEG['icaact']`` is returned as-is. + Rejection scoring passes ``use_stored=False`` to always recompute from the + weight and sphere matrices, matching EEGLAB's rejection path. + """ icaact = EEG.get("icaact") - if icaact is not None and np.asarray(icaact).size: + if use_stored and icaact is not None and np.asarray(icaact).size: data = np.asarray(icaact, dtype=float) if data.ndim == 2: return data[:, :, np.newaxis] @@ -110,10 +116,10 @@ def component_activations(EEG: dict[str, Any]) -> np.ndarray: data = eeg_epoch_data(EEG) icachansind = component_channel_indices(EEG, data.shape[0]) flat = data[icachansind, :, :].reshape(icachansind.size, -1) - unmixing = weights @ sphere + unmixing = finite_matmul(weights, sphere) if unmixing.shape[1] != flat.shape[0]: raise ValueError("ICA weights do not match EEG.icachansind channel count") - acts = unmixing @ flat + acts = finite_matmul(unmixing, flat) return acts.reshape(weights.shape[0], data.shape[1], data.shape[2]) diff --git a/src/eegprep/functions/popfunc/_rejection.py b/src/eegprep/functions/popfunc/_rejection.py index 04c68a8d..cbbc3b49 100644 --- a/src/eegprep/functions/popfunc/_rejection.py +++ b/src/eegprep/functions/popfunc/_rejection.py @@ -9,7 +9,7 @@ import numpy as np from scipy import signal, stats -from eegprep.functions.miscfunc.misc import finite_matmul +from eegprep.functions.popfunc._plot_utils import component_activations as plot_component_activations from eegprep.functions.popfunc._pop_utils import parse_numeric_sequence @@ -109,16 +109,16 @@ def data_3d(EEG: dict[str, Any]) -> np.ndarray: def component_activations(EEG: dict[str, Any]) -> np.ndarray: - """Return ICA activations as components x points x trials.""" + """Return ICA activations as components x points x trials. + + Rejection always recomputes from the weight and sphere matrices (ignoring a + stored ``EEG['icaact']``) so scoring matches EEGLAB's rejection path. + """ weights = np.asarray(EEG.get("icaweights", []), dtype=float) sphere = np.asarray(EEG.get("icasphere", []), dtype=float) if weights.size == 0 or sphere.size == 0: raise ValueError("ICA decomposition is required") - data = data_3d(EEG) - icachansind = np.asarray(EEG.get("icachansind", np.arange(data.shape[0])), dtype=int).ravel() - data_2d = data.reshape(data.shape[0], -1, order="F") - activations = finite_matmul(finite_matmul(weights, sphere), data_2d[icachansind]) - return activations.reshape(weights.shape[0], data.shape[1], data.shape[2], order="F") + return plot_component_activations(EEG, use_stored=False) def rejection_data(EEG: dict[str, Any], icacomp: int | bool) -> tuple[np.ndarray, int]: diff --git a/src/eegprep/functions/popfunc/eeg_amica.py b/src/eegprep/functions/popfunc/eeg_amica.py index bf0ca7ff..b4ab5399 100644 --- a/src/eegprep/functions/popfunc/eeg_amica.py +++ b/src/eegprep/functions/popfunc/eeg_amica.py @@ -1,7 +1,9 @@ """Perform ICA decomposition using the AMICA (Adaptive Mixture ICA) algorithm.""" +import copy + import numpy as np -from ..miscfunc.pinv import pinv +from ._ica_utils import finalize_ica_fields, flatten_ica_data, reshape_ica_activations from ..sigprocfunc.runamica import runamica @@ -62,9 +64,10 @@ def eeg_amica( dict The updated EEG structure with ICA fields. """ + EEG = copy.deepcopy(EEG) + # Extract data and reshape from 3D to 2D - data = EEG['data'].astype('float64') - data = data.reshape(data.shape[0], -1) + data = flatten_ica_data(EEG['data'].astype('float64')) # Run AMICA weights, sphere, mods = runamica( @@ -88,7 +91,7 @@ def eeg_amica( # Compute ICA activations EEG['icaact'] = (EEG['icaweights'] @ EEG['icasphere']) @ data # Reshape icaact back to 3D - EEG['icaact'] = EEG['icaact'].reshape(EEG['icaact'].shape[0], EEG['pnts'], EEG['trials']) + EEG['icaact'] = reshape_ica_activations(EEG['icaact'], EEG['pnts'], EEG['trials']) EEG['icachansind'] = np.arange(EEG['nbchan']) # Store full multi-model results @@ -96,41 +99,7 @@ def eeg_amica( EEG['etc'] = {} EEG['etc']['amica'] = mods - # Optionally sort components by mean descending activation variance - if sortcomps in ('on', True): - # Flatten icaact to 2D for variance computation - icaact_2d = EEG['icaact'].reshape(EEG['icaact'].shape[0], -1) - # Compute variance metric: sum(icawinv^2) .* sum(icaact^2) - variance_metric = np.sum(EEG['icawinv'] ** 2, axis=0) * np.sum(icaact_2d**2, axis=1) - # Sort indices in descending order - windex = np.argsort(variance_metric)[::-1] - # Reorder components - EEG['icaact'] = EEG['icaact'][windex, :, :] - EEG['icaweights'] = EEG['icaweights'][windex, :] - EEG['icawinv'] = EEG['icawinv'][:, windex] - - # Optionally normalize components using the same rule as runica() - if posact in ('on', True): - # Flatten icaact to 2D for finding max abs values - icaact_2d = EEG['icaact'].reshape(EEG['icaact'].shape[0], -1) - # Find indices of max absolute values for each component - ix = np.argmax(np.abs(icaact_2d), axis=1) - had_flips = False - ncomps = EEG['icaact'].shape[0] - - for r in range(ncomps): - if np.sign(icaact_2d[r, ix[r]]) < 0: - # Flip the activations - EEG['icaact'][r, :, :] = -EEG['icaact'][r, :, :] - # Flip the corresponding column of the mixing matrix - EEG['icawinv'][:, r] = -EEG['icawinv'][:, r] - had_flips = True - - if had_flips: - # Recompute unmixing matrix - EEG['icaweights'] = pinv(EEG['icawinv']) - - return EEG + return finalize_ica_fields(EEG, sortcomps=sortcomps, posact=posact) def load_amica_model(EEG, mods, model_num=0): @@ -160,13 +129,15 @@ def load_amica_model(EEG, mods, model_num=0): if model_num < 0 or model_num >= num_models: raise ValueError(f"model_num={model_num} out of range for {num_models} models") + EEG = copy.deepcopy(EEG) + EEG['icaweights'] = mods['W'][:, :, model_num] EEG['icasphere'] = mods['S'][:num_pcs, :] EEG['icawinv'] = mods['A'][:, :, model_num] # Recompute activations - data = EEG['data'].astype('float64').reshape(EEG['data'].shape[0], -1) + data = flatten_ica_data(EEG['data'].astype('float64')) EEG['icaact'] = (EEG['icaweights'] @ EEG['icasphere']) @ data - EEG['icaact'] = EEG['icaact'].reshape(EEG['icaact'].shape[0], EEG['pnts'], EEG['trials']) + EEG['icaact'] = reshape_ica_activations(EEG['icaact'], EEG['pnts'], EEG['trials']) return EEG diff --git a/src/eegprep/functions/popfunc/eeg_compare.py b/src/eegprep/functions/popfunc/eeg_compare.py index 8cd52c48..a10009af 100644 --- a/src/eegprep/functions/popfunc/eeg_compare.py +++ b/src/eegprep/functions/popfunc/eeg_compare.py @@ -4,30 +4,35 @@ differences between them. """ -import sys +import logging import math from collections.abc import Sequence import numpy as np +logger = logging.getLogger(__name__) + def eeg_compare(eeg1, eeg2, verbose_level=0, trigger_error=False): - """Compare two EEG-like structures, reporting differences to stderr. + """Compare two EEG-like structures (or arrays) and return a difference summary. + + Per-field findings are emitted through this module's logger; the returned value is the + human-readable summary string callers print or store. Parameters ---------- - eeg1 : dict or object - First EEG structure to compare. - eeg2 : dict or object - Second EEG structure to compare. + eeg1 : dict, object, or numpy.ndarray + First EEG structure (or array) to compare. + eeg2 : dict, object, or numpy.ndarray + Second EEG structure (or array) to compare. verbose_level : int, optional - Level of verbosity for output. Default 0. + Level of verbosity for logged output. Default 0. trigger_error : bool, optional - Whether to raise an error if differences are found. Default False. + Whether to raise a ``ValueError`` if differences are found. Default False. Returns ------- - bool - True if comparison completed (differences may still exist). + str + A summary describing the differences found, or that all fields match. """ summary_parts = [] @@ -48,24 +53,18 @@ def isequaln(a, b): if isinstance(a, np.ndarray) or isinstance(b, np.ndarray): try: return bool(np.array_equal(np.array(a), np.array(b), equal_nan=True)) - except Exception: - pass - # Handle numpy arrays in general comparison - if isinstance(a, np.ndarray) and isinstance(b, np.ndarray): - try: - return bool(np.array_equal(a, b, equal_nan=True)) - except Exception: + except (TypeError, ValueError): pass # Handle scalar vs array comparisons if isinstance(a, np.ndarray) and np.isscalar(b): try: return bool(np.all(a == b)) - except Exception: + except (TypeError, ValueError): pass if isinstance(b, np.ndarray) and np.isscalar(a): try: return bool(np.all(b == a)) - except Exception: + except (TypeError, ValueError): pass # Final comparison - ensure we return a boolean try: @@ -73,7 +72,7 @@ def isequaln(a, b): if isinstance(result, np.ndarray): return bool(result.all()) return bool(result) - except Exception: + except (TypeError, ValueError): return False def _numeric_distance(a, b): @@ -93,7 +92,7 @@ def _numeric_distance(a, b): return np.inf return float(np.max(np.abs(arr_a - arr_b))) - print('\nField analysis: (no entries means OK)') + logger.info('Field analysis: (no entries means OK)') # Collect differences for error reporting differences = [] @@ -101,7 +100,7 @@ def _numeric_distance(a, b): if isinstance(eeg1, np.ndarray) and isinstance(eeg2, np.ndarray): if eeg1.shape != eeg2.shape: summary = f"Array shape mismatch: {eeg1.shape} vs {eeg2.shape}" - print(summary, file=sys.stderr) + logger.warning(summary) if trigger_error: raise ValueError(summary) return summary @@ -184,7 +183,7 @@ def get_val2(f): for field in fields1: if not has_field2(field): error_msg = f'Field {field} missing in second dataset' - print(f' {error_msg}', file=sys.stderr) + logger.warning(' %s', error_msg) differences.append(error_msg) else: v1 = get_val1(field) @@ -192,16 +191,16 @@ def get_val2(f): if not isequaln(v1, v2): name = field.lower() if any(sub in name for sub in ('filename', 'datfile')): - print(f' Field {field} differs (ok, supposed to differ)') + logger.info(' Field %s differs (ok, supposed to differ)', field) elif any(sub in name for sub in ('subject', 'session', 'run', 'task')): error_msg = f'Field {field} differs ("{v1}" vs "{v2}")' - print(f' {error_msg}', file=sys.stderr) + logger.warning(' %s', error_msg) differences.append(error_msg) elif any(sub in name for sub in ('eventdescription')): n1 = len(v1) if isinstance(v1, Sequence) else 1 n2 = len(v2) if isinstance(v2, Sequence) else 1 error_msg = f'Field {field} differs (n={n1} vs n={n2})' - print(f' {error_msg}', file=sys.stderr) + logger.warning(' %s', error_msg) differences.append(error_msg) elif any(sub in name for sub in ('chanlocs', 'event', 'reject')): pass @@ -224,19 +223,19 @@ def get_val2(f): max_rel_diff = np.max(rel_diff) if rel_diff.size else 0.0 error_msg = f'Field {field} differs (max_abs={max_abs_diff:.6e}, mean_abs={mean_abs_diff:.6e}, rms={rms_diff:.6e}, max_rel={max_rel_diff:.6e})' - print(f' {error_msg}', file=sys.stderr) + logger.warning(' %s', error_msg) differences.append(error_msg) summary_parts.append( f" {field}: max_abs={max_abs_diff:.6e}, mean_abs={mean_abs_diff:.6e}, rms={rms_diff:.6e}" ) else: error_msg = f'Field {field} differs (shape mismatch: {v1.shape} vs {v2.shape})' - print(f' {error_msg}', file=sys.stderr) + logger.warning(' %s', error_msg) differences.append(error_msg) summary_parts.append(f" {field}: shape mismatch") else: error_msg = f'Field {field} differs' - print(f' {error_msg}', file=sys.stderr) + logger.warning(' %s', error_msg) differences.append(error_msg) # compare xmin/xmax for attr in ('xmin', 'xmax'): @@ -245,11 +244,11 @@ def get_val2(f): if not isequaln(x1, x2): diff = (x1 or 0) - (x2 or 0) error_msg = f'Difference between {attr} is {diff:1.6f} sec' - print(f' {error_msg}', file=sys.stderr) + logger.warning(' %s', error_msg) differences.append(error_msg) # channel locations - print('Chanlocs analysis:') + logger.info('Chanlocs analysis:') chans1 = get_val1('chanlocs') if chans1 is None or (isinstance(chans1, np.ndarray) and len(chans1) == 0): chans1 = [] @@ -266,26 +265,26 @@ def get_val2(f): if c1['labels'] != c2['labels']: label_diff += 1 if verbose_level > 0: - print(f' Channel {c1["labels"]} differs from {c2["labels"]}', file=sys.stderr) + logger.warning(' Channel %s differs from %s', c1["labels"], c2["labels"]) if coord_diff: error_msg = f'{coord_diff} channel coordinates differ' - print(f' {error_msg}', file=sys.stderr) + logger.warning(' %s', error_msg) differences.append(error_msg) else: - print(' All channel coordinates are OK') + logger.info(' All channel coordinates are OK') if label_diff: error_msg = f'{label_diff} channel label(s) differ' - print(f' {error_msg}', file=sys.stderr) + logger.warning(' %s', error_msg) differences.append(error_msg) else: - print(' All channel labels are OK') + logger.info(' All channel labels are OK') else: error_msg = 'Different numbers of channels' - print(f' {error_msg}', file=sys.stderr) + logger.warning(' %s', error_msg) differences.append(error_msg) # events - print('Event analysis:') + logger.info('Event analysis:') ev1 = get_val1('event') if ev1 is None or (isinstance(ev1, np.ndarray) and len(ev1) == 0): ev1 = [] @@ -294,23 +293,23 @@ def get_val2(f): ev2 = [] if len(ev1) != len(ev2): error_msg = f'Different numbers of events {len(ev1)} vs {len(ev2)}' - print(f' {error_msg}', file=sys.stderr) + logger.warning(' %s', error_msg) differences.append(error_msg) # print the first event of each if verbose_level > 0: if len(ev1) > 0: - print(f' First event of first dataset: {ev1[0]}', file=sys.stderr) + logger.warning(' First event of first dataset: %s', ev1[0]) if len(ev2) > 0: - print(f' First event of second dataset: {ev2[0]}', file=sys.stderr) + logger.warning(' First event of second dataset: %s', ev2[0]) else: if len(ev1) == 0: - print(' All events OK (empty)') + logger.info(' All events OK (empty)') else: f1 = set(ev1[0].keys()) f2 = set(ev2[0].keys()) if f1 != f2: error_msg = 'Not the same number of event fields' - print(f' {error_msg}', file=sys.stderr) + logger.warning(' %s', error_msg) differences.append(error_msg) for fld in f1: diffs = [] @@ -321,38 +320,16 @@ def get_val2(f): pct = len(nonzero) / len(diffs) * 100 avg = sum(abs(d) for d in nonzero) / len(nonzero) error_msg = f'Event latency ({pct:2.1f} %) not OK (abs diff {avg:1.4f} samples)' - print(f' {error_msg}', file=sys.stderr) + logger.warning(' %s', error_msg) differences.append(error_msg) - # print(' ******** (see plot)') - # import matplotlib.pyplot as plt - # plt.plot(diffs) - # plt.show() else: diffs = [not isequaln(e1.get(fld, None), e2.get(fld, None)) for e1, e2 in zip(ev1, ev2)] if any(diffs): pct = sum(diffs) / len(diffs) * 100 error_msg = f'Event fields "{fld}" are NOT OK ({pct:2.1f} % of them)' - print(f' {error_msg}', file=sys.stderr) + logger.warning(' %s', error_msg) differences.append(error_msg) - print(' All other events OK') - - # epochs - # if 'epoch' in eeg1: - # print('Epoch analysis:') - # ep1, ep2 = eeg1['epoch'], eeg2['epoch'] - # if len(ep1) != len(ep2): - # print(' Different numbers of epochs', file=sys.stderr) - # else: - # fields = ep1[0].keys() - # all_ok = True - # for fld in fields: - # diffs = [not isequaln(getattr(e1, fld, None), getattr(e2, fld, None)) for e1, e2 in zip(ep1, ep2)] - # if any(diffs): - # pct = sum(diffs) / len(diffs) * 100 - # print(f' Epoch fields "{fld}" are NOT OK ({pct:2.1f} % of them)', file=sys.stderr) - # all_ok = False - # if all_ok: - # print(' All epoch and all epoch fields are OK') + logger.info(' All other events OK') # Build final summary if summary_parts: @@ -388,16 +365,3 @@ def get_val2(f): raise ValueError(error_message) return summary - - -# add test data and compare with it - -# load test data -if __name__ == '__main__': - from eegprep import pop_loadset - - eeg1 = pop_loadset('../../sample_data/eeglab_data_tmp.set') - eeg2 = pop_loadset('../../sample_data/eeglab_data_tmp.set') - - # compare - eeg_compare(eeg1, eeg2) diff --git a/src/eegprep/functions/popfunc/eeg_eegrej.py b/src/eegprep/functions/popfunc/eeg_eegrej.py index a3e01616..3e841a0f 100644 --- a/src/eegprep/functions/popfunc/eeg_eegrej.py +++ b/src/eegprep/functions/popfunc/eeg_eegrej.py @@ -1,11 +1,15 @@ """EEG data rejection functions.""" +import logging from typing import List, Dict, Optional, Tuple import numpy as np from copy import deepcopy from ..miscfunc.misc import round_mat +logger = logging.getLogger(__name__) + + def _is_boundary_event(event: Dict) -> bool: t = event.get("type") if isinstance(t, str): @@ -324,10 +328,15 @@ def eeg_eegrej(EEG, regions): if len(EEG["event"]) > 1 and EEG["event"][-1].get("latency", 0) - 0.5 > EEG["pnts"] and EEG.get("trials", 1) == 1: EEG["event"].pop() - # light duplicate cleanup mirroring MATLAB edge cases - if len(EEG["event"]) > 1 and EEG["event"][0].get("latency") == 0: + # light duplicate cleanup mirroring MATLAB edge cases: only drop boundary + # events sitting at the very first/last sample, never genuine stimulus events + if len(EEG["event"]) > 1 and EEG["event"][0].get("latency") == 0 and _is_boundary_event(EEG["event"][0]): EEG["event"] = EEG["event"][1:] - if len(EEG["event"]) > 1 and EEG["event"][-1].get("latency") == EEG["pnts"]: + if ( + len(EEG["event"]) > 1 + and EEG["event"][-1].get("latency") == EEG["pnts"] + and _is_boundary_event(EEG["event"][-1]) + ): EEG["event"] = EEG["event"][:-1] if len(EEG["event"]) > 2: if EEG["event"][-1].get("latency") == EEG["event"][-2].get("latency"): @@ -359,7 +368,7 @@ def _combine_regions(regs): merged.append([beg, end]) newregs = np.asarray(merged, dtype=np.int64) if newregs.shape[0] != regs.shape[0]: - print("Warning: overlapping regions detected and fixed in eeg_eegrej") + logger.warning("Overlapping regions detected and fixed in eeg_eegrej") return newregs diff --git a/src/eegprep/functions/popfunc/eeg_interp.py b/src/eegprep/functions/popfunc/eeg_interp.py index 57f62b2a..8ddf8eb8 100644 --- a/src/eegprep/functions/popfunc/eeg_interp.py +++ b/src/eegprep/functions/popfunc/eeg_interp.py @@ -4,21 +4,12 @@ methods including spherical spline interpolation. """ -# to do, look at line 83 and 84 and try to see if the MATLAB array output match. Run code side by side. - -# EEG = pop_loadset('sample_data/eeglab_data_tmp.set'); -# EEG = eeg_interp(EEG, [1, 2, 3], 'spherical'); % or EEG = eeg_interp(EEG, {'Fp1' 'Fp2' 'F7'}, 'spherical'); -# pop_save(EEG, 'sample_data/eeglab_data_tmp_out_matlab.set'); - import numpy as np from scipy.linalg import pinv from scipy.interpolate import RBFInterpolator, griddata from scipy.special import lpmv from copy import deepcopy -# absolute path for all files in data folder -data_path = '/Users/arno/Python/eegprep/sample_data/' # os.path.abspath('sample_data/') - def eeg_interp(EEG, bad_chans, method='spherical', t_range=None, params=None, dtype='float32'): """Interpolate missing or bad EEG channels using spherical spline. @@ -525,153 +516,3 @@ def computeg(x, y, z, xelec, yelec, zelec, params): g += ((2 * n + 1) / (n**m * (n + 1) ** m)) * Pn return g / (4 * np.pi) - - -# Test functions moved to tests/test_eeg_interp.py - - -def test_chanloc_interpolation(): - """Example usage of the new chanloc interpolation functionality. - - This demonstrates the three different cases. - """ - # Create a sample EEG structure - EEG = { - 'data': np.random.randn(4, 100, 1), # 4 channels, 100 time points, 1 trial - 'nbchan': 4, - 'pnts': 100, - 'trials': 1, - 'srate': 500, - 'xmin': 0, - 'xmax': 0.2, - 'chanlocs': [ - {'labels': 'Fp1', 'X': 0.1, 'Y': 0.8, 'Z': 0.6}, - {'labels': 'Fp2', 'X': -0.1, 'Y': 0.8, 'Z': 0.6}, - {'labels': 'F3', 'X': 0.4, 'Y': 0.6, 'Z': 0.7}, - {'labels': 'F4', 'X': -0.4, 'Y': 0.6, 'Z': 0.7}, - ], - } - - print("Original EEG structure:") - print(f"Data shape: {EEG['data'].shape}") - print(f"Number of channels: {EEG['nbchan']}") - print(f"Channel labels: {[ch['labels'] for ch in EEG['chanlocs']]}") - - # Case 1: Identical chanlocs (should return unchanged) - identical_chanlocs = EEG['chanlocs'].copy() - result1 = eeg_interp(EEG.copy(), identical_chanlocs) - print("\nCase 1 - Identical chanlocs:") - print(f"Data shape unchanged: {result1['data'].shape == EEG['data'].shape}") - print(f"Data is identical: {np.array_equal(result1['data'], EEG['data'])}") - - # Case 2: No overlap (should append new channels) - new_chanlocs = [ - {'labels': 'T7', 'X': 0.8, 'Y': 0.0, 'Z': 0.6}, - {'labels': 'T8', 'X': -0.8, 'Y': 0.0, 'Z': 0.6}, - ] - result2 = eeg_interp(EEG.copy(), new_chanlocs) - print("\nCase 2 - No overlap (append new channels):") - print(f"Original channels: {EEG['nbchan']}, After: {result2['nbchan']}") - print(f"Data shape: {EEG['data'].shape} -> {result2['data'].shape}") - print(f"New channel labels: {[ch['labels'] for ch in result2['chanlocs']]}") - - # Case 3: Existing channels are proper subset (should remap to new structure) - superset_chanlocs = [ - {'labels': 'Fp1', 'X': 0.1, 'Y': 0.8, 'Z': 0.6}, - {'labels': 'Fp2', 'X': -0.1, 'Y': 0.8, 'Z': 0.6}, - {'labels': 'F3', 'X': 0.4, 'Y': 0.6, 'Z': 0.7}, - {'labels': 'F4', 'X': -0.4, 'Y': 0.6, 'Z': 0.7}, - {'labels': 'C3', 'X': 0.6, 'Y': 0.0, 'Z': 0.8}, - {'labels': 'C4', 'X': -0.6, 'Y': 0.0, 'Z': 0.8}, - ] - result3 = eeg_interp(EEG.copy(), superset_chanlocs) - print("\nCase 3 - Existing subset of new structure:") - print(f"Original channels: {EEG['nbchan']}, After: {result3['nbchan']}") - print(f"Data shape: {EEG['data'].shape} -> {result3['data'].shape}") - print(f"Final channel labels: {[ch['labels'] for ch in result3['chanlocs']]}") - - return result1, result2, result3 - - -def test_ica_indices_update(): - """Test that ICA channel indices are properly updated when channels are. - - reordered. - - Test that ICA channel indices are properly updated when channels are - - reordered during interpolation with chanloc structures. - """ - # Create a sample EEG structure with ICA data - EEG = { - 'data': np.random.randn(4, 100, 1), # 4 channels, 100 time points, 1 trial - 'nbchan': 4, - 'pnts': 100, - 'trials': 1, - 'srate': 500, - 'xmin': 0, - 'xmax': 0.2, - 'chanlocs': [ - {'labels': 'Fp1', 'X': 0.1, 'Y': 0.8, 'Z': 0.6}, - {'labels': 'Fp2', 'X': -0.1, 'Y': 0.8, 'Z': 0.6}, - {'labels': 'F3', 'X': 0.4, 'Y': 0.6, 'Z': 0.7}, - {'labels': 'F4', 'X': -0.4, 'Y': 0.6, 'Z': 0.7}, - ], - # Add ICA fields - 'icasphere': np.eye(4), # 4x4 identity matrix (not empty) - 'icaweights': np.random.randn(4, 4), - 'icawinv': np.random.randn(4, 4), - 'icachansind': [0, 1, 2, 3], # All channels used for ICA (0-based) - 'chaninfo': { - 'icachansind': [0, 1, 2, 3], - }, - } - - print("Original EEG structure with ICA:") - print(f"Data shape: {EEG['data'].shape}") - print(f"Number of channels: {EEG['nbchan']}") - print(f"Channel labels: {[ch['labels'] for ch in EEG['chanlocs']]}") - print(f"ICA channel indices: {EEG['icachansind']}") - print(f"Chaninfo ICA indices: {EEG['chaninfo']['icachansind']}") - - # Test Case: Subset interpolation that causes channel reordering - # Create a superset where the existing channels appear in different order - superset_chanlocs = [ - {'labels': 'F3', 'X': 0.4, 'Y': 0.6, 'Z': 0.7}, # was index 2, now 0 - {'labels': 'Fp1', 'X': 0.1, 'Y': 0.8, 'Z': 0.6}, # was index 0, now 1 - {'labels': 'C3', 'X': 0.6, 'Y': 0.0, 'Z': 0.8}, # new channel, index 2 - {'labels': 'Fp2', 'X': -0.1, 'Y': 0.8, 'Z': 0.6}, # was index 1, now 3 - {'labels': 'F4', 'X': -0.4, 'Y': 0.6, 'Z': 0.7}, # was index 3, now 4 - {'labels': 'C4', 'X': -0.6, 'Y': 0.0, 'Z': 0.8}, # new channel, index 5 - ] - - result = eeg_interp(EEG.copy(), superset_chanlocs) - - print("\nAfter interpolation with reordering:") - print(f"Data shape: {EEG['data'].shape} -> {result['data'].shape}") - print(f"Number of channels: {EEG['nbchan']} -> {result['nbchan']}") - print(f"Channel labels: {[ch['labels'] for ch in result['chanlocs']]}") - print(f"ICA channel indices: {EEG['icachansind']} -> {result['icachansind']}") - - # Verify the mapping is correct: - # Original: Fp1=0, Fp2=1, F3=2, F4=3 - # New: F3=0, Fp1=1, C3=2, Fp2=3, F4=4, C4=5 - # So ICA indices should be updated: [0,1,2,3] -> [1,3,0,4] - expected_indices = [1, 3, 0, 4] # New positions of Fp1, Fp2, F3, F4 - - print(f"Expected ICA indices: {expected_indices}") - print(f"Actual ICA indices: {result['icachansind']}") - print(f"Mapping correct: {result['icachansind'] == expected_indices}") - - # Also verify chaninfo is updated - if 'chaninfo' in result and 'icachansind' in result['chaninfo']: - print(f"Chaninfo ICA indices: {result['chaninfo']['icachansind']}") - print(f"Chaninfo mapping correct: {result['chaninfo']['icachansind'] == expected_indices}") - - return result - - -# Uncomment to run the tests -# if __name__ == '__main__': -# test_chanloc_interpolation() -# test_ica_indices_update() diff --git a/src/eegprep/functions/popfunc/eeg_lat2point.py b/src/eegprep/functions/popfunc/eeg_lat2point.py index fd98161c..49c813d1 100644 --- a/src/eegprep/functions/popfunc/eeg_lat2point.py +++ b/src/eegprep/functions/popfunc/eeg_lat2point.py @@ -1,8 +1,13 @@ """EEG latency to point conversion utilities.""" +import logging + import numpy as np +logger = logging.getLogger(__name__) + + def eeg_lat2point(lat_array, epoch_array, srate, timewin, timeunit=1.0, **kwargs): """Convert latencies in time units (relative to per-epoch time 0) to latencies in data points assuming concatenated epochs (EEGLAB style). @@ -70,7 +75,7 @@ def eeg_lat2point(lat_array, epoch_array, srate, timewin, timeunit=1.0, **kwargs newlat[idx] = max_valid flag = 1 # mirror MATLAB's informational message - print('eeg_lat2point(): Points out of range detected. Points replaced with maximum value') + logger.warning("Points out of range detected. Points replaced with maximum value") else: raise ValueError('Error in eeg_lat2point(): Points out of range detected') diff --git a/src/eegprep/functions/popfunc/eeg_picard.py b/src/eegprep/functions/popfunc/eeg_picard.py index a5fef76b..6b526d16 100644 --- a/src/eegprep/functions/popfunc/eeg_picard.py +++ b/src/eegprep/functions/popfunc/eeg_picard.py @@ -1,7 +1,10 @@ """Module for performing ICA decomposition using the Picard algorithm.""" +import copy + from picard import picard import numpy as np +from ._ica_utils import finalize_ica_fields, flatten_ica_data, reshape_ica_activations from ..miscfunc.pinv import pinv @@ -29,12 +32,14 @@ def eeg_picard(EEG, engine=None, posact='off', sortcomps='off', **kwargs): dict The updated EEG structure with ICA fields. """ + EEG = copy.deepcopy(EEG) + if engine is None: # Assuming EEG['data'] contains the EEG data as a numpy array of shape (channels, timepoints) data = EEG['data'].astype('float64') # reshape from 3D to 2D - data = data.reshape(data.shape[0], -1) + data = flatten_ica_data(data) # Parameters to match MATLAB picard defaults for reproducible parity # Using identity w_init ensures deterministic results matching MATLAB @@ -63,7 +68,7 @@ def eeg_picard(EEG, engine=None, posact='off', sortcomps='off', **kwargs): EEG['icaact'] = sources # reshape EEG['icaact'] back to 3D as EEG['data'] - EEG['icaact'] = EEG['icaact'].reshape(EEG['icaact'].shape[0], EEG['pnts'], EEG['trials']) + EEG['icaact'] = reshape_ica_activations(EEG['icaact'], EEG['pnts'], EEG['trials']) EEG['icachansind'] = np.arange(EEG['nbchan']) else: @@ -72,38 +77,4 @@ def eeg_picard(EEG, engine=None, posact='off', sortcomps='off', **kwargs): # sorting/normalization options) EEG = engine.eeg_picard(EEG, **kwargs) - # optionally sort components by mean descending activation variance - if sortcomps in ('on', True): - # Flatten icaact to 2D for variance computation - icaact_2d = EEG['icaact'].reshape(EEG['icaact'].shape[0], -1) - # Compute variance metric: sum(icawinv^2) .* sum(icaact^2) - variance_metric = np.sum(EEG['icawinv'] ** 2, axis=0) * np.sum(icaact_2d**2, axis=1) - # Sort indices in descending order - windex = np.argsort(variance_metric)[::-1] - # Reorder components - EEG['icaact'] = EEG['icaact'][windex, :, :] - EEG['icaweights'] = EEG['icaweights'][windex, :] - EEG['icawinv'] = EEG['icawinv'][:, windex] - - # optionally normalize components using the same rule as runica() - if posact in ('on', True): - # Flatten icaact to 2D for finding max abs values - icaact_2d = EEG['icaact'].reshape(EEG['icaact'].shape[0], -1) - # Find indices of max absolute values for each component - ix = np.argmax(np.abs(icaact_2d), axis=1) - had_flips = False - ncomps = EEG['icaact'].shape[0] - - for r in range(ncomps): - if np.sign(icaact_2d[r, ix[r]]) < 0: - # Flip the activations - EEG['icaact'][r, :, :] = -EEG['icaact'][r, :, :] - # Flip the corresponding column of the mixing matrix - EEG['icawinv'][:, r] = -EEG['icawinv'][:, r] - had_flips = True - - if had_flips: - # Recompute unmixing matrix - EEG['icaweights'] = pinv(EEG['icawinv']) - - return EEG + return finalize_ica_fields(EEG, sortcomps=sortcomps, posact=posact) diff --git a/src/eegprep/functions/popfunc/eeg_runica.py b/src/eegprep/functions/popfunc/eeg_runica.py index 313c9afa..cd8e47db 100644 --- a/src/eegprep/functions/popfunc/eeg_runica.py +++ b/src/eegprep/functions/popfunc/eeg_runica.py @@ -1,5 +1,7 @@ +import copy + import numpy as np -from ._ica_utils import flatten_ica_data, reshape_ica_activations +from ._ica_utils import finalize_ica_fields, flatten_ica_data, reshape_ica_activations from ..miscfunc.misc import finite_matmul, finite_pinv from ..miscfunc.pinv import pinv from ..sigprocfunc.runica import runica @@ -25,6 +27,8 @@ def eeg_runica(EEG, posact='off', sortcomps='off', **kwargs): dict The updated EEG structure with ICA fields. """ + EEG = copy.deepcopy(EEG) + # Extract data and reshape from 3D to 2D data = flatten_ica_data(EEG['data'].astype('float64')) @@ -49,38 +53,4 @@ def eeg_runica(EEG, posact='off', sortcomps='off', **kwargs): EEG['icaact'] = reshape_ica_activations(EEG['icaact'], EEG['pnts'], EEG['trials']) EEG['icachansind'] = np.arange(EEG['nbchan']) - # Optionally sort components by mean descending activation variance - if sortcomps in ('on', True): - # Flatten icaact to 2D for variance computation - icaact_2d = flatten_ica_data(EEG['icaact']) - # Compute variance metric: sum(icawinv^2) .* sum(icaact^2) - variance_metric = np.sum(EEG['icawinv'] ** 2, axis=0) * np.sum(icaact_2d**2, axis=1) - # Sort indices in descending order - windex = np.argsort(variance_metric)[::-1] - # Reorder components - EEG['icaact'] = EEG['icaact'][windex, :, :] - EEG['icaweights'] = EEG['icaweights'][windex, :] - EEG['icawinv'] = EEG['icawinv'][:, windex] - - # Optionally normalize components using the same rule as runica() - if posact in ('on', True): - # Flatten icaact to 2D for finding max abs values - icaact_2d = flatten_ica_data(EEG['icaact']) - # Find indices of max absolute values for each component - ix = np.argmax(np.abs(icaact_2d), axis=1) - had_flips = False - ncomps = EEG['icaact'].shape[0] - - for r in range(ncomps): - if np.sign(icaact_2d[r, ix[r]]) < 0: - # Flip the activations - EEG['icaact'][r, :, :] = -EEG['icaact'][r, :, :] - # Flip the corresponding column of the mixing matrix - EEG['icawinv'][:, r] = -EEG['icawinv'][:, r] - had_flips = True - - if had_flips: - # Recompute unmixing matrix - EEG['icaweights'] = finite_pinv(EEG['icawinv'], solver=pinv) - - return EEG + return finalize_ica_fields(EEG, sortcomps=sortcomps, posact=posact) diff --git a/src/eegprep/functions/popfunc/pop_chansel.py b/src/eegprep/functions/popfunc/pop_chansel.py index 92757891..9e871d03 100644 --- a/src/eegprep/functions/popfunc/pop_chansel.py +++ b/src/eegprep/functions/popfunc/pop_chansel.py @@ -39,7 +39,7 @@ def pop_chansel( return [], "", [] allchanstr = [channel_values[index - 1] for index in chanlist] - chanliststr = _selected_string(allchanstr, withindex_value) + chanliststr = _selected_string(allchanstr) if handle is not None and hasattr(handle, "setText"): handle.setText(chanliststr) return chanlist, chanliststr, allchanstr @@ -62,12 +62,22 @@ def pop_chansel_selected_string( select: Any, *, field: str = "labels", - withindex: str = "off", ) -> str: """Return EEGLAB's selected channel string without opening the dialog.""" channel_values = _channel_values(chans, field) selected = _selection_to_indices(select, channel_values) - return _selected_string([channel_values[index - 1] for index in selected], withindex) + return _selected_string([channel_values[index - 1] for index in selected]) + + +def pop_chansel_resolve( + chans: Any, + select: Any, + *, + field: str = "labels", +) -> tuple[list[str], list[int]]: + """Resolve a channel selection to its labels and 1-based indices without a dialog.""" + channel_values = _channel_values(chans, field) + return channel_values, _selection_to_indices(select, channel_values) def _channel_values(chans: Any, field: str) -> list[str]: @@ -136,7 +146,7 @@ def _parse_text(text: str) -> list[str]: return [next(part for part in token if part) for token in tokens] -def _selected_string(values: list[str], withindex: str) -> str: +def _selected_string(values: list[str]) -> str: if not values: return "" space_present = any(" " in value or "\t" in value for value in values) diff --git a/src/eegprep/functions/popfunc/pop_comperp.py b/src/eegprep/functions/popfunc/pop_comperp.py index 66660bc9..e9464cf7 100644 --- a/src/eegprep/functions/popfunc/pop_comperp.py +++ b/src/eegprep/functions/popfunc/pop_comperp.py @@ -19,6 +19,7 @@ history_command, numeric_vector, ) +from eegprep.functions.popfunc._pop_utils import is_on def pop_comperp( @@ -64,22 +65,18 @@ def pop_comperp( pvalues = _significance(add_stack, sub_stack, options.get("alpha")) lowpass = numeric_vector(options.get("lowpass", [])) if lowpass.size: - erp1 = _lowpass_erp(erp1, float(lowpass[0]), float(datasets[int(add_indices[0])].get("srate", 1) or 1)) + cutoff = float(lowpass[0]) + srate = float(datasets[int(add_indices[0])].get("srate", 1) or 1) + erp1 = _lowpass(erp1, cutoff, srate, axis=1) if erp2 is not None: - erp2 = _lowpass_erp(erp2, float(lowpass[0]), float(datasets[int(add_indices[0])].get("srate", 1) or 1)) + erp2 = _lowpass(erp2, cutoff, srate, axis=1) if erpsub is not None: - erpsub = _lowpass_erp(erpsub, float(lowpass[0]), float(datasets[int(add_indices[0])].get("srate", 1) or 1)) - add_stack = _lowpass_stack( - add_stack, float(lowpass[0]), float(datasets[int(add_indices[0])].get("srate", 1) or 1) - ) + erpsub = _lowpass(erpsub, cutoff, srate, axis=1) + add_stack = _lowpass(add_stack, cutoff, srate, axis=2) if sub_stack is not None: - sub_stack = _lowpass_stack( - sub_stack, float(lowpass[0]), float(datasets[int(add_indices[0])].get("srate", 1) or 1) - ) + sub_stack = _lowpass(sub_stack, cutoff, srate, axis=2) if diff_stack is not None: - diff_stack = _lowpass_stack( - diff_stack, float(lowpass[0]), float(datasets[int(add_indices[0])].get("srate", 1) or 1) - ) + diff_stack = _lowpass(diff_stack, cutoff, srate, axis=2) times = eeg_times_ms(datasets[int(add_indices[0])]) figure = _plot_comperp( erp1, @@ -246,23 +243,23 @@ def _plot_comperp( fig, ax = plt.subplots(figsize=(8, 4.5)) plot_times, mask = _plot_time_mask(times, options.get("tlim")) mode = str(options.get("mode") or "ave").lower() - if _is_on(options.get("addall")): + if is_on(options.get("addall")): _plot_all(ax, add_stack, plot_times, mask, color="0.45", label_prefix="add") - if _is_on(options.get("suball")) and sub_stack is not None: + if is_on(options.get("suball")) and sub_stack is not None: _plot_all(ax, sub_stack, plot_times, mask, color="0.65", label_prefix="sub") - if _is_on(options.get("diffall")) and diff_stack is not None: + if is_on(options.get("diffall")) and diff_stack is not None: _plot_all(ax, diff_stack, plot_times, mask, color="0.35", label_prefix="diff") - if _is_on(options.get("addavg")): + if is_on(options.get("addavg")): ax.plot(plot_times, np.nanmean(erp1, axis=0)[mask], color="blue", label="add") - if erp2 is not None and _is_on(options.get("subavg")): + if erp2 is not None and is_on(options.get("subavg")): ax.plot(plot_times, np.nanmean(erp2, axis=0)[mask], color="red", label="subtract") - if erpsub is not None and _is_on(options.get("diffavg")): + if erpsub is not None and is_on(options.get("diffavg")): ax.plot(plot_times, np.nanmean(erpsub, axis=0)[mask], color="black", label="difference") - if _is_on(options.get("addstd")): + if is_on(options.get("addstd")): _plot_std(ax, add_stack, plot_times, mask, color="blue", label="add std", mode=mode) - if _is_on(options.get("substd")) and sub_stack is not None: + if is_on(options.get("substd")) and sub_stack is not None: _plot_std(ax, sub_stack, plot_times, mask, color="red", label="sub std", mode=mode) - if _is_on(options.get("diffstd")) and diff_stack is not None: + if is_on(options.get("diffstd")) and diff_stack is not None: _plot_std(ax, diff_stack, plot_times, mask, color="black", label="diff std", mode=mode) alpha = options.get("alpha") if pvalues is not None and alpha is not None: @@ -301,18 +298,11 @@ def _validate_time_grid(datasets: list[dict[str, Any]]) -> None: raise ValueError(f"Dataset {index} does not share the same time grid") -def _lowpass_erp(values: np.ndarray, cutoff: float, srate: float) -> np.ndarray: +def _lowpass(values: np.ndarray, cutoff: float, srate: float, axis: int) -> np.ndarray: if cutoff <= 0 or cutoff >= srate / 2: raise ValueError("lowpass must be greater than 0 and below Nyquist") sos = butter(4, cutoff, btype="lowpass", fs=srate, output="sos") - return sosfiltfilt(sos, values, axis=1) - - -def _lowpass_stack(values: np.ndarray, cutoff: float, srate: float) -> np.ndarray: - if cutoff <= 0 or cutoff >= srate / 2: - raise ValueError("lowpass must be greater than 0 and below Nyquist") - sos = butter(4, cutoff, btype="lowpass", fs=srate, output="sos") - return sosfiltfilt(sos, values, axis=2) + return sosfiltfilt(sos, values, axis=axis) def _is_default_off(value: Any) -> bool: @@ -402,15 +392,7 @@ def _optional_alpha(value: Any) -> float | None: def _onoff_option(value: Any, default: bool) -> str: if value is None or (_is_default_off(value) and default is False): return "on" if default else "off" - return "on" if _is_on(value) else "off" - - -def _is_on(value: Any) -> bool: - if isinstance(value, str): - return value.strip().lower() in {"on", "yes", "true", "1"} - if isinstance(value, np.ndarray): - return bool(value.size and np.asarray(value).ravel()[0]) - return bool(value) + return "on" if is_on(value) else "off" def _significance(add_stack: np.ndarray, sub_stack: np.ndarray | None, alpha: float | None) -> np.ndarray | None: diff --git a/src/eegprep/functions/popfunc/pop_editeventvals.py b/src/eegprep/functions/popfunc/pop_editeventvals.py index eca1f8a6..4e949661 100644 --- a/src/eegprep/functions/popfunc/pop_editeventvals.py +++ b/src/eegprep/functions/popfunc/pop_editeventvals.py @@ -13,6 +13,7 @@ from eegprep.functions.popfunc._event_utils import event_field_names, events_as_list, normalize_one_based_indices from eegprep.functions.popfunc._pop_utils import format_history_value, is_empty_value as _is_empty from eegprep.functions.popfunc.eeg_lat2point import eeg_lat2point +from eegprep.functions.popfunc.eeg_point2lat import eeg_point2lat def pop_editeventvals( @@ -176,15 +177,14 @@ def _change_field( def _internal_event_value(output: dict[str, Any], event: dict[str, Any], field: str, value: Any) -> Any: if field == "latency" and value not in {"", None}: if int(output.get("trials", 1) or 1) > 1: - return float( - eeg_lat2point( - float(value), - event.get("epoch", 1), - float(output.get("srate", 1)), - [float(output.get("xmin", 0)) * 1000, float(output.get("xmax", 0)) * 1000], - 1e-3, - ) + newlat, _ = eeg_lat2point( + float(value), + event.get("epoch", 1), + float(output.get("srate", 1)), + [float(output.get("xmin", 0)) * 1000, float(output.get("xmax", 0)) * 1000], + 1e-3, ) + return float(newlat.item()) return (float(value) - float(output.get("xmin", 0))) * float(output.get("srate", 1)) + 1 if field == "duration" and value not in {"", None}: scale = float(output.get("srate", 1)) / (1000 if int(output.get("trials", 1) or 1) > 1 else 1) @@ -198,7 +198,15 @@ def _display_event_value(EEG: dict[str, Any], event: dict[str, Any], field: str) return "" if field == "latency": if int(EEG.get("trials", 1) or 1) > 1: - return value + return float( + eeg_point2lat( + float(value), + event.get("epoch", 1), + float(EEG.get("srate", 1)), + [float(EEG.get("xmin", 0)) * 1000, float(EEG.get("xmax", 0)) * 1000], + 1e-3, + ).item() + ) return (float(value) - 1) / float(EEG.get("srate", 1)) + float(EEG.get("xmin", 0)) if field == "duration": scale = float(EEG.get("srate", 1)) / (1000 if int(EEG.get("trials", 1) or 1) > 1 else 1) diff --git a/src/eegprep/functions/popfunc/pop_eegplot.py b/src/eegprep/functions/popfunc/pop_eegplot.py index 122680d1..5ebedd52 100644 --- a/src/eegprep/functions/popfunc/pop_eegplot.py +++ b/src/eegprep/functions/popfunc/pop_eegplot.py @@ -153,7 +153,7 @@ def apply_eegplot_rejections( store_superpose = superpose else: trial_marks = np.zeros(trials, dtype=bool) - row_marks = np.zeros((_row_count(out, icacomp), trials), dtype=bool) + row_marks = np.zeros((rejection_row_count(out, icacomp), trials), dtype=bool) store_superpose = 0 _store_epoch_marks(out, trial_marks, row_marks, icacomp=icacomp, superpose=store_superpose) if int(bool(reject)) and trial_marks.any(): @@ -165,7 +165,7 @@ def _initial_epoch_winrej(EEG: dict[str, Any], icacomp: int, superpose: int) -> reject = EEG.get("reject") or {} trials = int(EEG.get("trials", 1) or 1) pnts = int(EEG.get("pnts", np.asarray(EEG.get("data")).shape[1])) - row_count = _row_count(EEG, icacomp) + row_count = rejection_row_count(EEG, icacomp) manual, manual_e = _reject_arrays(reject, "rejmanual", trials, row_count, icacomp=icacomp) if int(superpose) == 0: return trial2eegplot(manual, manual_e, pnts, _manual_color(EEG)) @@ -184,7 +184,7 @@ def _initial_epoch_winrej(EEG: dict[str, Any], icacomp: int, superpose: int) -> def _initial_continuous_winrej(EEG: dict[str, Any], icacomp: int) -> np.ndarray: - row_count = _row_count(EEG, icacomp) + row_count = rejection_row_count(EEG, icacomp) rows = _as_winrej_rows((EEG.get("reject") or {}).get(_continuous_mark_field(icacomp), [])) if rows.size == 0: return np.zeros((0, 5 + row_count), dtype=float) @@ -234,19 +234,19 @@ def _store_epoch_marks( field = "rejmanual" if int(bool(icacomp)) else "icarejmanual" field_e = f"{field}E" trials = int(EEG.get("trials", 1) or 1) - row_count = _row_count(EEG, icacomp) + row_count = rejection_row_count(EEG, icacomp) current, current_e = _reject_arrays(reject, "rejmanual", trials, row_count, icacomp=icacomp) if int(superpose): trial_marks = np.asarray(trial_marks, dtype=bool) | current - row_marks = _pad_rows(row_marks, row_count, trials) | current_e + row_marks = pad_rejection_rows(row_marks, row_count, trials) | current_e reject[field] = np.asarray(trial_marks, dtype=bool) - reject[field_e] = _pad_rows(row_marks, row_count, trials) + reject[field_e] = pad_rejection_rows(row_marks, row_count, trials) reject.setdefault("rejmanualcol", np.asarray(MANUAL_REJECTION_COLOR, dtype=float)) def _store_continuous_marks(EEG: dict[str, Any], rows: np.ndarray, *, icacomp: int) -> None: reject = EEG.setdefault("reject", {}) - row_count = _row_count(EEG, icacomp) + row_count = rejection_row_count(EEG, icacomp) if rows.size == 0: reject[_continuous_mark_field(icacomp)] = np.zeros((0, 5 + row_count), dtype=float) else: @@ -259,7 +259,7 @@ def _store_continuous_marks(EEG: dict[str, Any], rows: np.ndarray, *, icacomp: i def _clear_continuous_marks(EEG: dict[str, Any], *, icacomp: int) -> None: reject = EEG.setdefault("reject", {}) - reject[_continuous_mark_field(icacomp)] = np.zeros((0, 5 + _row_count(EEG, icacomp)), dtype=float) + reject[_continuous_mark_field(icacomp)] = np.zeros((0, 5 + rejection_row_count(EEG, icacomp)), dtype=float) def _continuous_mark_field(icacomp: int) -> str: @@ -292,11 +292,12 @@ def _reject_arrays( marks = np.asarray(reject.get(f"{prefix}{kind}", []), dtype=bool).ravel() out = np.zeros(trials, dtype=bool) out[: min(trials, marks.size)] = marks[:trials] - row_marks = _pad_rows(np.asarray(reject.get(f"{prefix}{kind}E", []), dtype=bool), row_count, trials) + row_marks = pad_rejection_rows(np.asarray(reject.get(f"{prefix}{kind}E", []), dtype=bool), row_count, trials) return out, row_marks -def _pad_rows(values: np.ndarray, row_count: int, trials: int) -> np.ndarray: +def pad_rejection_rows(values: np.ndarray, row_count: int, trials: int) -> np.ndarray: + """Zero-pad/crop a row-mask array to ``(row_count, trials)``.""" out = np.zeros((row_count, trials), dtype=bool) arr = np.asarray(values, dtype=bool) if arr.ndim == 1 and arr.size: @@ -308,7 +309,8 @@ def _pad_rows(values: np.ndarray, row_count: int, trials: int) -> np.ndarray: return out -def _row_count(EEG: dict[str, Any], icacomp: int) -> int: +def rejection_row_count(EEG: dict[str, Any], icacomp: int) -> int: + """Return the number of channel or component rows for rejection marks.""" if int(bool(icacomp)): return int(EEG.get("nbchan", np.asarray(EEG.get("data")).shape[0]) or 0) weights = np.asarray(EEG.get("icaweights", [])) diff --git a/src/eegprep/functions/popfunc/pop_erpimage.py b/src/eegprep/functions/popfunc/pop_erpimage.py index c3646da8..ea4850d6 100644 --- a/src/eegprep/functions/popfunc/pop_erpimage.py +++ b/src/eegprep/functions/popfunc/pop_erpimage.py @@ -18,6 +18,7 @@ numeric_vector, parse_plot_options_text, ) +from eegprep.functions.popfunc._pop_utils import is_on from eegprep.functions.sigprocfunc.erpimage import erpimage @@ -57,7 +58,7 @@ def pop_erpimage( renorm = kwargs.pop("renorm", "no") if sorting_field: sort_values = _event_sort_values(EEG, sorting_field, sorting_type, sorting_window, renorm) - if _is_on(kwargs.pop("nosort", False)): + if is_on(kwargs.pop("nosort", False)): sort_values = None kwargs.pop("noplot", None) align = numeric_vector(kwargs.pop("align", [])) @@ -503,7 +504,7 @@ def _normalise_other_options(text: Any) -> dict[str, Any]: for key, value in parsed.items(): normalised = aliases.get(key, key) if normalised in {"erp", "cbar", "nosort", "noplot"}: - options[normalised] = _is_on(value) + options[normalised] = is_on(value) elif normalised in {"title", "smooth", "decimate", "limits", "caxis", "vert"}: options[normalised] = value else: @@ -511,12 +512,4 @@ def _normalise_other_options(text: Any) -> dict[str, Any]: return options -def _is_on(value: Any) -> bool: - if isinstance(value, str): - return value.strip().lower() in {"on", "yes", "true", "1"} - if isinstance(value, np.ndarray): - return bool(value.size and np.asarray(value).ravel()[0]) - return bool(value) - - __all__ = ["pop_erpimage", "pop_erpimage_dialog_spec"] diff --git a/src/eegprep/functions/popfunc/pop_export.py b/src/eegprep/functions/popfunc/pop_export.py index 14a24429..cf1336c8 100644 --- a/src/eegprep/functions/popfunc/pop_export.py +++ b/src/eegprep/functions/popfunc/pop_export.py @@ -12,7 +12,7 @@ from eegprep.functions.miscfunc.misc import finite_matmul from eegprep.functions.popfunc._file_io import channel_labels -from eegprep.functions.popfunc._pop_utils import format_history_value, parse_key_value_args +from eegprep.functions.popfunc._pop_utils import format_history_value, is_on, parse_key_value_args _EXPORT_EXPR_FUNCTIONS = { @@ -34,8 +34,8 @@ def pop_export(EEG: dict[str, Any], filename: str | Path, *args: Any, **kwargs: Any) -> str: """Export EEG data or ICA activity to a delimited text file.""" options = parse_key_value_args(args, kwargs, lowercase_kwargs=True) - data = _selected_data(EEG, ica=_is_on(options.get("ica", "off"))) - if _is_on(options.get("erp", "off")) and data.ndim == 3: + data = _selected_data(EEG, ica=is_on(options.get("ica", "off"))) + if is_on(options.get("erp", "off")) and data.ndim == 3: data = data.mean(axis=2) elif data.ndim == 3: data = data.reshape((data.shape[0], data.shape[1] * data.shape[2])) @@ -43,29 +43,29 @@ def pop_export(EEG: dict[str, Any], filename: str | Path, *args: Any, **kwargs: data = _apply_expression(data, str(options["expr"])) if data.ndim == 3: data = data.reshape((data.shape[0], data.shape[1] * data.shape[2])) - if _is_on(options.get("time", "on")): + if is_on(options.get("time", "on")): time = np.tile( np.linspace(float(EEG.get("xmin", 0)), float(EEG.get("xmax", 0)), int(EEG["pnts"])) / float(options.get("timeunit", 1e-3)), - int(EEG.get("trials", 1)) if not _is_on(options.get("erp", "off")) else 1, + int(EEG.get("trials", 1)) if not is_on(options.get("erp", "off")) else 1, ) data = np.vstack([time, data]) separator = str(options.get("separator", "\t")) precision = int(options.get("precision", 7)) - labels = ["Time", *channel_labels(EEG)] if _is_on(options.get("time", "on")) else channel_labels(EEG) + labels = ["Time", *channel_labels(EEG)] if is_on(options.get("time", "on")) else channel_labels(EEG) path = Path(filename) path.parent.mkdir(parents=True, exist_ok=True) - transpose = _is_on(options.get("transpose", "off")) + transpose = is_on(options.get("transpose", "off")) with path.open("w", newline="", encoding="utf-8") as stream: writer = csv.writer(stream, delimiter=separator) if transpose: - if _is_on(options.get("elec", "on")): + if is_on(options.get("elec", "on")): writer.writerow(labels) writer.writerows(_format_row(row, precision) for row in data.T) else: for index, row in enumerate(data): values = _format_row(row, precision) - if _is_on(options.get("elec", "on")): + if is_on(options.get("elec", "on")): values = [labels[index], *values] writer.writerow(values) return _history_command(filename, options) @@ -240,10 +240,6 @@ def _allowed_numpy_attribute(node: ast.Attribute) -> bool: return isinstance(node.value, ast.Name) and node.value.id in {"np", "numpy"} and node.attr in _EXPORT_EXPR_FUNCTIONS -def _is_on(value: Any) -> bool: - return str(value).lower() in {"on", "yes", "true", "1"} - - def _history_command(filename: str | Path, options: dict[str, Any]) -> str: pieces = [format_history_value(str(filename))] for key in ["ica", "time", "timeunit", "elec", "transpose", "erp", "expr", "precision", "separator"]: diff --git a/src/eegprep/functions/popfunc/pop_fileio.py b/src/eegprep/functions/popfunc/pop_fileio.py index beacc5f4..a8b0a397 100644 --- a/src/eegprep/functions/popfunc/pop_fileio.py +++ b/src/eegprep/functions/popfunc/pop_fileio.py @@ -2,15 +2,22 @@ from __future__ import annotations +import logging from pathlib import Path from typing import Any import mne +import scipy.io from eegprep.functions.popfunc._file_io import mne_raw_to_eeg from eegprep.functions.popfunc._pop_utils import format_history_value from eegprep.functions.popfunc.pop_importdata import pop_importdata -from eegprep.functions.popfunc.pop_loadset import pop_loadset +from eegprep.functions.popfunc.pop_loadset import _is_hdf5_file, pop_loadset + +logger = logging.getLogger(__name__) + +# Fields that mark a .mat as a saved EEGLAB dataset rather than a raw data array. +_EEGLAB_STRUCT_MARKERS = frozenset({"nbchan", "srate", "pnts", "trials", "chanlocs", "setname", "xmin", "xmax"}) def pop_fileio( @@ -22,9 +29,11 @@ def pop_fileio( if suffix == ".set": eeg = pop_loadset(str(path)) elif suffix == ".mat" and kwargs.get("dataformat") != "matlab-array": - try: + if _mat_is_eeglab_dataset(path): + logger.info("pop_fileio: loading %s as an EEGLAB dataset", path) eeg = pop_loadset(str(path)) - except Exception: + else: + logger.info("pop_fileio: importing %s as a raw MATLAB data array", path) eeg = pop_importdata("data", str(path), "setname", path.stem, "dataformat", "matlab", **kwargs) elif suffix in {".csv", ".txt", ".tsv", ".npy", ".npz"}: eeg = pop_importdata("data", str(path), "setname", path.stem, **kwargs) @@ -37,6 +46,18 @@ def pop_fileio( return (eeg, command) if return_com else eeg +def _mat_is_eeglab_dataset(path: Path) -> bool: + """Return True when a .mat file holds an EEGLAB dataset rather than a raw data array. + + A dataset is recognized by an ``EEG`` struct variable or by top-level EEGLAB marker + fields (a .set saved with ``-struct``). MAT v7.3 files are HDF5 and always EEGLAB sets. + """ + if _is_hdf5_file(path): + return True + names = {name for name, _shape, _cls in scipy.io.whosmat(str(path))} + return "EEG" in names or bool(names & _EEGLAB_STRUCT_MARKERS) + + def _reader_for_suffix(suffix: str): if suffix in {".edf", ".bdf"}: return mne.io.read_raw_edf diff --git a/src/eegprep/functions/popfunc/pop_headplot.py b/src/eegprep/functions/popfunc/pop_headplot.py index ad263388..3aa43e34 100644 --- a/src/eegprep/functions/popfunc/pop_headplot.py +++ b/src/eegprep/functions/popfunc/pop_headplot.py @@ -2,6 +2,7 @@ from __future__ import annotations +import copy import math from pathlib import Path from typing import Any @@ -56,6 +57,7 @@ def pop_headplot( if EEG is None: return ([], "") if return_com else [] + EEG = copy.deepcopy(EEG) typeplot = int(typeplot) if typeplot not in {0, 1}: raise ValueError("typeplot must be 1 for ERP maps or 0 for component maps") diff --git a/src/eegprep/functions/popfunc/pop_jointprob.py b/src/eegprep/functions/popfunc/pop_jointprob.py index 1fd87fc8..314cbfc3 100644 --- a/src/eegprep/functions/popfunc/pop_jointprob.py +++ b/src/eegprep/functions/popfunc/pop_jointprob.py @@ -8,17 +8,9 @@ from eegprep.functions.guifunc.inputgui import inputgui from eegprep.functions.guifunc.spec import ControlSpec, DialogSpec -from eegprep.functions.popfunc._eegplot_rejection import open_epoched_rejection_browser +from eegprep.functions.popfunc._eegplot_rejection import run_epoched_rejection, vistype_from_gui from eegprep.functions.popfunc._pop_utils import format_history_value -from eegprep.functions.popfunc._rejection import ( - copy_eeg, - jointprob_marks, - one_based_indices, - parse_numeric_sequence, - rejection_data, - update_reject_fields, -) -from eegprep.functions.popfunc.pop_rejepoch import pop_rejepoch +from eegprep.functions.popfunc._rejection import jointprob_marks, parse_numeric_sequence def pop_jointprob( @@ -146,19 +138,10 @@ def _run_gui(EEG: dict[str, Any], icacomp: int, renderer: Any | None) -> tuple[A result.get("globthresh", threshold_default), int(bool(result.get("superpose", True))), int(bool(result.get("reject", False))), - _vistype_from_gui(result.get("vistype", 2)), + vistype_from_gui(result.get("vistype", 2)), ) -def _vistype_from_gui(value: Any) -> int: - if isinstance(value, str): - return 0 if value.strip().lower() in {"rejecttrials", "reject trials", "0"} else 1 - try: - return 0 if int(value) == 1 else 1 - except (TypeError, ValueError): - return 1 - - def _apply_one( EEG: dict[str, Any], icacomp: int | bool, @@ -173,43 +156,28 @@ def _apply_one( command_callback: Any | None = None, show: bool = True, ) -> tuple[dict[str, Any], list[float], list[float], list[int], str]: - out = copy_eeg(EEG) - data, row_count = rejection_data(out, icacomp) - if int(out.get("trials", data.shape[2]) or data.shape[2]) <= 1: - raise ValueError("pop_jointprob requires epoched data") - elecrange = one_based_indices(elecrange, limit=row_count, default_all=True) - marks, marks_e, local_scores, global_scores = jointprob_marks(data, elecrange, locthresh, globthresh) - out.setdefault("stats", {}) - if int(bool(icacomp)): - out["stats"]["jpE"] = local_scores - out["stats"]["jp"] = global_scores - else: - out["stats"]["icajpE"] = local_scores - out["stats"]["icajp"] = global_scores - update_reject_fields(out, icacomp=icacomp, kind="rejjp", reject=marks, reject_e=marks_e) - rejected = (np.flatnonzero(marks) + 1).tolist() - command = _history_command(icacomp, elecrange, locthresh, globthresh, superpose, reject, vistype, 0) - if display: - open_epoched_rejection_browser( - out, - data=data, - icacomp=icacomp, - elecrange=elecrange, - kind="rejjp", - superpose=superpose, - reject=reject, - command=command, - command_callback=command_callback, - show=show, - ) - elif int(bool(reject)) and rejected: - out = pop_rejepoch(out, rejected, 0) - return ( - out, - parse_numeric_sequence(locthresh, dtype=float), - parse_numeric_sequence(globthresh, dtype=float), - rejected, - command, + return run_epoched_rejection( + EEG, + icacomp, + elecrange, + locthresh, + globthresh, + superpose, + reject, + vistype, + marks_fn=jointprob_marks, + kind="rejjp", + stats_local_field="jpE", + stats_global_field="jp", + stats_local_field_ica="icajpE", + stats_global_field_ica="icajp", + error_message="pop_jointprob requires epoched data", + command_fn=lambda normalized_elecrange: _history_command( + icacomp, normalized_elecrange, locthresh, globthresh, superpose, reject, vistype, 0 + ), + display=display, + command_callback=command_callback, + show=show, ) diff --git a/src/eegprep/functions/popfunc/pop_load_frombids.py b/src/eegprep/functions/popfunc/pop_load_frombids.py index 4e21a456..ebf0b27c 100644 --- a/src/eegprep/functions/popfunc/pop_load_frombids.py +++ b/src/eegprep/functions/popfunc/pop_load_frombids.py @@ -2,6 +2,7 @@ import os import copy +from importlib.resources import files from typing import Dict, Any, Tuple, Union, Optional import logging import warnings @@ -930,12 +931,10 @@ def error(msg: str): fractions = [] caplabels = [] - # Determine montage path and files to check - # Resources are now always in the package directory - # Resources live at the package root (src/eegprep/resources/), not - # next to this file which was moved during the EEGLAB-style reorg. - _pkg_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) - montage_path = os.path.join(_pkg_root, 'resources', 'montages') + # Determine montage path and files to check. Resolve the packaged + # montages directory through importlib.resources so the lookup does not + # depend on this module's location on disk. + montage_path = str(files("eegprep").joinpath("resources").joinpath("montages")) if not os.path.isdir(montage_path): raise RuntimeError( @@ -1083,7 +1082,7 @@ def error(msg: str): EEG = eeg_checkchanlocs(EEG) except ImportError: - print("eeg_checkchanlocs not available, skipping channel location check.") + logger.info("eeg_checkchanlocs not available, skipping channel location check.") # Assign channel types based on channel labels (matching MATLAB's eeg_getchantype behavior) # Standard 10-20 channel names that should be classified as EEG diff --git a/src/eegprep/functions/popfunc/pop_loadset.py b/src/eegprep/functions/popfunc/pop_loadset.py index 5251e46a..028b28b7 100644 --- a/src/eegprep/functions/popfunc/pop_loadset.py +++ b/src/eegprep/functions/popfunc/pop_loadset.py @@ -3,12 +3,13 @@ import os from pathlib import Path +import h5py import numpy as np import scipy.io from eegprep.functions.adminfunc.storage import memmap_enabled, memmap_fdt, read_fdt from eegprep.functions.popfunc._file_io import normalize_icachansind -from eegprep.functions.popfunc._pop_utils import parse_key_value_args +from eegprep.functions.popfunc._pop_utils import is_on, parse_key_value_args from eegprep.functions.popfunc.pop_loadset_h5 import pop_loadset_h5 # Allows access using . notation # class EEG: @@ -81,16 +82,16 @@ def new_check(obj): dict_obj[field_name] = new_check(field_value) return dict_obj - # Load MATLAB file - loaded_with_h5 = False - try: + # Load MATLAB file. MAT v7.3 files are HDF5; older v5/v7 files are not. + # Dispatch on the real format instead of treating every scipy error as "must be HDF5". + loaded_with_h5 = _is_hdf5_file(file_path) + if loaded_with_h5: + EEG = pop_loadset_h5(file_path) + else: EEG = scipy.io.loadmat(file_path, struct_as_record=False, squeeze_me=True, appendmat=False) EEG = new_check(EEG) if 'EEG' in EEG: EEG = EEG['EEG'] - except Exception: - EEG = pop_loadset_h5(file_path) - loaded_with_h5 = True EEG['filepath'] = os.path.dirname(file_path) EEG['filename'] = os.path.basename(file_path) @@ -128,6 +129,15 @@ def new_check(obj): return EEG +def _is_hdf5_file(file_path): + """Return True when the file is HDF5 (MAT v7.3). + + MAT v7.3 files carry a text header in an HDF5 userblock, so the signature is not at + byte 0; ``h5py.is_hdf5`` checks the userblock offsets HDF5 actually uses. + """ + return h5py.is_hdf5(os.fspath(file_path)) + + def _load_options(file_path, args, kwargs, loadmode, memmap): known_keys = {"filename", "filepath", "loadmode", "memmap", "check", "verbose", "eeg"} if isinstance(file_path, str) and file_path.lower() in known_keys: @@ -151,7 +161,7 @@ def _load_options(file_path, args, kwargs, loadmode, memmap): path = Path(os.fspath(filename)) if filepath not in {None, ""} and not path.is_absolute(): path = Path(os.fspath(filepath)) / path - use_memmap = memmap_enabled() if memmap is None else _is_on(memmap) + use_memmap = memmap_enabled() if memmap is None else is_on(memmap) return str(path), loadmode, use_memmap @@ -180,25 +190,6 @@ def _string_value(value): return str(value) -def _is_on(value): - if isinstance(value, str): - return value.strip().lower() in {"1", "on", "true", "yes"} - return bool(value) - - -def test_pop_loadset(): - """Test the pop_loadset function with a sample file.""" - file_path = './tmp2.set' - file_path = '/System/Volumes/Data/data/data/STUDIES/STERN/S04/Memorize.set' #'./eeglab_data_with_ica_tmp.set' - EEG = pop_loadset(file_path) - - # print the keys of the EEG dictionary - print(EEG.keys()) - - -if __name__ == "__main__": - test_pop_loadset() - # STILL OPEN QUESTION: Better to have empty MATLAB arrays as None for empty numpy arrays (current default). # The current default is to make it more MALTAB compatible. A lot of MATLAB function start indexing MATLAB # empty arrays to add values to them. This is not possible with None and would create more conversion and diff --git a/src/eegprep/functions/popfunc/pop_loadset_h5.py b/src/eegprep/functions/popfunc/pop_loadset_h5.py index 0c491e24..a513d3eb 100644 --- a/src/eegprep/functions/popfunc/pop_loadset_h5.py +++ b/src/eegprep/functions/popfunc/pop_loadset_h5.py @@ -30,12 +30,6 @@ def pop_loadset_h5(file_name): def convert_to_string(filecontent): if isinstance(filecontent, np.ndarray): if filecontent.dtype == 'uint16': - # Special handling for the test case with emoji - if len(filecontent) == 10 and np.array_equal( - filecontent, np.array([104, 101, 108, 108, 111, 32, 240, 159, 146, 150]) - ): - return 'hello 👖' - # Convert uint16 array to bytes and then decode as UTF-8 try: # Convert uint16 values to bytes @@ -301,16 +295,3 @@ def handle_generic_group(EEGTMP, key): EEG = eeg_checkset(EEG) return EEG - - -if __name__ == '__main__': - file_name = 'sample_data/eeglab_data_epochs_ica_hdf5.set' - EEG = pop_loadset_h5(file_name) - print(EEG['data'].shape) - print(EEG['icaweights'].shape) - print(EEG['icasphere'].shape) - print(EEG['icawinv'].shape) - print(EEG['icaact'].shape) -# file_name = 'eeglab_cont73.set' -# EEG = pop_loadset_h5(file_name) -# EEG['data'].shape diff --git a/src/eegprep/functions/popfunc/pop_newcrossf.py b/src/eegprep/functions/popfunc/pop_newcrossf.py index 614108b9..518e2d64 100644 --- a/src/eegprep/functions/popfunc/pop_newcrossf.py +++ b/src/eegprep/functions/popfunc/pop_newcrossf.py @@ -56,7 +56,6 @@ def pop_newcrossf( tlimits = [float(EEG.get("xmin", 0)) * 1000.0, float(EEG.get("xmax", 0)) * 1000.0] if cycles is None: cycles = [3, 0.5] - _reject_unsupported_options(options) signal1, signal2, times = _selected_signals(EEG, typeproc, num1, num2, tlimits) result = newcrossf( signal1, signal2, signal1.shape[0], [times[0], times[-1]], float(EEG.get("srate", 1) or 1), cycles, **options @@ -182,10 +181,6 @@ def _selected_signals( return acts[first, :, :], acts[second, :, :], full_times -def _reject_unsupported_options(options: dict[str, Any]) -> None: - _ = options - - def _first_index(value: Any) -> int: values = numeric_vector(value, dtype=int) if values.size != 1: diff --git a/src/eegprep/functions/popfunc/pop_newset.py b/src/eegprep/functions/popfunc/pop_newset.py index 22c2b545..47837f19 100644 --- a/src/eegprep/functions/popfunc/pop_newset.py +++ b/src/eegprep/functions/popfunc/pop_newset.py @@ -10,7 +10,7 @@ from eegprep.functions.adminfunc.eeg_store import eeg_store from eegprep.functions.guifunc.inputgui import inputgui from eegprep.functions.guifunc.spec import CallbackSpec, ControlSpec, DialogSpec -from eegprep.functions.popfunc._pop_utils import format_history_value, parse_key_value_args +from eegprep.functions.popfunc._pop_utils import format_history_value, is_on, parse_key_value_args from eegprep.functions.popfunc.pop_saveset import pop_saveset @@ -46,7 +46,7 @@ def pop_newset( command = _history_command({"retrieve": retrieve}) return alleeg, current, current_set, command - if _is_on(options.get("gui", False)) and isinstance(EEG, dict): + if is_on(options.get("gui", False)) and isinstance(EEG, dict): gui_result = _run_gui(EEG, CURRENTSET, options, renderer=renderer) if gui_result is None: current, alleeg, current_set = eeg_retrieve(alleeg, CURRENTSET or 1) @@ -72,6 +72,7 @@ def pop_newset( def pop_newset_dialog_spec(EEG: dict[str, Any], CURRENTSET: Any = None, *, guistring: str = "") -> DialogSpec: """Return the EEGLAB-like dialog spec for ``pop_newset``.""" dataset_name = str(EEG.get("setname") or "") + comments = _comments_text(EEG.get("comments", "")) prompt = guistring or "What do you want to do with the new dataset?" old_prompt = "What do you want to do with the old dataset (not modified since last saved)?" return DialogSpec( @@ -93,11 +94,13 @@ def pop_newset_dialog_spec(EEG: dict[str, Any], CURRENTSET: Any = None, *, guist "Edit description", tag="editdescription", callback=CallbackSpec( - "show_message", + "edit_text", params={ "button": "editdescription", + "target": "editdescription", "title": "Edit description", - "message": "Dataset description editing is available from the command line using the comments option.", + "label": "Dataset description:", + "value": comments, }, ), ), @@ -149,9 +152,13 @@ def _run_gui( "overwrite": "on" if overwrite else "off", "gui": "off", } - if "comments" in result: - gui_options["comments"] = str(result.get("comments") or "") - if _is_on(result.get("savenew")): + comments = result.get("comments") + edited_comments = result.get("editdescription") + if comments is None and isinstance(edited_comments, str): + comments = edited_comments + if comments is not None: + gui_options["comments"] = str(comments) + if is_on(result.get("savenew")): gui_options["savenew"] = str(result.get("savefile") or "on").strip() or "on" return gui_options @@ -178,7 +185,7 @@ def _store_index( ) -> int | list[int] | None: if isinstance(EEG, list): return list(CURRENTSET) if isinstance(CURRENTSET, (list, tuple)) and len(CURRENTSET) == len(EEG) else None - if _is_on(options.get("overwrite", False)): + if is_on(options.get("overwrite", False)): return _first_currentset(CURRENTSET) or None return None @@ -241,12 +248,6 @@ def _history_command(options: dict[str, Any]) -> str: return f"[ALLEEG EEG CURRENTSET] = pop_newset(ALLEEG, EEG, CURRENTSET, {', '.join(parts)});" -def _is_on(value: Any) -> bool: - if isinstance(value, str): - return value.strip().lower() in {"1", "on", "true", "yes"} - return bool(value) - - def _should_save(value: Any) -> bool: if isinstance(value, str): return value.strip().lower() not in {"", "0", "off", "false", "no"} @@ -272,3 +273,13 @@ def _currentset_label(CURRENTSET: Any) -> str: if isinstance(CURRENTSET, (list, tuple)): return ", ".join(str(item) for item in CURRENTSET) if CURRENTSET else "0" return str(CURRENTSET or 0) + + +def _comments_text(value: Any) -> str: + if value is None: + return "" + if isinstance(value, str): + return value + if isinstance(value, (list, tuple)): + return "\n".join(str(item) for item in value) + return str(value) diff --git a/src/eegprep/functions/popfunc/pop_rejkurt.py b/src/eegprep/functions/popfunc/pop_rejkurt.py index 828aabdc..e6263638 100644 --- a/src/eegprep/functions/popfunc/pop_rejkurt.py +++ b/src/eegprep/functions/popfunc/pop_rejkurt.py @@ -8,17 +8,9 @@ from eegprep.functions.guifunc.inputgui import inputgui from eegprep.functions.guifunc.spec import ControlSpec, DialogSpec -from eegprep.functions.popfunc._eegplot_rejection import open_epoched_rejection_browser +from eegprep.functions.popfunc._eegplot_rejection import run_epoched_rejection, vistype_from_gui from eegprep.functions.popfunc._pop_utils import format_history_value -from eegprep.functions.popfunc._rejection import ( - copy_eeg, - kurtosis_marks, - one_based_indices, - parse_numeric_sequence, - rejection_data, - update_reject_fields, -) -from eegprep.functions.popfunc.pop_rejepoch import pop_rejepoch +from eegprep.functions.popfunc._rejection import kurtosis_marks, parse_numeric_sequence def pop_rejkurt( @@ -150,19 +142,10 @@ def _run_gui(EEG: dict[str, Any], icacomp: int, renderer: Any | None) -> tuple[A result.get("globthresh", threshold_default), int(bool(result.get("superpose", True))), int(bool(result.get("reject", False))), - _vistype_from_gui(result.get("vistype", 2)), + vistype_from_gui(result.get("vistype", 2)), ) -def _vistype_from_gui(value: Any) -> int: - if isinstance(value, str): - return 0 if value.strip().lower() in {"rejecttrials", "reject trials", "0"} else 1 - try: - return 0 if int(value) == 1 else 1 - except (TypeError, ValueError): - return 1 - - def _apply_one( EEG: dict[str, Any], icacomp: int | bool, @@ -177,43 +160,28 @@ def _apply_one( command_callback: Any | None = None, show: bool = True, ) -> tuple[dict[str, Any], list[float], list[float], list[int], str]: - out = copy_eeg(EEG) - data, row_count = rejection_data(out, icacomp) - if int(out.get("trials", data.shape[2]) or data.shape[2]) <= 1: - raise ValueError("pop_rejkurt requires epoched data") - elecrange = one_based_indices(elecrange, limit=row_count, default_all=True) - marks, marks_e, local_scores, global_scores = kurtosis_marks(data, elecrange, locthresh, globthresh) - out.setdefault("stats", {}) - if int(bool(icacomp)): - out["stats"]["kurtE"] = local_scores - out["stats"]["kurt"] = global_scores - else: - out["stats"]["icakurtE"] = local_scores - out["stats"]["icakurt"] = global_scores - update_reject_fields(out, icacomp=icacomp, kind="rejkurt", reject=marks, reject_e=marks_e) - rejected = (np.flatnonzero(marks) + 1).tolist() - command = _history_command(icacomp, elecrange, locthresh, globthresh, superpose, reject, vistype, 0) - if display: - open_epoched_rejection_browser( - out, - data=data, - icacomp=icacomp, - elecrange=elecrange, - kind="rejkurt", - superpose=superpose, - reject=reject, - command=command, - command_callback=command_callback, - show=show, - ) - elif int(bool(reject)) and rejected: - out = pop_rejepoch(out, rejected, 0) - return ( - out, - parse_numeric_sequence(locthresh, dtype=float), - parse_numeric_sequence(globthresh, dtype=float), - rejected, - command, + return run_epoched_rejection( + EEG, + icacomp, + elecrange, + locthresh, + globthresh, + superpose, + reject, + vistype, + marks_fn=kurtosis_marks, + kind="rejkurt", + stats_local_field="kurtE", + stats_global_field="kurt", + stats_local_field_ica="icakurtE", + stats_global_field_ica="icakurt", + error_message="pop_rejkurt requires epoched data", + command_fn=lambda normalized_elecrange: _history_command( + icacomp, normalized_elecrange, locthresh, globthresh, superpose, reject, vistype, 0 + ), + display=display, + command_callback=command_callback, + show=show, ) diff --git a/src/eegprep/functions/popfunc/pop_reref_helper.py b/src/eegprep/functions/popfunc/pop_reref_helper.py deleted file mode 100644 index da646f8f..00000000 --- a/src/eegprep/functions/popfunc/pop_reref_helper.py +++ /dev/null @@ -1,22 +0,0 @@ -"""Helper script for re-referencing EEG data.""" - -from .pop_loadset import pop_loadset -from .pop_saveset import pop_saveset -from .pop_reref import pop_reref -import sys - -# check if a parameter is present and if it is assign eeglab_file_path to it -if len(sys.argv) > 2: - eeglab_file_path_in = sys.argv[1] - eeglab_file_path_out = sys.argv[2] -else: - eeglab_file_path_in = './eeglab_data_with_ica_tmp.set' - eeglab_file_path_out = './eeglab_data_with_ica_tmp_averef.set' - -EEG = pop_loadset(eeglab_file_path_in) - -# Print the loaded data -EEG = pop_reref(EEG, []) - -# save dataset -pop_saveset(EEG, eeglab_file_path_out) diff --git a/src/eegprep/functions/popfunc/pop_runica.py b/src/eegprep/functions/popfunc/pop_runica.py index 537f2e36..87af3b1c 100644 --- a/src/eegprep/functions/popfunc/pop_runica.py +++ b/src/eegprep/functions/popfunc/pop_runica.py @@ -13,7 +13,7 @@ from eegprep.functions.guifunc.inputgui import inputgui from eegprep.functions.guifunc.spec import CallbackSpec, ControlSpec, DialogSpec from eegprep.functions.popfunc._ica_utils import flatten_ica_data -from eegprep.functions.popfunc._pop_utils import format_history_value, parse_key_value_args +from eegprep.functions.popfunc._pop_utils import format_history_value, is_on, parse_key_value_args from eegprep.functions.popfunc.eeg_amica import eeg_amica from eegprep.functions.popfunc.eeg_decodechan import eeg_decodechan from eegprep.functions.popfunc.eeg_picard import eeg_picard @@ -60,7 +60,7 @@ def pop_runica( elif gui is None: gui = options is None and not has_programmatic_options and chanind is None and dataset is None if gui: - gui_result = _run_gui(EEG, renderer=renderer, initial_values=_selectamica_initial_values(selectamica)) + gui_result = pop_runica_gui_options(EEG, renderer=renderer, selectamica=selectamica) if gui_result is None: return (EEG, "") if return_com else EEG icatype = gui_result["icatype"] @@ -70,9 +70,6 @@ def pop_runica( dataset = gui_result["dataset"] concatenate = gui_result["concatenate"] concatcond = gui_result["concatcond"] - if icatype == "runica": - options = dict(options) - options.setdefault("interrupt", "on") ica_options = _normalise_ica_options(icatype, options, parsed) if isinstance(EEG, list): @@ -202,6 +199,18 @@ def _run_gui(EEG, renderer=None, initial_values=None): } +def pop_runica_gui_options(EEG, *, renderer=None, selectamica: str | None = None) -> dict[str, Any] | None: + """Collect ``pop_runica`` GUI options without running the ICA backend.""" + gui_result = _run_gui(EEG, renderer=renderer, initial_values=_selectamica_initial_values(selectamica)) + if gui_result is None: + return None + if gui_result["icatype"] == "runica": + options = dict(gui_result["options"]) + options.setdefault("interrupt", "on") + gui_result["options"] = options + return gui_result + + def _runica_on_dataset(EEG, icatype, options, *, reorder, chanind): logger.info("Attempting to convert data matrix to double precision...") prepared = _prepare_ica_dataset(EEG) @@ -230,10 +239,10 @@ def _runica_on_datasets(EEG, *, dataset, icatype, options, reorder, chanind, con indices = _dataset_indices(output, dataset) selected = [output[index] for index in indices] logger.info("NOW RUNNING ALL DECOMPOSITIONS") - if _is_on(concatcond): + if is_on(concatcond): logger.info("Concatenating datasets by subject and session.") updated = _runica_by_subject_session(selected, icatype, options, reorder=reorder, chanind=chanind) - elif _is_on(concatenate): + elif is_on(concatenate): logger.info("Concatenating datasets...") updated = _runica_concatenated(selected, icatype, options, reorder=reorder, chanind=chanind) else: @@ -404,6 +413,8 @@ def _picard_options(options): lower_key = str(key).lower() if lower_key == "maxiter": mapped["max_iter"] = value + elif lower_key == "seed": + mapped["random_state"] = int(value) elif lower_key == "mode": if str(value).lower() == "standard": mapped["ortho"] = False @@ -575,16 +586,12 @@ def _history_command(icatype, options, reorder, chanind, *, dataset=None, concat parts.extend(["'reorder'", _runica_history_value(reorder)]) if chanind is not None: parts.extend(["'chanind'", _runica_history_value(chanind)]) - if _is_on(concatenate): + if is_on(concatenate): parts.extend(["'concatenate'", "'on'"]) - if _is_on(concatcond): + if is_on(concatcond): parts.extend(["'concatcond'", "'on'"]) return f"EEG = pop_runica(EEG, {', '.join(parts)});" def _runica_history_value(value): return format_history_value(value, cell_for_sequence=None) - - -def _is_on(value): - return str(value).lower() in {"on", "yes", "true", "1"} diff --git a/src/eegprep/functions/popfunc/pop_saveset.py b/src/eegprep/functions/popfunc/pop_saveset.py index e09df3a0..641c2715 100644 --- a/src/eegprep/functions/popfunc/pop_saveset.py +++ b/src/eegprep/functions/popfunc/pop_saveset.py @@ -1,5 +1,6 @@ """EEG data saving and loading utilities.""" +import copy import os from pathlib import Path @@ -88,6 +89,21 @@ def _matlab_empty_if_missing(EEG, key): return default_empty if value is None else value +def _matlab_empty_or_copy(EEG, key): + """Like ``_matlab_empty_if_missing`` but deep-copies struct-array fields. + + Saving applies in-place 1-based offsets and latency coercion to the + MATLAB-facing structures; copying first keeps the caller's chanlocs/event + dicts (0-based urchan/urevent, untouched latencies) intact. ``chanlocs`` + and ``event`` can be Python lists or NumPy object arrays of dicts, so both + are deep-copied. + """ + value = _matlab_empty_if_missing(EEG, key) + if isinstance(value, list) or (isinstance(value, np.ndarray) and value.dtype == object): + return copy.deepcopy(value) + return value + + def _matlab_empty_struct_if_missing(EEG, key): """Return an EEGLAB empty array for optional empty struct-like fields.""" value = _matlab_empty_if_missing(EEG, key) @@ -377,11 +393,11 @@ def pop_saveset(EEG, file_name=None, *args, **kwargs): 'icasphere': _matlab_empty_if_missing(EEG, 'icasphere'), 'icaweights': _matlab_empty_if_missing(EEG, 'icaweights'), 'icachansind': _matlab_empty_if_missing(EEG, 'icachansind').copy(), - 'chanlocs': _matlab_empty_if_missing(EEG, 'chanlocs'), + 'chanlocs': _matlab_empty_or_copy(EEG, 'chanlocs'), 'urchanlocs': _matlab_empty_if_missing(EEG, 'urchanlocs'), 'chaninfo': _serialize_chaninfo(EEG.get('chaninfo', {})), 'ref': EEG.get('ref', 'common'), - 'event': _matlab_empty_if_missing(EEG, 'event'), + 'event': _matlab_empty_or_copy(EEG, 'event'), 'urevent': _matlab_empty_if_missing(EEG, 'urevent'), 'eventdescription': _matlab_empty_if_missing(EEG, 'eventdescription'), 'epoch': _matlab_empty_if_missing(EEG, 'epoch'), @@ -416,53 +432,10 @@ def pop_saveset(EEG, file_name=None, *args, **kwargs): for i in range(len(eeglab_dict['event'])): eeglab_dict['event'][i]['urevent'] = eeglab_dict['event'][i]['urevent'] + 1 - # Create the list of dictionaries with a string field + # Serialize chanlocs through the single canonical chanloc converter so the + # primary channel struct uses the same schema as chaninfo.removedchans. if 'chanlocs' in EEG and len(EEG['chanlocs']) > 0: - matlab_null = np.array([]) - d_list = [ - { - 'labels': c['labels'], - 'theta': c['theta'] if not isinstance(c.get('theta', matlab_null), np.ndarray) else None, - 'radius': c['radius'] if not isinstance(c.get('radius', matlab_null), np.ndarray) else None, - 'X': c['X'] if not isinstance(c.get('X', matlab_null), np.ndarray) else None, - 'Y': c['Y'] if not isinstance(c.get('Y', matlab_null), np.ndarray) else None, - 'Z': c['Z'] if not isinstance(c.get('Z', matlab_null), np.ndarray) else None, - 'sph_theta': c['sph_theta'] if not isinstance(c.get('sph_theta', matlab_null), np.ndarray) else None, - 'sph_phi': c['sph_phi'] if not isinstance(c.get('sph_phi', matlab_null), np.ndarray) else None, - 'sph_radius': c['sph_radius'] if not isinstance(c.get('sph_radius', matlab_null), np.ndarray) else None, - 'type': c['type'] if not isinstance(c.get('type', matlab_null), np.ndarray) else None, - 'urchan': c['urchan'] if not isinstance(c.get('urchan', matlab_null), np.ndarray) else None, - 'ref': c['ref'] if not isinstance(c.get('ref', matlab_null), np.ndarray) else None, - } - for c in EEG['chanlocs'] - ] - - # build a list of fields to selectively filter out if all entries are None - retain_fields = [fld for fld in d_list[0].keys() if not all(d[fld] is None for d in d_list)] - - dtype = np.dtype( - [ - (f, t) - for f, t in [ - ('labels', 'U100'), # String up to 100 characters - ('theta', np.float64), - ('radius', np.float64), - ('X', np.float64), - ('Y', np.float64), - ('Z', np.float64), - ('sph_theta', np.float64), - ('sph_phi', np.float64), - ('sph_radius', np.float64), - ('type', 'U10'), # String up to 10 characters - ('urchan', np.int32), - ('ref', 'U100'), # String up to 100 characters - ] - if f in retain_fields - ] - ) - - # Convert the list of dictionaries to a structured NumPy array - eeglab_dict['chanlocs'] = np.array([tuple(item[fld] for fld in retain_fields) for item in d_list], dtype=dtype) + eeglab_dict['chanlocs'] = _chanlocs_to_struct_array(eeglab_dict['chanlocs']) # Normalize event latencies to float before saving so MATLAB loads them # as double. Without this, integer stimulus latencies become int64 and @@ -551,23 +524,6 @@ def _save_two_files(EEG, savemode): return bool(datfile and savemode == 'resave') or savetwofiles_enabled() -def test_pop_saveset(): - """Test pop_saveset function.""" - from eegprep.functions.popfunc.pop_loadset import pop_loadset - - file_path = './sample_data/eeglab_data_with_ica_tmp.set' - EEG = pop_loadset(file_path) - pop_saveset(EEG, '/Users/arno/Python/eegprep/sample_data/tmp.set') - pop_saveset_old( - EEG, '/Users/arno/Python/eegprep/sample_data/tmp2.set' - ) # does not do events and function above is better - # print the keys of the EEG dictionary - print(EEG.keys()) - - -if __name__ == '__main__': - test_pop_saveset() - # STILL OPEN QUESTION: Better to have empty MATLAB arrays as None for empty numpy arrays (current default). # The current default is to make it more MALTAB compatible. A lot of MATLAB function start indexing MATLAB # empty arrays to add values to them. This is not possible with None and would create more conversion and diff --git a/src/eegprep/functions/popfunc/pop_select.py b/src/eegprep/functions/popfunc/pop_select.py index 38a8926c..0b419c66 100644 --- a/src/eegprep/functions/popfunc/pop_select.py +++ b/src/eegprep/functions/popfunc/pop_select.py @@ -1,4 +1,5 @@ import copy +import logging import re from typing import Any @@ -18,6 +19,9 @@ from eegprep.functions.popfunc.eeg_eegrej import eeg_eegrej +logger = logging.getLogger(__name__) + + def pop_select(EEG, *args, gui=None, renderer=None, return_com=False, **kwargs): """Select EEG data using EEGLAB ``pop_select`` semantics.""" options = parse_key_value_args(args, kwargs) @@ -55,6 +59,7 @@ def _pop_select_apply(EEG, **kwargs): ------- EEG_out, com """ + EEG = copy.deepcopy(EEG) # shallow options with MATLAB-compatible aliases g = { 'time': kwargs.get('time', []), # seconds; can be Nx2 for continuous @@ -184,7 +189,7 @@ def _numeric_channel_indices(values): inds, _ = eeg_decodechan(EEG, g['channel'], 'labels', True) # show warning if not all channels are found and error if no channels are found if len(inds) != len(g['channel']): - print(f"Warning: {len(g['channel']) - len(inds)} channels not found") + logger.warning("%s channels not found", len(g['channel']) - len(inds)) if len(inds) == 0: raise ValueError(f"Channels not found: {g['channel']}") chan_selected_flag[:] = False @@ -199,17 +204,17 @@ def _numeric_channel_indices(values): chan_selected_flag[np.array(inds, dtype=int)] = False # show warning if not all channels are found and error if no channels are found if len(inds) != len(g['nochannel']): - print(f"Warning: {len(g['nochannel']) - len(inds)} channels not found") + logger.warning("%s channels not found", len(g['nochannel']) - len(inds)) else: # by type if _decode_list(g['chantype']): - inds = eeg_decodechan(EEG, g['chantype'], 'type', True) + inds, _ = eeg_decodechan(EEG, g['chantype'], 'type', True) chan_selected_flag[:] = False chan_selected_flag[np.array(inds, dtype=int)] = True if _decode_list(g['rmchantype']): - inds = eeg_decodechan(EEG, g['rmchantype'], 'type', True) + inds, _ = eeg_decodechan(EEG, g['rmchantype'], 'type', True) chan_selected_flag[np.array(inds, dtype=int)] = False g['channel'] = np.where(chan_selected_flag)[0].tolist() @@ -226,7 +231,7 @@ def _normalize_range_matrix(x): if x.size <= 2: return np.array(x).reshape(1, 2) # vector form → [first last] - print('Warning: vector format for point/time range is deprecated') + logger.warning("Vector format for point/time range is deprecated") return np.array([x[0], x[-1]], dtype=float).reshape(1, 2) if x.shape[1] != 2: raise ValueError('Time/point range must contain exactly 2 columns') @@ -289,14 +294,14 @@ def _clip_time_matrix(mat): # 4) Informational prints (optional) if len(g['trial']) != trials: - print(f"Removing {trials - len(g['trial'])} trial(s)...") + logger.info("Removing %s trial(s)...", trials - len(g['trial'])) if len(g['channel']) != nbchan: - print(f"Removing {nbchan - len(g['channel'])} channel(s)...") + logger.info("Removing %s channel(s)...", nbchan - len(g['channel'])) # 5) Recompute event epoch indices and latencies when trials are dropped if len(g['trial']) != trials and (EEG.get('event') is not None and len(EEG.get('event', [])) > 0): if not any('epoch' in ev for ev in EEG['event']): - print('Pop_epoch warning: bad event format with epoch dataset, removing events') + logger.warning("Bad event format with epoch dataset, removing events") EEG['event'] = [] else: keepevent = [] @@ -317,7 +322,7 @@ def _clip_time_matrix(mat): ev['epoch'] = int(newindex[0] + 1) # back to 1-based for consistency diffevent = np.setdiff1d(np.arange(len(EEG['event'])), np.array(keepevent, dtype=int)) if diffevent.size: - print(f"Pop_select: removing {diffevent.size} unreferenced events") + logger.info("Removing %s unreferenced events", diffevent.size) EEG['event'] = [EEG['event'][i] for i in range(len(EEG['event'])) if i in keepevent] # 6) Apply time selection @@ -364,12 +369,11 @@ def _clip_time_matrix(mat): newevents = [] for ev in EEG['event']: if 'epoch' in ev and 'latency' in ev: - e = copy.deepcopy(ev) # within-epoch latency shift by (a-1) samples - e['latency'] = e['latency'] - (a - 1) + ev['latency'] = ev['latency'] - (a - 1) # keep only events that remain inside the cropped window - if 1 <= e['latency'] <= pnts * len(g['trial']): - newevents.append(e) + if 1 <= ev['latency'] <= pnts * len(g['trial']): + newevents.append(ev) EEG['event'] = newevents # erase epoch-level event fields @@ -434,7 +438,7 @@ def _clip_time_matrix(mat): # erase dipfit if channels removed if len(chan_idx) != nbchan and _has_content(EEG.get('dipfit')): - print('warning: erasing dipole information since channels have been removed') + logger.warning("Erasing dipole information since channels have been removed") EEG['dipfit'] = np.array([]) EEG['roi'] = {} @@ -769,11 +773,3 @@ def _history_command(options): for key, value in options.items(): parts.extend([f"'{key}'", format_history_value(value, empty_sequence="{}")]) return f"EEG = pop_select( EEG, {', '.join(parts)});" - - -if __name__ == '__main__': - from eegprep.functions.popfunc.pop_loadset import pop_loadset - - EEG = pop_loadset('sample_data/eeglab_data.set') - EEG2 = pop_select(EEG, channel=['FP1', 'FP2']) - print(EEG2) diff --git a/src/eegprep/functions/popfunc/pop_spectopo.py b/src/eegprep/functions/popfunc/pop_spectopo.py index 883eee40..3a101a29 100644 --- a/src/eegprep/functions/popfunc/pop_spectopo.py +++ b/src/eegprep/functions/popfunc/pop_spectopo.py @@ -56,6 +56,7 @@ def pop_spectopo( map_labels = None title = "Channel spectra and maps" else: + _raise_for_unsupported_component_options(options) plot_data, component_numbers = _component_spectral_data(EEG, timerange, options.get("icacomps")) maps, chanlocs = component_map_data(EEG) map_numbers = _component_map_numbers( @@ -235,6 +236,27 @@ def _component_map_numbers( return numbers.astype(int) +def _raise_for_unsupported_component_options(options: dict[str, Any]) -> None: + """Fail loudly when component controls EEGPrep does not implement are set to non-default values. + + EEGLAB's component dialog offers ``plotchan`` (``0`` = whole scalp) and ``icamode`` + (checked = component spectra). EEGPrep only computes whole-scalp component spectra, so a + non-default ``plotchan`` (a specific electrode or ``[]`` = max-power electrode) or an unchecked + ``icamode`` ((data-comp) spectra) is rejected instead of being silently ignored. + """ + if "plotchan" in options: + plotchan = numeric_vector(options["plotchan"]) + if plotchan.size != 1 or int(plotchan[0]) != 0: + raise ValueError( + "pop_spectopo only supports whole-scalp component spectra (plotchan=0); " + "per-electrode or max-power projection is not available in EEGPrep" + ) + if "icamode" in options and not bool(options["icamode"]): + raise ValueError( + "pop_spectopo only supports component spectra (icamode on); (data-comp) spectra is not available in EEGPrep" + ) + + def _split_spectopo_options(options: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: spectral_keys = {"plot", "winsize", "overlap", "nfft"} spectral = {key: options[key] for key in spectral_keys if key in options} diff --git a/src/eegprep/functions/popfunc/pop_topochansel.py b/src/eegprep/functions/popfunc/pop_topochansel.py index 0f919546..d2c7e002 100644 --- a/src/eegprep/functions/popfunc/pop_topochansel.py +++ b/src/eegprep/functions/popfunc/pop_topochansel.py @@ -5,7 +5,7 @@ from typing import Any from eegprep.functions.popfunc._pop_utils import format_history_value -from eegprep.functions.popfunc.pop_chansel import pop_chansel +from eegprep.functions.popfunc.pop_chansel import pop_chansel, pop_chansel_resolve def pop_topochansel( @@ -22,17 +22,17 @@ def pop_topochansel( del args, kwargs, labels if gui is None: gui = select is None + channel_values, resolved = pop_chansel_resolve(chanlocs, select) if gui: chanlist, strchannames, cellchannames = pop_chansel(chanlocs, select=select, withindex="on") else: - channel_values = _channel_labels(chanlocs) - chanlist = _resolve_selection(select, channel_values) + chanlist = resolved cellchannames = [channel_values[index - 1] for index in chanlist] strchannames = " ".join(cellchannames) first_output: Any = cellchannames if str(cellstrout).lower() == "on" else chanlist command = ( "pop_topochansel(" - f"{format_history_value(_channel_labels(chanlocs), cell_for_sequence='all_strings')}, " + f"{format_history_value(channel_values, cell_for_sequence='all_strings')}, " f"{format_history_value(select, cell_for_sequence=None)});" ) if return_com: @@ -40,35 +40,4 @@ def pop_topochansel( return first_output, cellchannames, strchannames -def _channel_labels(chanlocs: Any) -> list[str]: - if isinstance(chanlocs, dict) and "chanlocs" in chanlocs: - chanlocs = chanlocs["chanlocs"] - return [str(chan.get("labels", chan)) if isinstance(chan, dict) else str(chan) for chan in (chanlocs or [])] - - -def _resolve_selection(select: Any, labels: list[str]) -> list[int]: - if select is None or select == "": - return [] - if isinstance(select, str): - tokens = select.split() - elif isinstance(select, (int, float)): - tokens = [select] - else: - tokens = list(select) - lowered = [label.lower() for label in labels] - selected = [] - for token in tokens: - if isinstance(token, (int, float)) or str(token).isdigit(): - index = int(token) - else: - try: - index = lowered.index(str(token).lower()) + 1 - except ValueError as exc: - raise ValueError(f"Unknown channel label {token!r}") from exc - if index < 1 or index > len(labels): - raise ValueError("Selected channel index out of range") - selected.append(index) - return selected - - __all__ = ["pop_topochansel"] diff --git a/src/eegprep/functions/popfunc/pop_topoplot.py b/src/eegprep/functions/popfunc/pop_topoplot.py index 9cfbf18a..c777d0ba 100644 --- a/src/eegprep/functions/popfunc/pop_topoplot.py +++ b/src/eegprep/functions/popfunc/pop_topoplot.py @@ -13,7 +13,7 @@ from eegprep.functions.guifunc.spec import ControlSpec, DialogSpec from eegprep.functions.miscfunc.misc import round_mat from eegprep.functions.popfunc._chanutils import chanlocs_as_list -from eegprep.functions.popfunc._plot_utils import component_map_data +from eegprep.functions.popfunc._plot_utils import component_map_data, python_literal from eegprep.functions.popfunc._pop_utils import is_on as _is_on from eegprep.functions.popfunc._pop_utils import parse_key_value_args, parse_numeric_sequence, parse_text_tokens from eegprep.functions.sigprocfunc.topoplot import topoplot @@ -433,35 +433,14 @@ def _history_command( pieces = [ "EEG", f"typeplot={int(typeplot)}", - f"items={_python_literal(items)}", - f"topotitle={_python_literal(topotitle)}", - f"rowcols={_python_literal(list(rowcols))}", + f"items={python_literal(items)}", + f"topotitle={python_literal(topotitle)}", + f"rowcols={python_literal(list(rowcols))}", f"plotdip={int(plotdip)}", ] for key, value in options.items(): - pieces.append(f"{key}={_python_literal(value)}") + pieces.append(f"{key}={python_literal(value)}") return f"pop_topoplot({', '.join(pieces)})" -def _python_literal(value: Any) -> str: - if isinstance(value, np.ndarray): - value = value.tolist() - if isinstance(value, (np.integer, np.floating)): - value = value.item() - if isinstance(value, float): - if np.isnan(value): - return "float('nan')" - if np.isposinf(value): - return "float('inf')" - if np.isneginf(value): - return "float('-inf')" - if value.is_integer(): - return str(int(value)) - if isinstance(value, list): - return "[" + ", ".join(_python_literal(item) for item in value) + "]" - if isinstance(value, tuple): - return "(" + ", ".join(_python_literal(item) for item in value) + ("," if len(value) == 1 else "") + ")" - return repr(value) - - __all__ = ["plot_channel_locations", "pop_topoplot", "pop_topoplot_dialog_spec"] diff --git a/src/eegprep/functions/sigprocfunc/eegrej.py b/src/eegprep/functions/sigprocfunc/eegrej.py index dc70047e..f329e3d8 100644 --- a/src/eegprep/functions/sigprocfunc/eegrej.py +++ b/src/eegprep/functions/sigprocfunc/eegrej.py @@ -10,10 +10,7 @@ def _is_boundary_event(event: Dict) -> bool: if isinstance(t, str): return t.lower() == "boundary" if isinstance(t, (int, float)): - try: - return int(t) == -99 - except Exception: - return False + return int(t) == -99 return False @@ -135,10 +132,12 @@ def eegrej( extra += float(ev.get("duration", 0.0) or 0.0) durations[i_region] += extra - # Compute boundevents considering prior removals + # Compute boundevents considering prior removals. EEGLAB shifts each later + # boundary by the prior regions' base spans (eegrej.m L139), not by the + # augmented durations (those only feed the inserted boundary's .duration). boundevents = r[:, 0].astype(float) - 1.0 - if len(durations) > 1: - cums = np.concatenate([[0.0], np.cumsum(durations[:-1])]) + if len(base_durations) > 1: + cums = np.concatenate([[0.0], np.cumsum(base_durations[:-1].astype(float))]) boundevents = boundevents - cums boundevents = boundevents + 0.5 boundevents = boundevents[boundevents >= 0] diff --git a/src/eegprep/functions/sigprocfunc/epoch.py b/src/eegprep/functions/sigprocfunc/epoch.py index 3afb2eff..421d593d 100644 --- a/src/eegprep/functions/sigprocfunc/epoch.py +++ b/src/eegprep/functions/sigprocfunc/epoch.py @@ -4,10 +4,14 @@ locked to specified events. """ +import logging + import numpy as np from ..miscfunc.misc import round_mat +logger = logging.getLogger(__name__) + def epoch(data, events, lim, **kwargs): """ @@ -51,7 +55,8 @@ def _as_1d(a): reallim[1] = int(round_mat(lim[1] * g['srate'] - 1)) # minus 1 sample # --- epoching --- - print('Epoching...') + if g['verbose'] == 'on': + logger.info('Epoching...') newdatalength = int(reallim[1] - reallim[0] + 1) @@ -84,12 +89,10 @@ def _as_1d(a): posinit = pos0 + reallim[0] # 0-based + offset posend = pos0 + reallim[1] # 0-based + offset - # Boundary check: MATLAB uses 1-based logic for boundary checks - # Convert to 1-based for the boundary check only - posinit_1based = posinit + 1 - posend_1based = posend + 1 - within_one_epoch = np.floor((posinit_1based - 1) / dataframes) == np.floor((posend_1based - 1) / dataframes) - within_bounds = (posinit_1based >= 1) and (posend_1based <= datawidth) + # Boundary check in MATLAB coordinates: data(:,posinit:posend) requires + # posinit >= 1 and posend <= datawidth, matching the posinit-1 slice start below. + within_one_epoch = np.floor((posinit - 1) / dataframes) == np.floor((posend - 1) / dataframes) + within_bounds = (posinit >= 1) and (posend <= datawidth) if within_one_epoch and within_bounds: # Extract contiguous slice. MATLAB does data(:,posinit:posend) with posinit/posend in MATLAB coordinates @@ -116,12 +119,12 @@ def _as_1d(a): indexes[index] = 1 else: if g['verbose'] == 'on': - print(f'Warning: event {index + 1} out of value limits') + logger.warning('event %s out of value limits', index + 1) else: indexes[index] = 1 else: if g['verbose'] == 'on': - print(f'Warning: event {index + 1} out of data boundary') + logger.warning('event %s out of data boundary', index + 1) # Re-reference events if g['allevents'] is not None and g['allevents'].size > 0: diff --git a/src/eegprep/functions/sigprocfunc/runamica.py b/src/eegprep/functions/sigprocfunc/runamica.py index 550615b9..591573d2 100644 --- a/src/eegprep/functions/sigprocfunc/runamica.py +++ b/src/eegprep/functions/sigprocfunc/runamica.py @@ -827,23 +827,30 @@ def runamica( # Write parameter file param_file = _write_param_file(outdir, params) - # Run the binary - _run_amica(binary, param_file) - - # Load output - mods = _load_amica_output( - outdir, - num_models=num_models, - num_pcs=pcakeep, - data_dim=chans, - num_mix_comps=num_mix_comps, - max_iter=max_iter, - field_dim=frames, - ) + try: + # Run the binary + _run_amica(binary, param_file) + + # Load output + mods = _load_amica_output( + outdir, + num_models=num_models, + num_pcs=pcakeep, + data_dim=chans, + num_mix_comps=num_mix_comps, + max_iter=max_iter, + field_dim=frames, + ) - # Extract model 0 weights and sphere - weights = mods['W'][:, :, 0] - sphere = mods['S'][: mods['num_pcs'], :] + # Extract model 0 weights and sphere + weights = mods['W'][:, :, 0] + sphere = mods['S'][: mods['num_pcs'], :] + except BaseException: + # A failed run/load must not leak the temp dir we created (it holds the + # full float32 .fdt data copy). Remove it, then re-raise the real error. + if tmp_created: + shutil.rmtree(outdir, ignore_errors=True) + raise # Cleanup if cleanup: diff --git a/src/eegprep/functions/sigprocfunc/runica.py b/src/eegprep/functions/sigprocfunc/runica.py index 62def0eb..5c23534b 100644 --- a/src/eegprep/functions/sigprocfunc/runica.py +++ b/src/eegprep/functions/sigprocfunc/runica.py @@ -20,6 +20,8 @@ """ import logging +import time + import numpy as np from scipy.linalg import sqrtm, pinv, eig from ...plugins.clean_rawdata.private.ransac import rand_permutation @@ -222,8 +224,10 @@ def runica(data, **kwargs): # 1. DATA VALIDATION AND INITIALIZATION # ========================================================================= - # Ensure data is float64 for numerical consistency with MATLAB - data = np.asarray(data, dtype=np.float64) + # Ensure data is float64 for numerical consistency with MATLAB. + # Copy at entry so the in-place channel-mean subtraction below never + # mutates the caller's array, regardless of input dtype. + data = np.array(data, dtype=np.float64, copy=True) chans, frames = data.shape urchans = chans # remember original data channels @@ -621,8 +625,6 @@ def runica(data, **kwargs): elif reset_randomseed: # Set seed based on time (random state) # Use None to get time-based seed, similar to MATLAB's sum(100*clock) - import time - seed = int(time.time() * 1000) % (2**32) rng = np.random.RandomState(seed) else: @@ -645,22 +647,6 @@ def runica(data, **kwargs): # Phase 2: Core ICA Training Loop # ========================================================================= - # Helper function for random permutation with MATLAB parity - def custom_randperm(n, rng_state): - """ - Random permutation with MATLAB parity. - - This function produces the SAME permutation sequence as MATLAB's randperm() - by using rand() + round_mat() instead of permutation(). This achieves exact - parity because: - - rand() (uniform [0,1]) matches between Python and MATLAB - - round_mat() matches MATLAB's round() tie-breaking behavior - - permutation() uses different algorithms and does NOT match - - For details see test_parity_rng.py and clean_rawdata/private/ransac.py:rand_permutation() - """ - return rand_permutation(n, rng_state) - # Initialize step counters and tracking variables step = 0 # Training step counter (MATLAB line 795) laststep = 0 # Will be set when stopping criterion met @@ -690,7 +676,7 @@ def custom_randperm(n, rng_state): if biasflag and extended: while step < maxsteps: # MATLAB line 828 # Shuffle data order at each step (MATLAB line 829) - timeperm = custom_randperm(datalength, rng) + timeperm = rand_permutation(datalength, rng) # Process data in blocks (MATLAB line 831) for t in range(0, lastt, block): @@ -877,7 +863,7 @@ def custom_randperm(n, rng_state): elif biasflag and not extended: while step < maxsteps: # MATLAB line 1004 # Shuffle data order at each step (MATLAB line 1005) - timeperm = custom_randperm(datalength, rng) + timeperm = rand_permutation(datalength, rng) # Process data in blocks (MATLAB line 1007) for t in range(0, lastt, block): @@ -1020,7 +1006,7 @@ def custom_randperm(n, rng_state): elif not biasflag and extended: while step < maxsteps: # MATLAB line 1128 # Shuffle data order at each step (MATLAB line 1129) - timeperm = custom_randperm(datalength, rng) + timeperm = rand_permutation(datalength, rng) # Process data in blocks (MATLAB line 1131) for t in range(0, lastt, block): @@ -1182,7 +1168,7 @@ def custom_randperm(n, rng_state): else: # not biasflag and not extended while step < maxsteps: # MATLAB line 1299 # Shuffle data order at each step (MATLAB line 1300) - timeperm = custom_randperm(datalength, rng) + timeperm = rand_permutation(datalength, rng) # Process data in blocks (MATLAB line 1302) for t in range(0, lastt, block): diff --git a/src/eegprep/functions/statistics/_core.py b/src/eegprep/functions/statistics/_core.py index a1845baa..c99aee11 100644 --- a/src/eegprep/functions/statistics/_core.py +++ b/src/eegprep/functions/statistics/_core.py @@ -115,10 +115,14 @@ def fdr(pvals: Any, q: float | None = None, fdr_type: str = "parametric") -> FDR Threshold and boolean mask with the same shape as ``pvals``. """ - values = _as_numeric_array(pvals, "pvals", require_axis=False) + values = np.asarray(pvals) + if not np.issubdtype(values.dtype, np.number): + raise TypeError("pvals must be numeric") if values.size == 0: return FDRResult(np.array([], dtype=float), np.array([], dtype=bool)) - if np.any((values < 0) | (values > 1)): + finite_mask = np.isfinite(values) + finite_values = values[finite_mask] + if np.any((finite_values < 0) | (finite_values > 1)): raise ValueError("pvals must contain probabilities between 0 and 1") if q is None: @@ -127,7 +131,7 @@ def fdr(pvals: Any, q: float | None = None, fdr_type: str = "parametric") -> FDR for current in thresholds: current_result = fdr(values, float(current), fdr_type=fdr_type) threshold[current_result.mask] = current - return FDRResult(threshold, values <= threshold) + return FDRResult(threshold, finite_mask & (values <= threshold)) q_value = float(q) if not 0 <= q_value <= 1: @@ -137,13 +141,16 @@ def fdr(pvals: Any, q: float | None = None, fdr_type: str = "parametric") -> FDR if fdr_type_name not in {"parametric", "nonparametric"}: raise ValueError("fdr_type must be 'parametric' or 'nonparametric'") - flat = np.sort(values.reshape(-1)) + if finite_values.size == 0: + return FDRResult(0.0, np.zeros(values.shape, dtype=bool)) + + flat = np.sort(finite_values.reshape(-1)) count = flat.size indices = np.arange(1, count + 1, dtype=float) correction = 1.0 if fdr_type_name == "parametric" else float(np.sum(1.0 / indices)) accepted = flat <= indices / count * q_value / correction threshold_value = float(flat[np.flatnonzero(accepted).max()]) if np.any(accepted) else 0.0 - return FDRResult(threshold_value, values <= threshold_value) + return FDRResult(threshold_value, finite_mask & (values <= threshold_value)) def stat_surrogate_pvals(distribution: Any, observed: Any, tail: str = "both") -> np.ndarray: diff --git a/src/eegprep/functions/studyfunc/_cluster_kmeans.py b/src/eegprep/functions/studyfunc/_cluster_kmeans.py new file mode 100644 index 00000000..5a10713b --- /dev/null +++ b/src/eegprep/functions/studyfunc/_cluster_kmeans.py @@ -0,0 +1,62 @@ +"""Deterministic k-means numeric kernel shared by STUDY clustering helpers. + +These functions hold the clustering numerics so that ``pop_clust``, +``optimal_kmeans``, and ``robust_kmeans`` all import downward from this module +instead of one user-facing wrapper. Labels are returned 1-based to match the +EEGLAB-facing cluster numbering convention used by the callers. +""" + +from __future__ import annotations + +import numpy as np + + +KMEANS_MAX_ITER = 300 +KMEANS_N_INIT = 10 +KMEANS_TOLERANCE = 1e-8 + + +def kmeans_labels(data: np.ndarray, clus_num: int, random_state: int) -> tuple[np.ndarray, np.ndarray]: + """Run deterministic multi-restart k-means and return 1-based labels and centers.""" + rng = np.random.default_rng(random_state) + best_labels: np.ndarray | None = None + best_centers: np.ndarray | None = None + best_inertia = float("inf") + for _attempt in range(KMEANS_N_INIT): + centers = data[rng.choice(data.shape[0], size=clus_num, replace=False)].copy() + labels = np.zeros(data.shape[0], dtype=int) + for _iteration in range(KMEANS_MAX_ITER): + labels = np.argmin(squared_distances(data, centers), axis=1) + new_centers = _recompute_centers(data, labels, centers, clus_num) + if np.allclose(new_centers, centers, rtol=0, atol=KMEANS_TOLERANCE): + centers = new_centers + break + centers = new_centers + distances = squared_distances(data, centers) + inertia = float(np.sum(distances[np.arange(data.shape[0]), labels])) + if inertia < best_inertia: + best_inertia = inertia + best_labels = labels.copy() + best_centers = centers.copy() + if best_labels is None or best_centers is None: + raise ValueError("K-means failed to initialize clusters") + return best_labels.astype(int) + 1, best_centers + + +def squared_distances(data: np.ndarray, centers: np.ndarray) -> np.ndarray: + """Return the matrix of squared Euclidean distances from rows to centers.""" + diff = data[:, np.newaxis, :] - centers[np.newaxis, :, :] + return np.sum(diff * diff, axis=2) + + +def _recompute_centers(data: np.ndarray, labels: np.ndarray, centers: np.ndarray, clus_num: int) -> np.ndarray: + new_centers = np.empty_like(centers) + nearest_distance = np.min(squared_distances(data, centers), axis=1) + fallback_index = int(np.argmax(nearest_distance)) + for cluster in range(clus_num): + rows = data[labels == cluster] + new_centers[cluster] = np.mean(rows, axis=0) if rows.size else data[fallback_index] + return new_centers + + +__all__ = ["KMEANS_MAX_ITER", "KMEANS_N_INIT", "KMEANS_TOLERANCE", "kmeans_labels", "squared_distances"] diff --git a/src/eegprep/functions/studyfunc/_cluster_utils.py b/src/eegprep/functions/studyfunc/_cluster_utils.py index e8dbd36e..2a4147f2 100644 --- a/src/eegprep/functions/studyfunc/_cluster_utils.py +++ b/src/eegprep/functions/studyfunc/_cluster_utils.py @@ -9,7 +9,12 @@ import numpy as np from eegprep.functions.popfunc._plot_utils import component_maps, python_literal -from eegprep.functions.studyfunc._study_utils import as_alleeg_list, ensure_study, sync_datasetinfo +from eegprep.functions.studyfunc._study_utils import ( + as_alleeg_list, + ensure_study, + sync_datasetinfo, + unique_preserving_order, +) def checked_study_and_datasets( @@ -85,7 +90,7 @@ def _axis_values(primary: Any, fallback: Any) -> list[int]: values = _numeric_values(primary) if values: return values - return _unique_preserving_order(_numeric_values(fallback)) + return unique_preserving_order(_numeric_values(fallback)) def _numeric_values(value: Any) -> list[int]: @@ -97,14 +102,6 @@ def _numeric_values(value: Any) -> list[int]: return [int(item) for item in array.ravel().tolist()] -def _unique_preserving_order(values: list[int]) -> list[int]: - output = [] - for value in values: - if value not in output: - output.append(value) - return output - - def cluster_list(study: dict[str, Any]) -> list[dict[str, Any]]: """Return normalized STUDY cluster dictionaries.""" value = study.get("cluster") diff --git a/src/eegprep/functions/studyfunc/_std_measureplot.py b/src/eegprep/functions/studyfunc/_std_measureplot.py index 2bac76b1..7c48e13c 100644 --- a/src/eegprep/functions/studyfunc/_std_measureplot.py +++ b/src/eegprep/functions/studyfunc/_std_measureplot.py @@ -9,12 +9,25 @@ from eegprep.functions.popfunc._pop_utils import is_on, parse_key_value_args from eegprep.functions.studyfunc._study_utils import build_python_call -from eegprep.functions.studyfunc.std_readdata import std_readdata +from eegprep.functions.studyfunc.std_readdata import MEASURE_DATA_FIELDS, std_readdata LINE_MEASURES = {"erp", "spec"} +def default_measure_target(study: dict[str, Any], field: str, channels: Any, clusters: Any, components: Any): + """Choose channels-vs-parent-cluster default target by cached measure field. + + Returns ``("channels", clusters)`` when any STUDY channel group caches + ``field``; otherwise defaults to the parent component cluster (``1``). + """ + if channels is not None or clusters is not None or components is not None: + return channels, clusters + if any(isinstance(group, dict) and field in group for group in study.get("changrp") or []): + return "channels", clusters + return channels, 1 + + def std_measureplot( STUDY: dict[str, Any], ALLEEG: list[dict[str, Any]] | None, @@ -73,12 +86,7 @@ def std_measureplot( def _default_target( study: dict[str, Any], datatype: str, channels: Any, clusters: Any, components: Any ) -> tuple[Any, Any]: - if channels is not None or clusters is not None or components is not None: - return channels, clusters - field = {"erp": "erpdata", "spec": "specdata", "ersp": "erspdata", "itc": "itcdata"}[datatype] - if any(isinstance(group, dict) and field in group for group in study.get("changrp") or []): - return "channels", clusters - return channels, 1 + return default_measure_target(study, MEASURE_DATA_FIELDS[datatype], channels, clusters, components) def _apply_ranges( diff --git a/src/eegprep/functions/studyfunc/_study_utils.py b/src/eegprep/functions/studyfunc/_study_utils.py index 08943007..2fd4653b 100644 --- a/src/eegprep/functions/studyfunc/_study_utils.py +++ b/src/eegprep/functions/studyfunc/_study_utils.py @@ -109,7 +109,6 @@ def ensure_study(STUDY: dict[str, Any] | None = None) -> dict[str, Any]: study.setdefault("design", []) study.setdefault("currentdesign", 0) study.setdefault("cache", []) - study.setdefault("preclust", _default_preclust()) study.setdefault("history", "STUDY = []") study.setdefault("saved", "no") etc = study.get("etc") @@ -351,6 +350,28 @@ def build_python_call(targets: tuple[str, ...], function_name: str, *args: str, return f"{assignment} = {function_name}({', '.join(pieces)})" +def unique_preserving_order(values: list[int]) -> list[int]: + """Return integers de-duplicated while keeping first-seen order.""" + output: list[int] = [] + for value in values: + if value not in output: + output.append(value) + return output + + +def trialinfo_rows(value: Any) -> list[dict[str, Any]]: + """Normalize STUDY trialinfo into a list of row dictionaries.""" + if value is None: + return [] + if isinstance(value, np.ndarray): + value = value.tolist() + if isinstance(value, dict): + return [value] + if not isinstance(value, list): + return [] + return [row for row in value if isinstance(row, dict)] + + def clear_study_data_fields(study: dict[str, Any]) -> dict[str, Any]: """Remove precomputed measure arrays from STUDY channel/component groups.""" study = deepcopy(study) @@ -379,15 +400,6 @@ def _datasetinfo_list(value: Any) -> list[dict[str, Any]]: return [deepcopy(item) for item in value if isinstance(item, dict)] -def _default_preclust() -> dict[str, list[Any]]: - return { - "erpclusttimes": [], - "specclustfreqs": [], - "erspclustfreqs": [], - "erspclusttimes": [], - } - - def _default_components(eeg: dict[str, Any]) -> list[int]: weights = eeg.get("icaweights") if weights is None: diff --git a/src/eegprep/functions/studyfunc/optimal_kmeans.py b/src/eegprep/functions/studyfunc/optimal_kmeans.py index 8a3a6423..265d86c1 100644 --- a/src/eegprep/functions/studyfunc/optimal_kmeans.py +++ b/src/eegprep/functions/studyfunc/optimal_kmeans.py @@ -6,7 +6,7 @@ import numpy as np -from eegprep.functions.studyfunc.pop_clust import _kmeans_labels, _squared_distances +from eegprep.functions.studyfunc._cluster_kmeans import kmeans_labels, squared_distances def optimal_kmeans( @@ -19,9 +19,9 @@ def optimal_kmeans( values = _cluster_range(clusnum, data.shape[0]) best = None for cluster_count in values: - labels, centers = _kmeans_labels(data, cluster_count, random_state) + labels, centers = kmeans_labels(data, cluster_count, random_state) score = _silhouette_score(data, labels) - distances = np.sqrt(_squared_distances(data, centers)) + distances = np.sqrt(squared_distances(data, centers)) sumd = _sum_distances(labels, distances) candidate = (score, labels, centers, sumd, distances) if best is None or candidate[0] > best[0]: @@ -47,7 +47,7 @@ def _cluster_range(clusnum: int | list[int] | tuple[int, int], maximum: int) -> def _silhouette_score(data: np.ndarray, labels: np.ndarray) -> float: - distance = np.sqrt(_squared_distances(data, data)) + distance = np.sqrt(squared_distances(data, data)) scores = [] for row, label in enumerate(labels): same = labels == label diff --git a/src/eegprep/functions/studyfunc/pop_clust.py b/src/eegprep/functions/studyfunc/pop_clust.py index d4799857..f9305d91 100644 --- a/src/eegprep/functions/studyfunc/pop_clust.py +++ b/src/eegprep/functions/studyfunc/pop_clust.py @@ -9,14 +9,12 @@ from eegprep.functions.guifunc.inputgui import inputgui from eegprep.functions.guifunc.spec import ControlSpec, DialogSpec from eegprep.functions.popfunc._plot_utils import numeric_vector +from eegprep.functions.studyfunc._cluster_kmeans import kmeans_labels from eegprep.functions.studyfunc._cluster_utils import checked_study_and_datasets, cluster_command from eegprep.functions.studyfunc.std_createclust import std_createclust ALGORITHMS = ("kmeans", "kmeanscluster") -KMEANS_MAX_ITER = 300 -KMEANS_N_INIT = 10 -KMEANS_TOLERANCE = 1e-8 def pop_clust( @@ -66,7 +64,7 @@ def pop_clust( if clus_num > data.shape[0]: raise ValueError("Number of clusters cannot exceed the number of preclustered components") - labels, centers = _kmeans_labels(data, clus_num, random_state) + labels, centers = kmeans_labels(data, clus_num, random_state) if np.isfinite(outliers): if outliers <= 0: raise ValueError("Outlier threshold must be greater than 0") @@ -132,47 +130,6 @@ def pop_clust_dialog_spec(STUDY: dict[str, Any]) -> DialogSpec: ) -def _kmeans_labels(data: np.ndarray, clus_num: int, random_state: int) -> tuple[np.ndarray, np.ndarray]: - rng = np.random.default_rng(random_state) - best_labels: np.ndarray | None = None - best_centers: np.ndarray | None = None - best_inertia = float("inf") - for _attempt in range(KMEANS_N_INIT): - centers = data[rng.choice(data.shape[0], size=clus_num, replace=False)].copy() - labels = np.zeros(data.shape[0], dtype=int) - for _iteration in range(KMEANS_MAX_ITER): - labels = np.argmin(_squared_distances(data, centers), axis=1) - new_centers = _recompute_centers(data, labels, centers, clus_num) - if np.allclose(new_centers, centers, rtol=0, atol=KMEANS_TOLERANCE): - centers = new_centers - break - centers = new_centers - distances = _squared_distances(data, centers) - inertia = float(np.sum(distances[np.arange(data.shape[0]), labels])) - if inertia < best_inertia: - best_inertia = inertia - best_labels = labels.copy() - best_centers = centers.copy() - if best_labels is None or best_centers is None: - raise ValueError("K-means failed to initialize clusters") - return best_labels.astype(int) + 1, best_centers - - -def _squared_distances(data: np.ndarray, centers: np.ndarray) -> np.ndarray: - diff = data[:, np.newaxis, :] - centers[np.newaxis, :, :] - return np.sum(diff * diff, axis=2) - - -def _recompute_centers(data: np.ndarray, labels: np.ndarray, centers: np.ndarray, clus_num: int) -> np.ndarray: - new_centers = np.empty_like(centers) - nearest_distance = np.min(_squared_distances(data, centers), axis=1) - fallback_index = int(np.argmax(nearest_distance)) - for cluster in range(clus_num): - rows = data[labels == cluster] - new_centers[cluster] = np.mean(rows, axis=0) if rows.size else data[fallback_index] - return new_centers - - def _mark_outliers(data: np.ndarray, labels: np.ndarray, centers: np.ndarray, threshold: float) -> np.ndarray: output = labels.copy() cluster_distances = [] diff --git a/src/eegprep/functions/studyfunc/robust_kmeans.py b/src/eegprep/functions/studyfunc/robust_kmeans.py index 74080368..10fd198d 100644 --- a/src/eegprep/functions/studyfunc/robust_kmeans.py +++ b/src/eegprep/functions/studyfunc/robust_kmeans.py @@ -6,8 +6,8 @@ import numpy as np +from eegprep.functions.studyfunc._cluster_kmeans import kmeans_labels, squared_distances from eegprep.functions.studyfunc.optimal_kmeans import optimal_kmeans -from eegprep.functions.studyfunc.pop_clust import _kmeans_labels, _squared_distances def robust_kmeans( @@ -39,8 +39,8 @@ def robust_kmeans( values[active], [2, cluster_count], random_state=random_state ) else: - active_labels, centers = _kmeans_labels(values[active], cluster_count, random_state) - distances = np.sqrt(_squared_distances(values, centers)) + active_labels, centers = kmeans_labels(values[active], cluster_count, random_state) + distances = np.sqrt(squared_distances(values, centers)) labels[:] = 0 labels[active] = active_labels new_outliers = _outlier_rows(active, active_labels, distances[active], float(STD)) diff --git a/src/eegprep/functions/studyfunc/std_combtrialinfo.py b/src/eegprep/functions/studyfunc/std_combtrialinfo.py index bb682669..71df0fb3 100644 --- a/src/eegprep/functions/studyfunc/std_combtrialinfo.py +++ b/src/eegprep/functions/studyfunc/std_combtrialinfo.py @@ -7,6 +7,8 @@ import numpy as np +from eegprep.functions.studyfunc._study_utils import trialinfo_rows + DATASETINFO_TRIAL_EXCLUDE = {"filepath", "filename", "comps", "trialinfo"} @@ -19,7 +21,7 @@ def std_combtrialinfo(datasetinfo: Any, inds: Any, trials: Any = None) -> list[d rows: list[dict[str, Any]] = [] for index in selected: info = infos[index - 1] - base_rows = _trial_rows(info.get("trialinfo")) + base_rows = trialinfo_rows(info.get("trialinfo")) if not base_rows: base_rows = [{} for _trial in range(trial_counts[index - 1])] for base in base_rows: @@ -60,7 +62,7 @@ def _selected_indices(infos: list[dict[str, Any]], inds: Any) -> list[int]: def _trial_counts(infos: list[dict[str, Any]], trials: Any) -> list[int]: if trials is None: - return [max(1, len(_trial_rows(info.get("trialinfo")))) for info in infos] + return [max(1, len(trialinfo_rows(info.get("trialinfo")))) for info in infos] if isinstance(trials, np.ndarray): trials = trials.tolist() if not isinstance(trials, (list, tuple)): @@ -71,16 +73,4 @@ def _trial_counts(infos: list[dict[str, Any]], trials: Any) -> list[int]: return counts[: len(infos)] -def _trial_rows(value: Any) -> list[dict[str, Any]]: - if value is None: - return [] - if isinstance(value, np.ndarray): - value = value.tolist() - if isinstance(value, dict): - return [value] - if not isinstance(value, list): - return [] - return [row for row in value if isinstance(row, dict)] - - __all__ = ["std_combtrialinfo"] diff --git a/src/eegprep/functions/studyfunc/std_makedesign.py b/src/eegprep/functions/studyfunc/std_makedesign.py index 5a564440..53e67acd 100644 --- a/src/eegprep/functions/studyfunc/std_makedesign.py +++ b/src/eegprep/functions/studyfunc/std_makedesign.py @@ -2,6 +2,7 @@ from __future__ import annotations +from copy import deepcopy from typing import Any from eegprep.functions.popfunc._pop_utils import parse_key_value_args @@ -48,9 +49,10 @@ def std_makedesign( ) -> Any: """Create or replace a 1-based STUDY design. - This Phase 5a implementation stores design metadata and factor selections. - Measure-file deletion/precompute side effects from EEGLAB are intentionally - outside this phase. + Design metadata and factor selections are stored. ``delfiles`` controls + cached measure arrays: ``'on'`` or ``'limited'`` clear them on the design + change, while ``'off'`` preserves any precomputed ``changrp``/``cluster`` + measures attached to the redefined design. """ datasets = as_alleeg_list(ALLEEG) study = sync_datasetinfo(ensure_study(STUDY), datasets) @@ -105,7 +107,14 @@ def std_makedesign( designs[design_index - 1] = design study["design"] = designs study = std_addvarlevel(study, design_index) + preserved_changrp = deepcopy(study.get("changrp")) if delfiles == "off" else None + preserved_cluster = deepcopy(study.get("cluster")) if delfiles == "off" else None study = std_selectdesign(study, datasets, design_index) + if delfiles == "off": + if preserved_changrp is not None: + study["changrp"] = preserved_changrp + if preserved_cluster is not None: + study["cluster"] = preserved_cluster study["cache"] = [] study["saved"] = "no" study = store_consistency(study, datasets) diff --git a/src/eegprep/functions/studyfunc/std_pacplot.py b/src/eegprep/functions/studyfunc/std_pacplot.py index 67adc17d..d651f40f 100644 --- a/src/eegprep/functions/studyfunc/std_pacplot.py +++ b/src/eegprep/functions/studyfunc/std_pacplot.py @@ -8,7 +8,8 @@ import numpy as np from eegprep.functions.popfunc._pop_utils import is_on, parse_key_value_args -from eegprep.functions.studyfunc._study_utils import build_python_call, ensure_study +from eegprep.functions.studyfunc._std_measureplot import default_measure_target +from eegprep.functions.studyfunc._study_utils import build_python_call from eegprep.functions.studyfunc.std_readdata import std_readpac @@ -57,7 +58,7 @@ def std_pacplot( unsupported = sorted(key for key in options if key not in ignored) if unsupported: raise ValueError(f"Unknown std_pacplot option(s): {', '.join(unsupported)}") - channels, clusters = _default_target(STUDY, channels, clusters, components) + channels, clusters = default_measure_target(STUDY, "pacdata", channels, clusters, components) study, pacdata, pactimes, pacfreqs = std_readpac( STUDY, ALLEEG, @@ -87,15 +88,6 @@ def std_pacplot( return (*result, command) if return_com else result -def _default_target(study: dict[str, Any], channels: Any, clusters: Any, components: Any) -> tuple[Any, Any]: - if channels is not None or clusters is not None or components is not None: - return channels, clusters - prepared = ensure_study(study) - if any(isinstance(group, dict) and "pacdata" in group for group in prepared.get("changrp") or []): - return "channels", clusters - return channels, 1 - - def _plot_pac(data: list[np.ndarray], times: np.ndarray, freqs: np.ndarray) -> Any: images = [] for values in data: diff --git a/src/eegprep/functions/studyfunc/std_precomp.py b/src/eegprep/functions/studyfunc/std_precomp.py index 505f84e2..cec96359 100644 --- a/src/eegprep/functions/studyfunc/std_precomp.py +++ b/src/eegprep/functions/studyfunc/std_precomp.py @@ -80,12 +80,15 @@ def std_precomp( if not computed: raise ValueError("std_precomp requires at least one measure enabled") kind = _measure_kind(chanorcomp) + force = is_on(recompute) if kind == "channels": study["changrp"] = _precompute_channels( datasets, chanorcomp, computed, int(design), + cached=_cached_by_name(study.get("changrp")), + force=force, erpparams=_params_dict(erpparams), specparams=_params_dict(specparams), erspparams=_params_dict(erspparams), @@ -97,6 +100,8 @@ def std_precomp( chanorcomp, computed, int(design), + cached=_first_cluster(study.get("cluster")), + force=force, allcomps=is_on(allcomps), scalp=is_on(scalp), erpparams=_params_dict(erpparams), @@ -136,6 +141,8 @@ def _precompute_channels( computed: list[str], design: int, *, + cached: dict[str, dict[str, Any]], + force: bool, erpparams: dict[str, Any], specparams: dict[str, Any], erspparams: dict[str, Any], @@ -145,6 +152,7 @@ def _precompute_channels( groups = [] for channel_index in selected: label = labels[channel_index] + prior = cached.get(label, {}) entry: dict[str, Any] = { "name": label, "channels": [label], @@ -159,17 +167,33 @@ def _precompute_channels( }, } if "erp" in computed: - entry["erpdata"], entry["erptimes"] = _channel_erp(datasets, channel_index, erpparams) + if _keep_cached(prior, "erpdata", force): + _carry(entry, prior, ("erpdata", "erptimes")) + else: + entry["erpdata"], entry["erptimes"] = _channel_erp(datasets, channel_index, erpparams) if "spec" in computed: - entry["specdata"], entry["specfreqs"] = _channel_spec(datasets, channel_index, specparams) - if "ersp" in computed or "itc" in computed: + if _keep_cached(prior, "specdata", force): + _carry(entry, prior, ("specdata", "specfreqs")) + else: + entry["specdata"], entry["specfreqs"] = _channel_spec(datasets, channel_index, specparams) + ersp_cached = _keep_cached(prior, "erspdata", force) if "ersp" in computed else True + itc_cached = _keep_cached(prior, "itcdata", force) if "itc" in computed else True + if ("ersp" in computed and not ersp_cached) or ("itc" in computed and not itc_cached): tf = _channel_time_frequency(datasets, channel_index, erspparams) - if "ersp" in computed: + else: + tf = None + if "ersp" in computed: + if ersp_cached: + _carry(entry, prior, ("erspdata", "ersptimes", "erspfreqs", "erspbase")) + else: entry["erspdata"] = tf["erspdata"] entry["ersptimes"] = tf["times"] entry["erspfreqs"] = tf["freqs"] entry["erspbase"] = tf["powbase"] - if "itc" in computed: + if "itc" in computed: + if itc_cached: + _carry(entry, prior, ("itcdata", "itctimes", "itcfreqs")) + else: entry["itcdata"] = tf["itcdata"] entry["itctimes"] = tf["times"] entry["itcfreqs"] = tf["freqs"] @@ -184,6 +208,8 @@ def _precompute_components( computed: list[str], design: int, *, + cached: dict[str, Any], + force: bool, allcomps: bool, scalp: bool, erpparams: dict[str, Any], @@ -213,21 +239,37 @@ def _precompute_components( if scalp: cluster["topo"] = _component_topographies(datasets, selected, selection_mask) if "erp" in computed: - cluster["erpdata"], cluster["erptimes"] = _component_erp( - datasets, activations, selected, selection_mask, erpparams - ) + if _keep_cached(cached, "erpdata", force): + _carry(cluster, cached, ("erpdata", "erptimes")) + else: + cluster["erpdata"], cluster["erptimes"] = _component_erp( + datasets, activations, selected, selection_mask, erpparams + ) if "spec" in computed: - cluster["specdata"], cluster["specfreqs"] = _component_spec( - datasets, activations, selected, selection_mask, specparams - ) - if "ersp" in computed or "itc" in computed: + if _keep_cached(cached, "specdata", force): + _carry(cluster, cached, ("specdata", "specfreqs")) + else: + cluster["specdata"], cluster["specfreqs"] = _component_spec( + datasets, activations, selected, selection_mask, specparams + ) + ersp_cached = _keep_cached(cached, "erspdata", force) if "ersp" in computed else True + itc_cached = _keep_cached(cached, "itcdata", force) if "itc" in computed else True + if ("ersp" in computed and not ersp_cached) or ("itc" in computed and not itc_cached): tf = _component_time_frequency(datasets, activations, selected, selection_mask, erspparams) - if "ersp" in computed: + else: + tf = None + if "ersp" in computed: + if ersp_cached: + _carry(cluster, cached, ("erspdata", "ersptimes", "erspfreqs", "erspbase")) + else: cluster["erspdata"] = tf["erspdata"] cluster["ersptimes"] = tf["times"] cluster["erspfreqs"] = tf["freqs"] cluster["erspbase"] = tf["powbase"] - if "itc" in computed: + if "itc" in computed: + if itc_cached: + _carry(cluster, cached, ("itcdata", "itctimes", "itcfreqs")) + else: cluster["itcdata"] = tf["itcdata"] cluster["itctimes"] = tf["times"] cluster["itcfreqs"] = tf["freqs"] @@ -447,6 +489,28 @@ def _component_topographies( return np.asarray(topographies, dtype=float).tolist() +def _keep_cached(prior: dict[str, Any], data_field: str, force: bool) -> bool: + return not force and data_field in prior + + +def _carry(target: dict[str, Any], prior: dict[str, Any], fields: tuple[str, ...]) -> None: + for field in fields: + if field in prior: + target[field] = prior[field] + + +def _cached_by_name(changrp: Any) -> dict[str, dict[str, Any]]: + if not isinstance(changrp, list): + return {} + return {entry["name"]: entry for entry in changrp if isinstance(entry, dict) and entry.get("name")} + + +def _first_cluster(cluster: Any) -> dict[str, Any]: + if isinstance(cluster, list) and cluster and isinstance(cluster[0], dict): + return cluster[0] + return {} + + def _measure_kind(chanorcomp: Any) -> str: if isinstance(chanorcomp, str) and chanorcomp.lower() == "components": return "components" diff --git a/src/eegprep/functions/studyfunc/std_readdata.py b/src/eegprep/functions/studyfunc/std_readdata.py index 80f7eb02..726183d0 100644 --- a/src/eegprep/functions/studyfunc/std_readdata.py +++ b/src/eegprep/functions/studyfunc/std_readdata.py @@ -8,7 +8,10 @@ from eegprep.functions.popfunc._plot_utils import numeric_vector from eegprep.functions.studyfunc._cluster_utils import cluster_list, sets_array -from eegprep.functions.studyfunc._study_utils import ensure_study +from eegprep.functions.studyfunc._study_utils import ensure_study, unique_preserving_order + + +MEASURE_DATA_FIELDS = {"erp": "erpdata", "spec": "specdata", "ersp": "erspdata", "itc": "itcdata"} def std_readdata( @@ -239,7 +242,7 @@ def _cached_axis_values( if values.size == count: return values.astype(int) values = numeric_vector(group.get(fallback_key), dtype=int) - unique_values = np.asarray(_unique_preserving_order(values.tolist()), dtype=int) + unique_values = np.asarray(unique_preserving_order(values.tolist()), dtype=int) if unique_values.size == count: return unique_values return np.arange(fallback_start, fallback_start + count, dtype=int) @@ -268,14 +271,6 @@ def component_measure_selection(components: Any, axis: np.ndarray) -> np.ndarray return np.asarray(selected, dtype=int) -def _unique_preserving_order(values: list[int]) -> list[int]: - output = [] - for value in values: - if value not in output: - output.append(value) - return output - - def _cluster_component_data( parent: dict[str, Any], cluster: dict[str, Any], data: np.ndarray, components: Any ) -> np.ndarray: @@ -321,7 +316,7 @@ def _parent_cluster_requested(clusters: Any) -> bool: def _data_array(group: dict[str, Any], measure: str) -> np.ndarray: - field = {"erp": "erpdata", "spec": "specdata", "ersp": "erspdata", "itc": "itcdata"}[measure] + field = MEASURE_DATA_FIELDS[measure] if field not in group: raise ValueError(f"{measure.upper()} measures have not been precomputed") return np.asarray(group[field], dtype=float) diff --git a/src/eegprep/functions/timefreqfunc/_pac_support.py b/src/eegprep/functions/timefreqfunc/_pac_support.py index 871add5c..bb450f00 100644 --- a/src/eegprep/functions/timefreqfunc/_pac_support.py +++ b/src/eegprep/functions/timefreqfunc/_pac_support.py @@ -88,12 +88,16 @@ def compute_pac( **kwargs: Any, ) -> PacResult: """Compute EEGLAB-style phase-amplitude coupling from epoched data.""" - _ = title, vert, newfig unsupported = sorted(kwargs) if unsupported: raise TypeError(f"Unsupported pac option(s): {', '.join(unsupported)}") if alpha is not None: raise NotImplementedError(PAC_UNSUPPORTED_MESSAGE) + if str(title) != "" or vert is not None or str(newfig).strip().lower() not in {"on", "1", "true", "yes"}: + raise NotImplementedError( + "compute_pac returns PAC arrays without plotting; the 'title', 'vert', and " + "'newfig' plotting options are not implemented" + ) method_name = str(method or "mod").lower() if method_name == "modulation": method_name = "mod" @@ -443,6 +447,11 @@ def _surrogate_pvalue(surrogates: np.ndarray, observed: float, statlim: str) -> def _empirical_pvalue(distribution: np.ndarray, observed: float) -> float: + # PAC uses the bias-corrected (count + 1) / (N + 1) one-sided convention about + # |observed|, which never returns exactly 0. This intentionally differs from the + # ERSP/ITC surrogate p-values (bootstat.exact_p_values) and the statcond tail + # folding (statistics.stat_surrogate_pvals); they are distinct statistical + # conventions, not duplicates, so PAC keeps its own definition. values = np.asarray(distribution, dtype=float) values = values[np.isfinite(values)] if values.size == 0: diff --git a/src/eegprep/functions/timefreqfunc/newcrossf.py b/src/eegprep/functions/timefreqfunc/newcrossf.py index e8b3536e..1a1f1790 100644 --- a/src/eegprep/functions/timefreqfunc/newcrossf.py +++ b/src/eegprep/functions/timefreqfunc/newcrossf.py @@ -9,8 +9,10 @@ import numpy as np from scipy import stats -from eegprep.functions.timefreqfunc.bootstat import exact_p_values -from eegprep.functions.timefreqfunc.newtimef import compute_time_frequency +from eegprep.functions.popfunc._pop_utils import is_on as _is_on +from eegprep.functions.popfunc._pop_utils import parse_numeric_sequence +from eegprep.functions.timefreqfunc.bootstat import bootstrap_threshold, exact_p_values +from eegprep.functions.timefreqfunc.newtimef import _threshold_vector, compute_time_frequency from eegprep.functions.timefreqfunc.newtimefitc import newtimefitc from eegprep.functions.timefreqfunc.newtimeftrialbaseln import baseline_indices @@ -240,10 +242,10 @@ def _resample_pair( def _upper_thresholds_by_frequency(values: np.ndarray, *, alpha: float) -> np.ndarray: - reshaped = values.transpose(1, 0, 2).reshape(values.shape[1], -1) - sorted_values = np.sort(reshaped, axis=1) - tail_count = max(1, int(round(sorted_values.shape[1] * alpha))) - return np.nanmean(sorted_values[:, -tail_count:], axis=1) + nfreq = values.shape[1] + pooled = values.transpose(0, 2, 1).reshape(-1, nfreq) + thresholds = np.asarray(bootstrap_threshold(pooled, alpha=alpha, bootside="upper")) + return thresholds.reshape(nfreq) def _shuffle_trials(tf_y: np.ndarray, count: int, rng: Any) -> np.ndarray: @@ -272,15 +274,6 @@ def _bootstrap_indices(times: np.ndarray, baseboot: Any) -> np.ndarray: return baseline_indices(times, values) -def _threshold_vector(thresholds: np.ndarray, target_shape: tuple[int, ...]) -> np.ndarray: - values = np.asarray(thresholds, dtype=float).squeeze() - if values.ndim == 0: - return np.full(target_shape, float(values)) - if values.ndim == 1: - return values[:, np.newaxis] - return values - - def _plot_cross_frequency( coherence: np.ndarray, phase: np.ndarray, @@ -394,18 +387,9 @@ def _boot_array(value: Any) -> np.ndarray | None: def _numeric_vector(value: Any, *, dtype: Any = float) -> np.ndarray: if value is None: return np.asarray([], dtype=dtype) - if isinstance(value, np.ndarray): - return value.astype(dtype).ravel() - if isinstance(value, (int, float, np.integer, np.floating)): - return np.asarray([value], dtype=dtype) - if isinstance(value, str): - text = value.strip().strip("[]") - if not text: - return np.asarray([], dtype=dtype) - return np.asarray([float(token) for token in text.replace(",", " ").split()], dtype=dtype) - if isinstance(value, (list, tuple)): - return np.asarray(value, dtype=dtype).ravel() - return np.asarray([value], dtype=dtype) + if isinstance(value, str) and value.strip() == "": + return np.asarray([], dtype=dtype) + return np.asarray(parse_numeric_sequence(value, dtype=dtype), dtype=dtype).ravel() def _first_numeric(value: Any, default: float) -> float: @@ -413,8 +397,4 @@ def _first_numeric(value: Any, default: float) -> float: return float(values[0]) if values.size else float(default) -def _is_on(value: Any) -> bool: - return str(value).lower() not in {"0", "false", "off", "no", "none"} - - __all__ = ["CrossFrequencyResult", "newcrossf"] diff --git a/src/eegprep/functions/timefreqfunc/newtimef.py b/src/eegprep/functions/timefreqfunc/newtimef.py index 03a69686..6e90c2a9 100644 --- a/src/eegprep/functions/timefreqfunc/newtimef.py +++ b/src/eegprep/functions/timefreqfunc/newtimef.py @@ -8,7 +8,10 @@ import matplotlib.pyplot as plt import numpy as np -from eegprep.functions.timefreqfunc.bootstat import exact_p_values +from eegprep.functions.popfunc._pop_utils import is_on as _is_on +from eegprep.functions.popfunc._pop_utils import parse_numeric_sequence +from eegprep.functions.statistics.fdr import fdr +from eegprep.functions.timefreqfunc.bootstat import bootstrap_threshold, exact_p_values from eegprep.functions.timefreqfunc.newtimefbaseln import newtimefbaseln from eegprep.functions.timefreqfunc.newtimefitc import newtimefitc from eegprep.functions.timefreqfunc.newtimeftrialbaseln import baseline_indices, newtimeftrialbaseln @@ -84,7 +87,10 @@ def newtimef( verbose: str = "off", ) -> TimeFrequencyResult: """Compute an EEGLAB-like ERSP/ITC time-frequency decomposition.""" - _ = overlap, plotphase + if overlap is not None: + raise NotImplementedError("newtimef does not implement the 'overlap' option") + if str(plotphase).strip().lower() not in {"off", "0", "false", "no", "none"}: + raise NotImplementedError("newtimef does not implement the 'plotphase' option") if freqs is None and freqrange is not None: freqs = freqrange if type is not None: @@ -242,7 +248,8 @@ def compute_time_frequency( timewarpms: Any = None, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Return ``(freqs, times_ms, tfdata)`` for one signal.""" - _ = overlap + if overlap is not None: + raise NotImplementedError("compute_time_frequency does not implement the 'overlap' option") timestretch, _markers = _timewarp_options(timewarp, timewarpms, None, frames, tlimits, srate) decomp = _compute_decomposition( data, @@ -502,38 +509,23 @@ def _resample_trials( def _thresholds_by_frequency(values: np.ndarray, *, alpha: float, both: bool) -> np.ndarray: - reshaped = values.transpose(1, 0, 2).reshape(values.shape[1], -1) - sorted_values = np.sort(reshaped, axis=1) - tail_count = max(1, int(round(sorted_values.shape[1] * alpha))) - upper = np.nanmean(sorted_values[:, -tail_count:], axis=1) - if not both: - return upper - lower = np.nanmean(sorted_values[:, :tail_count], axis=1) - return np.stack([lower, upper], axis=1) + nfreq = values.shape[1] + pooled = values.transpose(0, 2, 1).reshape(-1, nfreq) + bootside = "both" if both else "upper" + thresholds = np.asarray(bootstrap_threshold(pooled, alpha=alpha, bootside=bootside)) + return thresholds.reshape(nfreq, 2) if both else thresholds.reshape(nfreq) def _significance_mask(pvalues: np.ndarray, alpha: float, correction: str) -> np.ndarray: mode = str(correction).lower() if mode == "fdr": - threshold = _fdr_threshold(pvalues, alpha) + threshold = float(fdr(pvalues, alpha).threshold) if threshold == 0: return np.zeros_like(pvalues, dtype=bool) return pvalues <= threshold return pvalues <= alpha -def _fdr_threshold(pvalues: np.ndarray, alpha: float) -> float: - values = np.sort(np.asarray(pvalues, dtype=float).ravel()) - values = values[np.isfinite(values)] - if values.size == 0: - return 0.0 - ranks = np.arange(1, values.size + 1, dtype=float) - accepted = values <= alpha * ranks / values.size - if not np.any(accepted): - return 0.0 - return float(values[np.nonzero(accepted)[0][-1]]) - - def _threshold_mask(values: np.ndarray, thresholds: np.ndarray) -> np.ndarray: threshold_values = np.asarray(thresholds) if threshold_values.ndim == 1: @@ -655,24 +647,9 @@ def _plot_panel( def _numeric_vector(value: Any, *, dtype: Any = float) -> np.ndarray: if value is None: return np.asarray([], dtype=dtype) - if isinstance(value, np.ndarray): - return value.astype(dtype).ravel() - if isinstance(value, (int, float, np.integer, np.floating)): - return np.asarray([value], dtype=dtype) - if isinstance(value, str): - text = value.strip().strip("[]") - if not text: - return np.asarray([], dtype=dtype) - values = [] - for token in text.replace(",", " ").split(): - if ":" in token: - values.extend(_colon_sequence(token)) - else: - values.append(float(token)) - return np.asarray(values, dtype=dtype) - if isinstance(value, (list, tuple)): - return np.asarray(value, dtype=dtype).ravel() - return np.asarray([value], dtype=dtype) + if isinstance(value, str) and value.strip() == "": + return np.asarray([], dtype=dtype) + return np.asarray(parse_numeric_sequence(value, dtype=dtype), dtype=dtype).ravel() def _first_numeric(value: Any, default: float) -> float: @@ -680,25 +657,4 @@ def _first_numeric(value: Any, default: float) -> float: return float(values[0]) if values.size else float(default) -def _colon_sequence(token: str) -> list[float]: - pieces = token.split(":") - if len(pieces) not in {2, 3}: - raise ValueError(f"Invalid colon range: {token}") - start = float(pieces[0]) - if len(pieces) == 2: - stop = float(pieces[1]) - step = 1.0 if stop >= start else -1.0 - else: - step = float(pieces[1]) - stop = float(pieces[2]) - if step == 0 or (stop - start) * step < 0: - return [] - count = int(np.floor((stop - start) / step + 1e-9)) + 1 - return [float(start + index * step) for index in range(max(count, 0))] - - -def _is_on(value: Any) -> bool: - return str(value).lower() not in {"0", "false", "off", "no", "none"} - - __all__ = ["TimeFrequencyResult", "compute_time_frequency", "newtimef"] diff --git a/src/eegprep/plugins/EEG_BIDS/bids.py b/src/eegprep/plugins/EEG_BIDS/bids.py index 38eeaa3e..f906ce0f 100644 --- a/src/eegprep/plugins/EEG_BIDS/bids.py +++ b/src/eegprep/plugins/EEG_BIDS/bids.py @@ -71,7 +71,7 @@ def query_for_adjacent_fpath(fn: str, **overrides) -> Dict[str, Any]: def gen_derived_fpath( raw_fn: str, *, - outputdir: str = '${root}/derivatives/clean_artifacts', + outputdir: str = '{root}/derivatives/clean_artifacts', keyword: str = '', suffix: Optional[str] = None, extension: str = '.set', @@ -84,6 +84,8 @@ def gen_derived_fpath( Original raw filename. outputdir : str Output directory for derived files (e.g., 'derivatives/clean_artifacts'). + The literal '{root}' placeholder, if present, is replaced with the BIDS + dataset root path. keyword : str Optional keyword tag to splice into the filename (e.g., 'desc-cleaned'). suffix : str, optional @@ -122,7 +124,8 @@ def gen_derived_fpath( # single directory name, need to prepend everything else outputdir = os.path.join(root_relative, 'derivatives', outputdir) - out_path = f"{outputdir}/{root_relative}/{new_fprefix}{new_fext}" + outputdir = os.path.normpath(outputdir) + out_path = os.path.join(outputdir, root_relative, new_fprefix + new_fext) return out_path diff --git a/src/eegprep/plugins/EEG_BIDS/coords.py b/src/eegprep/plugins/EEG_BIDS/coords.py index 3c18d132..5791b7a0 100644 --- a/src/eegprep/plugins/EEG_BIDS/coords.py +++ b/src/eegprep/plugins/EEG_BIDS/coords.py @@ -17,14 +17,13 @@ def coords_to_mm(coords: np.ndarray, unit: str) -> np.ndarray: """Convert the given coordinates array from the specified unit to millimeters.""" if unit in ('mm', 'millimeters'): - pass + return np.array(coords, copy=True) elif unit in ('cm', 'centimeters'): - coords *= 10.0 + return coords * 10.0 elif unit in ('m', 'meters'): - coords *= 1000.0 + return coords * 1000.0 else: raise ValueError(f"Unsupported coordinate unit: {unit}. Supported units are 'mm', 'cm', 'm'.") - return coords def coords_RAS_to_ALS(coords: np.ndarray) -> np.ndarray: diff --git a/src/eegprep/plugins/ICLabel/ICL_feature_extractor.py b/src/eegprep/plugins/ICLabel/ICL_feature_extractor.py index a1de18e5..88bdd5aa 100644 --- a/src/eegprep/plugins/ICLabel/ICL_feature_extractor.py +++ b/src/eegprep/plugins/ICLabel/ICL_feature_extractor.py @@ -28,15 +28,14 @@ def ICL_feature_extractor(EEG, flag_autocorr=False): EEG = deepcopy(EEG) - # Check inputs - ncomp = EEG['icawinv'].shape[1] - - # Check for ICA key and if it is not empty + # Check for ICA key and if it is not empty before dereferencing it if 'icawinv' not in EEG.keys() or EEG['icawinv'].size == 0: raise ValueError('You must have an ICA decomposition to use ICLabel') + ncomp = EEG['icawinv'].shape[1] + # Assuming chanlocs are correct - if EEG['ref'] != 'average' and EEG['ref'] != 'averef': + if EEG.get('ref') != 'average' and EEG.get('ref') != 'averef': EEG = pop_reref(EEG, []) # Calculate ICA activations if missing and cast to double diff --git a/src/eegprep/plugins/ICLabel/eeg_autocorr.py b/src/eegprep/plugins/ICLabel/eeg_autocorr.py index 5bb01a22..ee8cd208 100644 --- a/src/eegprep/plugins/ICLabel/eeg_autocorr.py +++ b/src/eegprep/plugins/ICLabel/eeg_autocorr.py @@ -56,27 +56,3 @@ def eeg_autocorr(EEG, pct_data=None): ac = ac[:, 1:] return ac - - -def test_eeg_autocorr(): - """Test the eeg_autocorr function.""" - EEG = { - 'srate': 256, - 'icaweights': np.random.randn(10, 256), - 'pnts': 1000, - 'trials': 5, - 'icaact': np.random.randn(10, 1000, 5), - } - - eeg_autocorr(EEG, 100) - - # print information about psdmed - # print(psdmed.shape) - - # print(psdmed) - - # assert psdmed.shape == (10, 100) - # assert np.all(np.isfinite(psdmed)) - - -# test_eeg_autocorr() diff --git a/src/eegprep/plugins/ICLabel/eeg_autocorr_fftw.py b/src/eegprep/plugins/ICLabel/eeg_autocorr_fftw.py index 6c77a3bb..d00242cc 100644 --- a/src/eegprep/plugins/ICLabel/eeg_autocorr_fftw.py +++ b/src/eegprep/plugins/ICLabel/eeg_autocorr_fftw.py @@ -7,7 +7,6 @@ import numpy as np from scipy.fft import fft, ifft, next_fast_len from scipy.signal import resample_poly -from ...functions.popfunc.pop_loadset import pop_loadset def eeg_autocorr_fftw(EEG, pct_data=100): @@ -60,29 +59,3 @@ def eeg_autocorr_fftw(EEG, pct_data=100): ac = ac[:, 1:101] return ac - - -def test_eeg_autocorr_fftw(): - """Test function for eeg_autocorr_fftw.""" - EEG = { - 'srate': 256, - 'icaweights': np.random.randn(10, 256), - 'pnts': 1000, - 'trials': 5, - 'icaact': np.random.randn(10, 1000, 5), - } - EEG = pop_loadset('/System/Volumes/Data/data/data/STUDIES/STERN/S01/Memorize.set') - - # reshape the last two dimensions of EEG['icaact'] - # EEG['icaact'] = EEG['icaact'].reshape(EEG['icaact'].shape[0], -1) - - # convert EEG['icaact'] to double precision - - psdmed = eeg_autocorr_fftw(EEG, 100) - - # print information about psdmed - print(psdmed.shape) - print(psdmed) - - -# test_eeg_autocorr_fftw() diff --git a/src/eegprep/plugins/ICLabel/eeg_autocorr_welch.py b/src/eegprep/plugins/ICLabel/eeg_autocorr_welch.py index 62ed2c61..6dfaaf15 100644 --- a/src/eegprep/plugins/ICLabel/eeg_autocorr_welch.py +++ b/src/eegprep/plugins/ICLabel/eeg_autocorr_welch.py @@ -7,7 +7,6 @@ import numpy as np from scipy.signal import resample_poly import random -from ...functions.popfunc.pop_loadset import pop_loadset from numpy.fft import fft, ifft @@ -83,18 +82,3 @@ def eeg_autocorr_welch(EEG, pct_data=100): ac = ac[:, 1:101] return ac - - -def test_eeg_autocorr_welch(): - """Test function for eeg_autocorr_welch.""" - eeglab_file_path = './eeglab_data_with_ica_tmp.set' - EEG = pop_loadset(eeglab_file_path) - - eeg_autocorr_welch(EEG, 100) - - # print information about psdmed - # print(psdmed.shape) - # print(psdmed) - - -# test_eeg_autocorr_welch() diff --git a/src/eegprep/plugins/ICLabel/eeg_rpsd.py b/src/eegprep/plugins/ICLabel/eeg_rpsd.py index a4b52cc0..7855fd8c 100644 --- a/src/eegprep/plugins/ICLabel/eeg_rpsd.py +++ b/src/eegprep/plugins/ICLabel/eeg_rpsd.py @@ -53,9 +53,9 @@ def eeg_rpsd(EEG, nfreqs=None, pct_data=100): .transpose() ) - np.random.seed(0) # rng('default') in MATLAB + rng = np.random.RandomState(0) # rng('default') in MATLAB; local RNG avoids mutating global state n_seg = index.shape[1] * EEG['trials'] - subset = np.random.permutation(n_seg)[: int(n_seg * pct_data / 100)] + subset = rng.permutation(n_seg)[: int(n_seg * pct_data / 100)] # calculate windowed spectrums psdmed = np.zeros((ncomp, nfreqs)) @@ -70,21 +70,3 @@ def eeg_rpsd(EEG, nfreqs=None, pct_data=100): psdmed[it, :] = 20 * np.log10(np.median(temp, axis=2)) return psdmed - - -def test_eeg_rpsd(): - """Test the eeg_rpsd function with sample data.""" - EEG = { - 'srate': 256, - 'icaweights': np.random.randn(10, 256), - 'pnts': 1000, - 'trials': 5, - 'icaact': np.random.randn(10, 1000, 5), - } - - psdmed = eeg_rpsd(EEG, 100) - assert psdmed.shape == (10, 100) - assert np.all(np.isfinite(psdmed)) - - -# test_eeg_rpsd() diff --git a/src/eegprep/plugins/ICLabel/iclabel_net_load_py_measures.py b/src/eegprep/plugins/ICLabel/iclabel_net_load_py_measures.py deleted file mode 100644 index f3b20c80..00000000 --- a/src/eegprep/plugins/ICLabel/iclabel_net_load_py_measures.py +++ /dev/null @@ -1,186 +0,0 @@ -"""ICLabel neural network model loading utilities.""" - -from pathlib import Path - -import scipy -import scipy.io -import torch - - -class Reshape(torch.nn.Module): - """Reshape layer for PyTorch.""" - - def __init__(self, shape): - """Initialize reshape layer.""" - super().__init__() - self.shape = shape - - def forward(self, x): - """Forward pass for reshape.""" - return x.view(x.shape[0], *self.shape) - - -class Concatenate(torch.nn.Module): - """Concatenate layer for PyTorch.""" - - def __init__(self, dim): - """Initialize concatenate layer.""" - super().__init__() - self.dim = dim - - def forward(self, x: list): - """Forward pass for concatenate.""" - return torch.cat(x, dim=self.dim) - - -class ICLabelNet(torch.nn.Module): - """ICLabel neural network model.""" - - def __init__(self, mat_path): - """Initialize ICLabelNet from MATLAB file.""" - super().__init__() - iclabel_matlab = scipy.io.loadmat(mat_path) - params = iclabel_matlab['params'][0] - i = 11 - print('shape of param', i, torch.tensor(params[i][1]).shape) - self.discriminator_image_layer1_conv = torch.nn.Conv2d( - in_channels=1, out_channels=128, kernel_size=4, stride=2, padding=1, dilation=1 - ) - print(self.discriminator_image_layer1_conv.weight.shape) - self.discriminator_image_layer1_conv.weight = torch.nn.Parameter(torch.tensor(params[0][1]).permute(3, 2, 0, 1)) - self.discriminator_image_layer1_conv.bias = torch.nn.Parameter(torch.tensor(params[1][1]).squeeze()) - self.discriminator_image_layer1_relu = torch.nn.LeakyReLU(0.2) - self.discriminator_image_layer2_conv = torch.nn.Conv2d( - in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, dilation=1 - ) - self.discriminator_image_layer2_conv.weight = torch.nn.Parameter(torch.tensor(params[2][1]).permute(3, 2, 0, 1)) - self.discriminator_image_layer2_conv.bias = torch.nn.Parameter(torch.tensor(params[3][1]).squeeze()) - self.discriminator_image_layer2_relu = torch.nn.LeakyReLU(0.2) - self.discriminator_image_layer3_conv = torch.nn.Conv2d( - in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1, dilation=1 - ) - self.discriminator_image_layer3_conv.weight = torch.nn.Parameter(torch.tensor(params[4][1]).permute(3, 2, 0, 1)) - self.discriminator_image_layer3_conv.bias = torch.nn.Parameter(torch.tensor(params[5][1]).squeeze()) - self.discriminator_image_layer3_relu = torch.nn.LeakyReLU(0.2) - self.discriminator_psdmed_layer1_conv_conv = torch.nn.Conv2d( - in_channels=1, out_channels=128, kernel_size=(1, 3), stride=1, padding=(0, 1), dilation=1 - ) - self.discriminator_psdmed_layer1_conv_conv.weight = torch.nn.Parameter( - torch.tensor(params[6][1]).permute(3, 2, 0, 1) - ) - self.discriminator_psdmed_layer1_conv_conv.bias = torch.nn.Parameter(torch.tensor(params[7][1]).squeeze()) - self.discriminator_psdmed_layer1_conv_relu = torch.nn.LeakyReLU(0.2) - self.discriminator_psdmed_layer2_conv_conv = torch.nn.Conv2d( - in_channels=128, out_channels=256, kernel_size=(1, 3), stride=1, padding=(0, 1), dilation=1 - ) - self.discriminator_psdmed_layer2_conv_conv.weight = torch.nn.Parameter( - torch.tensor(params[8][1]).permute(3, 2, 0, 1) - ) - self.discriminator_psdmed_layer2_conv_conv.bias = torch.nn.Parameter(torch.tensor(params[9][1]).squeeze()) - self.discriminator_psdmed_layer2_conv_relu = torch.nn.LeakyReLU(0.2) - self.discriminator_psdmed_layer3_conv_conv = torch.nn.Conv2d( - in_channels=256, out_channels=1, kernel_size=(1, 3), stride=1, padding=(0, 1), dilation=1 - ) - self.discriminator_psdmed_layer3_conv_conv.weight = torch.nn.Parameter( - torch.tensor(params[10][1]).unsqueeze(3).permute(3, 2, 0, 1) - ) - self.discriminator_psdmed_layer3_conv_conv.bias = torch.nn.Parameter(torch.tensor(params[11][1]).squeeze(1)) - self.discriminator_psdmed_layer3_conv_relu = torch.nn.LeakyReLU(0.2) - self.discriminator_autocorr_layer1_conv_conv = torch.nn.Conv2d( - in_channels=1, out_channels=128, kernel_size=(1, 3), stride=1, padding=(0, 1), dilation=1 - ) - self.discriminator_autocorr_layer1_conv_conv.weight = torch.nn.Parameter( - torch.tensor(params[12][1]).permute(3, 2, 0, 1) - ) - self.discriminator_autocorr_layer1_conv_conv.bias = torch.nn.Parameter(torch.tensor(params[13][1]).squeeze()) - self.discriminator_autocorr_layer1_conv_relu = torch.nn.LeakyReLU(0.2) - self.discriminator_autocorr_layer2_conv_conv = torch.nn.Conv2d( - in_channels=128, out_channels=256, kernel_size=(1, 3), stride=1, padding=(0, 1), dilation=1 - ) - self.discriminator_autocorr_layer2_conv_conv.weight = torch.nn.Parameter( - torch.tensor(params[14][1]).permute(3, 2, 0, 1) - ) - self.discriminator_autocorr_layer2_conv_conv.bias = torch.nn.Parameter(torch.tensor(params[15][1]).squeeze()) - self.discriminator_autocorr_layer2_conv_relu = torch.nn.LeakyReLU(0.2) - self.discriminator_autocorr_layer3_conv_conv = torch.nn.Conv2d( - in_channels=256, out_channels=1, kernel_size=(1, 3), stride=1, padding=(0, 1), dilation=1 - ) - self.discriminator_autocorr_layer3_conv_conv.weight = torch.nn.Parameter( - torch.tensor(params[16][1]).unsqueeze(3).permute(3, 2, 0, 1) - ) - self.discriminator_autocorr_layer3_conv_conv.bias = torch.nn.Parameter(torch.tensor(params[17][1]).squeeze(1)) - self.discriminator_autocorr_layer3_conv_relu = torch.nn.LeakyReLU(0.2) - self.discriminator_psdmed_reshape = Reshape((100, 1, 1)) - self.discriminator_psdmed_concat1 = Concatenate(dim=2) - self.discriminator_psdmed_concat2 = Concatenate(dim=3) - self.discriminator_autocorr_reshape = Reshape((100, 1, 1)) - self.discriminator_autocorr_concat1 = Concatenate(dim=2) - self.discriminator_autocorr_concat2 = Concatenate(dim=3) - self.discriminator_concat = Concatenate(dim=1) - self.discriminator_conv = torch.nn.Conv2d( - in_channels=712, out_channels=7, kernel_size=4, stride=1, padding=0, dilation=1 - ) - self.discriminator_conv.weight = torch.nn.Parameter(torch.tensor(params[18][1]).permute(3, 2, 0, 1)) - self.discriminator_conv.bias = torch.nn.Parameter(torch.tensor(params[19][1]).squeeze()) - self.discriminator_softmax = torch.nn.Softmax(dim=1) - - def forward(self, image, psdmed, autocorr): - """Forward pass for ICLabelNet.""" - x_image = self.discriminator_image_layer1_conv(image) - x_image = self.discriminator_image_layer1_relu(x_image) - x_image = self.discriminator_image_layer2_conv(x_image) - x_image = self.discriminator_image_layer2_relu(x_image) - x_image = self.discriminator_image_layer3_conv(x_image) - x_image = self.discriminator_image_layer3_relu(x_image) - print('x_image', x_image.shape) - - x_psdmed = self.discriminator_psdmed_layer1_conv_conv(psdmed) - x_psdmed = self.discriminator_psdmed_layer1_conv_relu(x_psdmed) - x_psdmed = self.discriminator_psdmed_layer2_conv_conv(x_psdmed) - x_psdmed = self.discriminator_psdmed_layer2_conv_relu(x_psdmed) - x_psdmed = self.discriminator_psdmed_layer3_conv_conv(x_psdmed) - x_psdmed = self.discriminator_psdmed_layer3_conv_relu(x_psdmed) - x_psdmed = self.discriminator_psdmed_reshape(x_psdmed) - x_psdmed = self.discriminator_psdmed_concat1([x_psdmed] * 4) - x_psdmed = self.discriminator_psdmed_concat2([x_psdmed] * 4) - print('x_psdmed', x_psdmed.shape) - - x_autocorr = self.discriminator_autocorr_layer1_conv_conv(autocorr) - x_autocorr = self.discriminator_autocorr_layer1_conv_relu(x_autocorr) - x_autocorr = self.discriminator_autocorr_layer2_conv_conv(x_autocorr) - x_autocorr = self.discriminator_autocorr_layer2_conv_relu(x_autocorr) - x_autocorr = self.discriminator_autocorr_layer3_conv_conv(x_autocorr) - x_autocorr = self.discriminator_autocorr_layer3_conv_relu(x_autocorr) - x_autocorr = self.discriminator_autocorr_reshape(x_autocorr) - x_autocorr = self.discriminator_autocorr_concat1([x_autocorr] * 4) - x_autocorr = self.discriminator_autocorr_concat2([x_autocorr] * 4) - print('x_autocorr', x_autocorr.shape) - - x = self.discriminator_concat([x_image, x_psdmed, x_autocorr]) - x = self.discriminator_conv(x) - print('x', x.shape) - # subtract max value to avoid overflow - x = x - torch.max(x, dim=1, keepdim=True).values - x = self.discriminator_softmax(x) - - return x - - -if __name__ == "__main__": - model = ICLabelNet(Path(__file__).with_name("netICL.mat")) - data = scipy.io.loadmat('python_temp_reformated.mat') - image_mat = data['grid'][0][0] - psdmed_mat = data['grid'][0][1] - autocorr_mat = data['grid'][0][2] - # assuming third dimension is trivial and last dimension is channel. First two dimensions (32 x 32) are size of topoplot - image = torch.tensor(image_mat).permute(-1, 2, 0, 1) - print('image shape', image.shape) - psdmed = torch.tensor(psdmed_mat).permute(-1, 2, 0, 1) - print('psd shape', psdmed.shape) - autocorr = torch.tensor(autocorr_mat).permute(-1, 2, 0, 1) - print('autocorr shape', autocorr.shape) - output = model(image, psdmed, autocorr) - print(output.shape) - - # save the output to a mat file - scipy.io.savemat('output4_py.mat', {'output': output.detach().numpy()}) diff --git a/src/eegprep/plugins/clean_rawdata/asr_calibrate.py b/src/eegprep/plugins/clean_rawdata/asr_calibrate.py index 62ac8349..775aaddc 100644 --- a/src/eegprep/plugins/clean_rawdata/asr_calibrate.py +++ b/src/eegprep/plugins/clean_rawdata/asr_calibrate.py @@ -12,6 +12,9 @@ logger = logging.getLogger(__name__) +# Sampling rates (Hz) for which a pre-computed spectral-shaping IIR filter is available. +_SUPPORTED_SRATES = frozenset({100, 128, 200, 250, 256, 300, 500, 512}) + def asr_calibrate( X, @@ -356,16 +359,14 @@ def asr_calibrate( dtype=np.float64, ) else: - # Fallback if no precomputed filter matches or yulewalk is unavailable - # Consider adding a call to a yulewalk implementation if available, - # or raising a more specific error/warning. - logger.warning( - f"No pre-computed spectral filter for srate {srate}. Using a simple default (may be suboptimal)." + # No precomputed spectral-shaping filter for this sampling rate. Degrading + # to a trivial difference filter would silently miscalibrate the ASR + # thresholds, so fail loudly instead (mirrors MATLAB's asr_calibrate:NoYulewalk). + raise ValueError( + f"No pre-computed ASR spectral filter for srate {srate} Hz " + f"(supported: {sorted(_SUPPORTED_SRATES)}). Resample the data to a " + "supported rate or pass explicit filter coefficients via B and A." ) - B = np.array([1.0, -1.0]) # Simple high-pass/difference filter as a basic fallback - A = np.array([1.0]) - # Original MATLAB error: - # error('asr_calibrate:NoYulewalk','The yulewalk() function was not found and there is no pre-computed spectral filter for your sampling rate...'); # Ensure data is finite X[~np.isfinite(X)] = 0.0 diff --git a/src/eegprep/plugins/clean_rawdata/asr_process.py b/src/eegprep/plugins/clean_rawdata/asr_process.py index 533ec0b9..c77b0134 100644 --- a/src/eegprep/plugins/clean_rawdata/asr_process.py +++ b/src/eegprep/plugins/clean_rawdata/asr_process.py @@ -192,15 +192,13 @@ def asr_process( logger.warning(f"Eigendecomposition failed at update point {j}. Using identity matrix.") D, V = np.ones(C), np.eye(C) - # Determine which components to keep (variance below threshold or not admissible for rejection) - try: - thresholds = np.sum(finite_matmul(T, V) ** 2, axis=0) - keep = (D < thresholds) | (np.arange(1, C + 1) < (C - max_dims_num)) - trivial = np.all(keep) - except Exception as e: - logger.error(f"Error in component selection: {e}") - keep = np.ones(C, dtype=bool) - trivial = True + # Determine which components to keep (variance below threshold or not admissible for rejection). + # No catch-all here: a shape/contract error in T/V must surface rather than silently + # disable artifact removal for this window. Genuine numerical-singularity cases are + # handled by the LinAlgError fallbacks for eigendecomposition and the pseudo-inverse. + thresholds = np.sum(finite_matmul(T, V) ** 2, axis=0) + keep = (D < thresholds) | (np.arange(1, C + 1) < (C - max_dims_num)) + trivial = np.all(keep) # Update the reconstruction matrix R if not trivial: diff --git a/src/eegprep/plugins/clean_rawdata/clean_artifacts.py b/src/eegprep/plugins/clean_rawdata/clean_artifacts.py index b4a75790..6b0e6cd3 100644 --- a/src/eegprep/plugins/clean_rawdata/clean_artifacts.py +++ b/src/eegprep/plugins/clean_rawdata/clean_artifacts.py @@ -1,5 +1,6 @@ """EEG artifact cleaning functions.""" +import copy import logging from typing import Any, Dict, Optional, Sequence, Tuple, Union @@ -116,9 +117,12 @@ def clean_artifacts( Riemannian path uses EEGPrep's calibration-time estimate; full Riemannian ASR processing is not ported. Channels : sequence of str or None - List of channel labels to include before cleaning (pop_select). Default None. + List of channel labels to include before cleaning (pop_select). The + returned dataset contains only these channels; channels outside this list + are dropped and not re-inserted. Default None. Channels_ignore : sequence of str or None - List of channel labels to exclude before cleaning. Default None. + List of channel labels to exclude before cleaning. The excluded channels + are dropped from the returned dataset and not re-inserted. Default None. availableRAM_GB : float or None Available system RAM in GB to adjust MaxMem. Default None. @@ -146,6 +150,8 @@ def clean_artifacts( if distance not in _DISTANCE_MODES: raise ValueError("Distance must be 'euclidian', 'euclidean', or 'riemannian'") + EEG = copy.deepcopy(EEG) + # Ensure some obligatory fields exist in the structure (MATLAB code assumes) if 'etc' not in EEG: EEG['etc'] = {} @@ -154,32 +160,31 @@ def clean_artifacts( # Optional: restrict to / ignore certain channels # ------------------------------------------------------------------ if Channels is not None and len(Channels): - # Attempt pop_select based on labels; fall back to manual + # Attempt pop_select based on labels; the manual fallback only covers the + # documented case where pop_select is unavailable (ImportError). Errors + # raised inside pop_select itself must surface, not be masked. try: from eegprep import pop_select EEG = pop_select(EEG, channel=list(Channels)) - except Exception: - # Manual selection on labels + except ImportError: lbl_to_idx = {ch['labels']: idx for idx, ch in enumerate(EEG['chanlocs'])} keep_idx = [lbl_to_idx[lbl] for lbl in Channels if lbl in lbl_to_idx] EEG['data'] = EEG['data'][keep_idx, :] EEG['chanlocs'] = [EEG['chanlocs'][i] for i in keep_idx] EEG['nbchan'] = len(keep_idx) - EEG['event'] = [] # will be restored later elif Channels_ignore is not None and len(Channels_ignore): try: from eegprep import pop_select EEG = pop_select(EEG, nochannel=list(Channels_ignore)) - except Exception: + except ImportError: lbl_to_idx = {ch['labels']: idx for idx, ch in enumerate(EEG['chanlocs'])} drop_idx_set = {lbl_to_idx[lbl] for lbl in Channels_ignore if lbl in lbl_to_idx} keep_idx = [i for i in range(len(EEG['chanlocs'])) if i not in drop_idx_set] EEG['data'] = EEG['data'][keep_idx, :] EEG['chanlocs'] = [EEG['chanlocs'][i] for i in keep_idx] EEG['nbchan'] = len(keep_idx) - EEG['event'] = [] # ------------------------------------------------------------------ # 1) Flat‑line channel removal @@ -196,8 +201,10 @@ def clean_artifacts( raise ValueError('Highpass must be a (low, high) tuple or None/"off".') logger.info('Applying high‑pass filter...') EEG = clean_drifts(EEG, tuple(Highpass)) - # Keep a copy after HP for optional return - HP = EEG.copy() + # Keep a point-in-time snapshot after HP for optional return. Deep-copy so + # later stages that mutate EEG['etc'] (channel/sample masks) do not bleed + # into the returned high-pass dataset. + HP = copy.deepcopy(EEG) # ------------------------------------------------------------------ # 3) Channel cleaning (noisy / disconnected) @@ -220,10 +227,17 @@ def clean_artifacts( num_samples=int(NumSamples), subset_size=SubsetSize, # Default 0.25, matches MATLAB default when not passed ) - removed_channels = ~EEG['etc']['clean_channel_mask'] - except Exception as e: - # Fall back to "no‑locs" version if location dependent failure - logger.warning(f'clean_channels failed ({e}); falling back to clean_channels_nolocs.') + # clean_channels only writes clean_channel_mask when it removes channels; + # an absent mask means nothing was removed, so keep the all-False default. + mask = EEG.get('etc', {}).get('clean_channel_mask') + if mask is not None: + removed_channels = ~mask + except ValueError as e: + # Only the missing-channel-locations case warrants the no-locs fallback; + # any other ValueError is a genuine failure and must propagate. + if 'location' not in str(e).lower(): + raise + logger.warning(f'clean_channels lacks usable locations ({e}); falling back to clean_channels_nolocs.') EEG, removed_channels = clean_channels_nolocs( EEG, min_corr=float(NoLocsChannelCriterion), @@ -238,9 +252,8 @@ def clean_artifacts( BUR = EEG # default in case ASR is skipped if BurstCriterion not in (None, 'off'): logger.info('Applying ASR burst repair...') - # Save original data before clean_asr modifies EEG in place. - # MATLAB passes structs by value so the caller's EEG retains the - # original data, but Python dicts are passed by reference. + # Snapshot the pre-repair data to compare against the ASR-repaired + # result; clean_asr returns a fresh dataset (BUR) and leaves EEG intact. original_data = EEG['data'].copy() if BurstRejection else None useriemannian = _DISTANCE_MODES[distance] BUR = clean_asr( @@ -254,8 +267,8 @@ def clean_artifacts( ) if BurstRejection: - # Determine unchanged samples: compare original (pre-ASR) with repaired. - # Use original_data saved before clean_asr modified EEG['data'] in place. + # Determine unchanged samples: compare the pre-repair snapshot with + # the ASR-repaired data returned in BUR. sample_mask = np.sum(np.abs(original_data - BUR['data']), axis=0) < 1e-8 del original_data # Convert retained samples to inclusive zero-based intervals. @@ -269,6 +282,8 @@ def clean_artifacts( sample_mask[s : e + 1] = False retain_intervals = retain_intervals[~small] + # Reject bad periods from the ASR-repaired dataset (BUR). + EEG = BUR rejected_intervals = mask_to_intervals(sample_mask, value=False) if rejected_intervals.size: EEG = eeg_eegrej(EEG, rejected_intervals) @@ -291,11 +306,9 @@ def clean_artifacts( logger.info('Use vis_artifacts to compare the cleaned data to the original.') - # ------------------------------------------------------------------ - # Optionally re‑insert ignored channels - # ------------------------------------------------------------------ - # The full MATLAB logic is complicated; the Python port currently skips the - # re‑insertion of previously excluded channels for simplicity. Users can - # merge channels back manually if needed. + # When Channels/Channels_ignore restrict the dataset, the returned EEG holds + # only the cleaned subset; excluded channels are not re-inserted (unlike the + # MATLAB reference). Callers that need the ignored channels back must merge + # them manually. This is documented on the Channels/Channels_ignore parameters. return EEG, HP, BUR, removed_channels diff --git a/src/eegprep/plugins/clean_rawdata/clean_asr.py b/src/eegprep/plugins/clean_rawdata/clean_asr.py index 5b69fd66..76470820 100644 --- a/src/eegprep/plugins/clean_rawdata/clean_asr.py +++ b/src/eegprep/plugins/clean_rawdata/clean_asr.py @@ -80,7 +80,9 @@ def clean_asr( raise ValueError("EEG dictionary must contain 'data', 'srate', and 'nbchan'.") useriemannian = _normalise_useriemannian(useriemannian) - data = np.asarray(EEG['data'], dtype=np.float64) + # Operate on a copy so the caller's data array (and dict) are never mutated; + # asr_calibrate zeroes non-finite samples in place on whatever array it receives. + data = np.array(EEG['data'], dtype=np.float64, copy=True) srate = float(EEG['srate']) nbchan = int(EEG['nbchan']) C, S = data.shape @@ -119,10 +121,14 @@ def clean_asr( "clean_windows returned insufficient data. Falling back to using all data for calibration." ) ref_section_data = data - except Exception as e: - logger.error(f"An error occurred during clean_windows: {e}") + except ValueError as e: + # clean_windows raises ValueError for expected calibration-data problems + # (empty data, window too small, not enough data for one window). Only + # those warrant the all-data fallback; unexpected exceptions propagate so + # genuine bugs are not masked as silently weaker ASR calibration. logger.warning( - "Could not automatically identify clean calibration data. Falling back to using the entire data for calibration." + f"Could not automatically identify clean calibration data ({e}). " + "Falling back to using the entire data for calibration." ) ref_section_data = data elif (isinstance(ref_maxbadchannels, str) and ref_maxbadchannels.lower() == 'off') or ref_maxbadchannels is None: @@ -200,6 +206,7 @@ def clean_asr( # --- Finalize --- # shift signal content back (to compensate for processing delay) outdata = outdata[:, :S] + EEG = deepcopy(EEG) EEG['data'] = outdata logger.info('ASR cleaning finished.') diff --git a/src/eegprep/plugins/clean_rawdata/clean_flatlines.py b/src/eegprep/plugins/clean_rawdata/clean_flatlines.py index d0bdeb4e..e3ea44a6 100644 --- a/src/eegprep/plugins/clean_rawdata/clean_flatlines.py +++ b/src/eegprep/plugins/clean_rawdata/clean_flatlines.py @@ -67,7 +67,8 @@ def clean_flatlines(EEG: Dict[str, Any], max_flatline_duration: float = 5.0, max EEG['nbchan'] = EEG['data'].shape[0] for fn in EEG.keys() & {'icawinv', 'icasphere', 'icaweights', 'icaact', 'stats', 'specdata', 'specicaact'}: EEG[fn] = np.array([]) - if CCM := EEG['etc'].get('clean_channel_mask') is not None: + CCM = EEG['etc'].get('clean_channel_mask') + if CCM is not None: CCM[CCM] = ~removed_channels else: EEG['etc']['clean_channel_mask'] = ~removed_channels diff --git a/src/eegprep/plugins/clean_rawdata/clean_windows.py b/src/eegprep/plugins/clean_rawdata/clean_windows.py index 5bdb11f3..732d17dd 100644 --- a/src/eegprep/plugins/clean_rawdata/clean_windows.py +++ b/src/eegprep/plugins/clean_rawdata/clean_windows.py @@ -5,6 +5,7 @@ """ import logging +from copy import deepcopy from typing import Any, Dict, Sequence, Tuple, Union import numpy as np @@ -79,9 +80,11 @@ def clean_windows( # ------------------------------------------------------------------ # Input handling # ------------------------------------------------------------------ + # Operate on a deep copy so the caller's dataset is never mutated. + EEG = deepcopy(EEG) input_data = np.asarray(EEG['data']) output_dtype = input_data.dtype if np.issubdtype(input_data.dtype, np.floating) else np.dtype(np.float64) - EEG['data'] = input_data.astype(np.float64, copy=False) + EEG['data'] = input_data.astype(np.float64, copy=True) C, S = EEG['data'].shape Fs = EEG['srate'] diff --git a/src/eegprep/plugins/clean_rawdata/pop_clean_rawdata.py b/src/eegprep/plugins/clean_rawdata/pop_clean_rawdata.py index 5eb1cfa8..e9f796f3 100644 --- a/src/eegprep/plugins/clean_rawdata/pop_clean_rawdata.py +++ b/src/eegprep/plugins/clean_rawdata/pop_clean_rawdata.py @@ -69,11 +69,13 @@ def pop_clean_rawdata( return (output, command) if return_com else output if int(EEG.get("trials", 1) or 1) > 1 or np.asarray(EEG.get("data")).ndim == 3: raise ValueError("Input data must be continuous. This data seems epoched.") - original_eeg = copy.deepcopy(EEG) if show_vis_artifacts else None - clean_eeg, _hp, _bur, _removed_channels = clean_artifacts(EEG, **options) + # Deep-copy so a failure (or any partial cleaning stage) leaves the caller's + # input dataset untouched and a successful run returns a distinct object. + working_eeg = copy.deepcopy(EEG) + clean_eeg, _hp, _bur, _removed_channels = clean_artifacts(working_eeg, **options) command = _history_command(options) - if show_vis_artifacts and original_eeg is not None: - vis_artifacts(clean_eeg, original_eeg) + if show_vis_artifacts: + vis_artifacts(clean_eeg, EEG) logger.info("Done.") return (clean_eeg, command) if return_com else clean_eeg diff --git a/src/eegprep/plugins/clean_rawdata/private/sigproc.py b/src/eegprep/plugins/clean_rawdata/private/sigproc.py index e6be3ffd..d5dc6fb6 100644 --- a/src/eegprep/plugins/clean_rawdata/private/sigproc.py +++ b/src/eegprep/plugins/clean_rawdata/private/sigproc.py @@ -1,12 +1,12 @@ """Signal processing utilities.""" -from typing import Optional, Sequence, Tuple, Union +from typing import Optional, Sequence, Union import numpy as np from scipy.signal import fftconvolve from ....functions.miscfunc.misc import round_mat -__all__ = ['design_kaiser', 'design_fir', 'filtfilt_fast', 'firwsord', 'firws'] +__all__ = ['design_kaiser', 'design_fir', 'filtfilt_fast'] def design_kaiser(lo: float, hi: float, atten: float, want_odd: bool, use_scipy: bool = False) -> np.ndarray: @@ -270,193 +270,3 @@ def slice_at(x, k): Y[slice_at(Y, k)] = res Z.p = (Z.p + 1) % N return (X if inplace else Y), Z - - -def firws( - m: int, f: Union[float, Sequence[float]], t: Optional[str] = None, w: Optional[np.ndarray] = None -) -> Tuple[np.ndarray, float]: - """Designs windowed sinc type I linear phase FIR filter. - - Parameters - ---------- - m : int - Filter order (mandatory even). - f : float or sequence of float - Vector or scalar of cutoff frequency/ies (-6 dB; pi rad / sample). - t : str, optional - 'high' for highpass, 'stop' for bandstop filter (default low-/bandpass). - w : array_like, optional - Vector of length m + 1 defining window (default hamming). - - Returns - ------- - b : np.ndarray - Filter coefficients. - a : float - Always 1 (FIR filter). - - Examples - -------- - fs = 500; cutoff = 0.5; df = 1; - m = firwsord('hamming', fs, df)[0] - b, a = firws(m, cutoff / (fs / 2), 'high', scipy.signal.windows.hamming(m + 1)) - - Notes - ----- - Based on a MATLAB implementation by Andreas Widmann, University of Leipzig, 2005. - """ - from scipy.signal.windows import hamming - - a = 1.0 - - if m <= 0 or not isinstance(m, int) or m % 2 != 0: - raise ValueError('Filter order must be a real, even, positive integer.') - - # Convert f to array and normalize - f_arr = np.asarray(f, dtype=float) - if f_arr.ndim == 0: - f_arr = f_arr.reshape(1) - f = f_arr / 2.0 - - if np.any(f <= 0) or np.any(f >= 0.5): - raise ValueError('Frequencies must fall in range between 0 and 1.') - - if t is None: - t = '' - - if w is None: - if t is not None and not isinstance(t, str): - # Handle case where third argument is window, not filter type - w = t - t = '' - else: - w = hamming(m + 1) - - # Make window row vector - w = np.asarray(w).flatten() - - b = _fkernel(m, f[0], w) - - if len(f) == 1 and t.lower() == 'high': - b = _fspecinv(b) - - if len(f) == 2: - b = b + _fspecinv(_fkernel(m, f[1], w)) - if not t or t.lower() != 'stop': - b = _fspecinv(b) - - return b, a - - -def _fkernel(m: int, f: float, w: np.ndarray) -> np.ndarray: - """Compute filter kernel. - - Parameters - ---------- - m : int - Filter order. - f : float - Normalized cutoff frequency. - w : np.ndarray - Window function. - - Returns - ------- - b : np.ndarray - Filter kernel. - """ - # Create range -m/2 : m/2 - n = np.arange(-m // 2, m // 2 + 1, dtype=float) - - # Compute sinc function - b = np.zeros_like(n) - - # Handle n == 0 case (no division by zero) - zero_idx = n == 0 - b[zero_idx] = 2 * np.pi * f - - # Handle n != 0 case - nonzero_idx = n != 0 - b[nonzero_idx] = np.sin(2 * np.pi * f * n[nonzero_idx]) / n[nonzero_idx] - - # Apply window - b = b * w - - # Normalization to unity gain at DC - b = b / np.sum(b) - - return b - - -def _fspecinv(b: np.ndarray) -> np.ndarray: - """Perform spectral inversion. - - Parameters - ---------- - b : np.ndarray - Filter coefficients. - - Returns - ------- - b_inv : np.ndarray - Spectrally inverted filter coefficients. - """ - b_inv = -b.copy() - center_idx = (len(b) - 1) // 2 - b_inv[center_idx] = b_inv[center_idx] + 1 - return b_inv - - -def firwsord(wintype: str, fs: float, df: float, dev: Optional[float] = None) -> Tuple[int, float]: - """Estimate windowed sinc FIR filter order depending on window type and requested transition band width. - - Parameters - ---------- - wintype : str - Window type. One of 'rectangular', 'hann', 'hamming', 'blackman', or 'kaiser'. - fs : float - Sampling frequency. - df : float - Requested transition band width. - dev : float, optional - Maximum passband deviation/ripple (Kaiser window only). - - Returns - ------- - m : int - Estimated filter order. - dev : float - Maximum passband deviation/ripple. - - Notes - ----- - Based on a MATLAB implementation by Andreas Widmann, University of Leipzig, 2005. - """ - win_type_array = ['rectangular', 'hann', 'hamming', 'blackman', 'kaiser'] - win_df_array = [0.9, 3.1, 3.3, 5.5] - win_dev_array = [0.089, 0.0063, 0.0022, 0.0002] - - # Check arguments - if fs is None or df is None or wintype is None: - raise ValueError('Not enough input arguments.') - - # Window type - try: - wintype_idx = win_type_array.index(wintype) - except ValueError: - raise ValueError('Unknown window type.') - - df_norm = df / fs # Normalize transition band width - - if wintype_idx == 4: # Kaiser window (index 4 in 0-based, was 5 in 1-based MATLAB) - if dev is None: - raise ValueError('Not enough input arguments.') - devdb = -20 * np.log10(dev) - m = 1 + (devdb - 8) / (2.285 * 2 * np.pi * df_norm) - else: - m = win_df_array[wintype_idx] / df_norm - dev = win_dev_array[wintype_idx] - - m = int(np.ceil(m / 2) * 2) # Make filter order even (FIR type I) - - return m, dev diff --git a/src/eegprep/plugins/firfilt/firws.py b/src/eegprep/plugins/firfilt/firws.py index 5ec4b279..3ec68789 100644 --- a/src/eegprep/plugins/firfilt/firws.py +++ b/src/eegprep/plugins/firfilt/firws.py @@ -1,5 +1,142 @@ """Windowed-sinc FIR filter design.""" -from ..clean_rawdata.private.sigproc import firws +from typing import Optional, Sequence, Tuple, Union + +import numpy as np __all__ = ["firws"] + + +def firws( + m: int, f: Union[float, Sequence[float]], t: Optional[str] = None, w: Optional[np.ndarray] = None +) -> Tuple[np.ndarray, float]: + """Designs windowed sinc type I linear phase FIR filter. + + Parameters + ---------- + m : int + Filter order (mandatory even). + f : float or sequence of float + Vector or scalar of cutoff frequency/ies (-6 dB; pi rad / sample). + t : str, optional + 'high' for highpass, 'stop' for bandstop filter (default low-/bandpass). + w : array_like, optional + Vector of length m + 1 defining window (default hamming). + + Returns + ------- + b : np.ndarray + Filter coefficients. + a : float + Always 1 (FIR filter). + + Examples + -------- + fs = 500; cutoff = 0.5; df = 1; + m = firwsord('hamming', fs, df)[0] + b, a = firws(m, cutoff / (fs / 2), 'high', scipy.signal.windows.hamming(m + 1)) + + Notes + ----- + Based on a MATLAB implementation by Andreas Widmann, University of Leipzig, 2005. + """ + from scipy.signal.windows import hamming + + a = 1.0 + + if m <= 0 or not isinstance(m, int) or m % 2 != 0: + raise ValueError('Filter order must be a real, even, positive integer.') + + # Convert f to array and normalize + f_arr = np.asarray(f, dtype=float) + if f_arr.ndim == 0: + f_arr = f_arr.reshape(1) + f = f_arr / 2.0 + + if np.any(f <= 0) or np.any(f >= 0.5): + raise ValueError('Frequencies must fall in range between 0 and 1.') + + if t is None: + t = '' + + if w is None: + if t is not None and not isinstance(t, str): + # Handle case where third argument is window, not filter type + w = t + t = '' + else: + w = hamming(m + 1) + + # Make window row vector + w = np.asarray(w).flatten() + + b = _fkernel(m, f[0], w) + + if len(f) == 1 and t.lower() == 'high': + b = _fspecinv(b) + + if len(f) == 2: + b = b + _fspecinv(_fkernel(m, f[1], w)) + if not t or t.lower() != 'stop': + b = _fspecinv(b) + + return b, a + + +def _fkernel(m: int, f: float, w: np.ndarray) -> np.ndarray: + """Compute filter kernel. + + Parameters + ---------- + m : int + Filter order. + f : float + Normalized cutoff frequency. + w : np.ndarray + Window function. + + Returns + ------- + b : np.ndarray + Filter kernel. + """ + # Create range -m/2 : m/2 + n = np.arange(-m // 2, m // 2 + 1, dtype=float) + + # Compute sinc function + b = np.zeros_like(n) + + # Handle n == 0 case (no division by zero) + zero_idx = n == 0 + b[zero_idx] = 2 * np.pi * f + + # Handle n != 0 case + nonzero_idx = n != 0 + b[nonzero_idx] = np.sin(2 * np.pi * f * n[nonzero_idx]) / n[nonzero_idx] + + # Apply window + b = b * w + + # Normalization to unity gain at DC + b = b / np.sum(b) + + return b + + +def _fspecinv(b: np.ndarray) -> np.ndarray: + """Perform spectral inversion. + + Parameters + ---------- + b : np.ndarray + Filter coefficients. + + Returns + ------- + b_inv : np.ndarray + Spectrally inverted filter coefficients. + """ + b_inv = -b.copy() + center_idx = (len(b) - 1) // 2 + b_inv[center_idx] = b_inv[center_idx] + 1 + return b_inv diff --git a/src/eegprep/plugins/firfilt/firwsord.py b/src/eegprep/plugins/firfilt/firwsord.py index 9e0d85ea..1d14c510 100644 --- a/src/eegprep/plugins/firfilt/firwsord.py +++ b/src/eegprep/plugins/firfilt/firwsord.py @@ -1,5 +1,62 @@ """Windowed-sinc FIR filter order estimation.""" -from ..clean_rawdata.private.sigproc import firwsord +from typing import Optional, Tuple + +import numpy as np __all__ = ["firwsord"] + + +def firwsord(wintype: str, fs: float, df: float, dev: Optional[float] = None) -> Tuple[int, float]: + """Estimate windowed sinc FIR filter order depending on window type and requested transition band width. + + Parameters + ---------- + wintype : str + Window type. One of 'rectangular', 'hann', 'hamming', 'blackman', or 'kaiser'. + fs : float + Sampling frequency. + df : float + Requested transition band width. + dev : float, optional + Maximum passband deviation/ripple (Kaiser window only). + + Returns + ------- + m : int + Estimated filter order. + dev : float + Maximum passband deviation/ripple. + + Notes + ----- + Based on a MATLAB implementation by Andreas Widmann, University of Leipzig, 2005. + """ + win_type_array = ['rectangular', 'hann', 'hamming', 'blackman', 'kaiser'] + win_df_array = [0.9, 3.1, 3.3, 5.5] + win_dev_array = [0.089, 0.0063, 0.0022, 0.0002] + + # Check arguments + if fs is None or df is None or wintype is None: + raise ValueError('Not enough input arguments.') + + # Window type + try: + wintype_idx = win_type_array.index(wintype) + except ValueError: + raise ValueError('Unknown window type.') + + df_norm = df / fs # Normalize transition band width + + if wintype_idx == 4: # Kaiser window (index 4 in 0-based, was 5 in 1-based MATLAB) + if dev is None: + raise ValueError('Not enough input arguments.') + devdb = -20 * np.log10(dev) + m = 1 + (devdb - 8) / (2.285 * 2 * np.pi * df_norm) + else: + m = win_df_array[wintype_idx] / df_norm + dev = win_dev_array[wintype_idx] + + m = int(np.ceil(m / 2) * 2) # Make filter order even (FIR type I) + + return m, dev diff --git a/src/eegprep/resources/help/pop_newset.md b/src/eegprep/resources/help/pop_newset.md index 04b4db6d..a339cf42 100644 --- a/src/eegprep/resources/help/pop_newset.md +++ b/src/eegprep/resources/help/pop_newset.md @@ -8,6 +8,9 @@ resampling, filtering, epoching, selecting data, rereferencing, interpolation, and cleaning. Choose whether the processed dataset should overwrite the current dataset or be stored as a new dataset. +Use **Edit description** to open a multiline editor for the dataset +`comments` field. + Common command-line forms: ```python diff --git a/src/eegprep/resources/help/pop_runica.md b/src/eegprep/resources/help/pop_runica.md index f3498281..fd1abe72 100644 --- a/src/eegprep/resources/help/pop_runica.md +++ b/src/eegprep/resources/help/pop_runica.md @@ -30,6 +30,10 @@ Calling `pop_runica(EEG)` opens an EEGLAB-style dialog with: - Channel type/index selection controls. - For multiple datasets, a dataset selector and concatenate controls. +When `pop_runica` is started from the main EEGPrep GUI, the ICA computation +runs behind an indeterminate progress dialog so the window can continue +repainting while the decomposition is being computed. + Behavior: - Supplying a non-default `icatype` programmatically, for example @@ -41,6 +45,9 @@ Behavior: - Existing ICLabel classifications are removed when ICA is recomputed because they no longer describe the active components. - `EEG.icaweights`, `EEG.icasphere`, `EEG.icawinv`, `EEG.icaact`, and `EEG.icachansind` are updated. - GUI-launched runica adds `'interrupt', 'on'` to the history command, matching EEGLAB's GUI path. +- GUI-launched ICA stores the updated dataset only after the background + computation finishes successfully. Failed runs leave the current dataset and + history unchanged. Examples: diff --git a/tests/conftest.py b/tests/conftest.py index 1735e1a8..a246c980 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -50,6 +50,7 @@ def _preload_matlab_libstdcxx() -> None: "tests/test_gui_pop_runica.py", "tests/test_gui_pop_select.py", "tests/test_gui_pop_study.py", + "tests/test_gui_long_task.py", "tests/test_gui_main_window.py", "tests/test_eegplot_gui.py", ) @@ -60,10 +61,8 @@ def _preload_matlab_libstdcxx() -> None: "tests/test_bids_preproc.py", "tests/test_clean_rawdata.py", "tests/test_eeg_compare.py", - "tests/test_eeg_eeg2mne.py", "tests/test_eeg_eegrej.py", "tests/test_eeg_lat2point.py", - "tests/test_eeg_mne2eeg_epochs.py", "tests/test_eeg_point2lat.py", "tests/test_eeg_rpsd_parity.py", "tests/test_eegfindboundaries.py", diff --git a/tests/matlab/eeg_iclabelcompare_features.m b/tests/matlab/eeg_iclabelcompare_features.m index e484c426..1e08837e 100644 --- a/tests/matlab/eeg_iclabelcompare_features.m +++ b/tests/matlab/eeg_iclabelcompare_features.m @@ -19,7 +19,7 @@ respy.grid{3} = single(repmat(respy.grid{3}, [1 1 1 4])); save('python_temp_reformated.mat', '-struct', 'respy'); -system([pythonFunc ' iclabel_net_load_py_measures.py']); +system([pythonFunc ' ../../tools/iclabel/iclabel_net_load_py_measures.py']); labels_py4 = load('-mat','output4_py.mat'); labels_py4 = reshape(mean(reshape(labels_py4.output', [], 4), 2), 7, [])'; delete('output4_py.mat'); diff --git a/tests/matlab/iclabel_compare_with_features_tmp.m b/tests/matlab/iclabel_compare_with_features_tmp.m index 2b2cff0a..18a42b95 100644 --- a/tests/matlab/iclabel_compare_with_features_tmp.m +++ b/tests/matlab/iclabel_compare_with_features_tmp.m @@ -18,7 +18,7 @@ res.grid{3} = single(repmat(res.grid{3}, [1 1 1 4])); save('python_temp_reformated.mat', '-struct', 'res'); -system([pythonFunc ' iclabel_net_load_py_measures.py']); +system([pythonFunc ' ../../tools/iclabel/iclabel_net_load_py_measures.py']); labels_py4 = load('-mat','output4_py.mat'); labels_py4 = reshape(mean(reshape(labels_py4.output', [], 4), 2), 7, [])'; delete('output4_py.mat'); diff --git a/tests/test_ICL_feature_extractor.py b/tests/test_ICL_feature_extractor.py index f2598111..67ea460a 100644 --- a/tests/test_ICL_feature_extractor.py +++ b/tests/test_ICL_feature_extractor.py @@ -12,31 +12,14 @@ import scipy.io from eegprep.plugins.ICLabel.ICL_feature_extractor import ICL_feature_extractor +from eegprep.plugins.ICLabel.eeg_rpsd import eeg_rpsd from eegprep.functions.popfunc.pop_loadset import pop_loadset from eegprep.functions.popfunc.pop_saveset import pop_saveset from eegprep.functions.adminfunc.eeglabcompat import get_eeglab -local_url = os.path.join(os.path.dirname(__file__), '../sample_data/') - - -def create_test_eeg(n_channels=32, n_samples=1000, srate=250.0, n_trials=1): - """Create a synthetic EEG structure for testing.""" - data = np.random.randn(n_channels, n_samples, n_trials) * 0.5 - if n_trials == 1: - data = data.squeeze(axis=2) # Remove trial dimension for continuous data +from tests.fixtures import create_test_eeg - return { - 'data': data, - 'srate': srate, - 'pnts': n_samples, - 'nbchan': n_channels, - 'trials': n_trials, - 'xmin': 0.0, - 'xmax': (n_samples - 1) / srate, - 'times': np.arange(n_samples) / srate, - 'event': [], - 'ref': 'unknown', - } +local_url = os.path.join(os.path.dirname(__file__), '../sample_data/') class TestICLFeatureExtractorBasic(unittest.TestCase): @@ -81,22 +64,31 @@ def setUp(self): ) def test_icl_feature_extractor_missing_ica_winv(self): - """Test ICL_feature_extractor with missing icawinv.""" + """Missing icawinv raises a clear ValueError before any dereference.""" EEG = self.test_eeg.copy() del EEG['icawinv'] - # Function has a bug - it tries to access icawinv before checking if it exists - with self.assertRaises(KeyError): + with self.assertRaises(ValueError) as cm: ICL_feature_extractor(EEG) + self.assertIn('ICA decomposition', str(cm.exception)) def test_icl_feature_extractor_empty_ica_winv(self): - """Test ICL_feature_extractor with empty icawinv.""" + """Empty icawinv raises a clear ValueError before any dereference.""" EEG = self.test_eeg.copy() EEG['icawinv'] = np.array([]) - # Function has a bug - it tries to access shape[1] on empty array - with self.assertRaises(IndexError): + with self.assertRaises(ValueError) as cm: ICL_feature_extractor(EEG) + self.assertIn('ICA decomposition', str(cm.exception)) + + def test_icl_feature_extractor_missing_ref_field(self): + """A dataset without a 'ref' field is treated as non-average and re-referenced.""" + EEG = self.test_eeg.copy() + del EEG['ref'] + + # Must not raise KeyError on the missing 'ref'; should proceed to feature extraction. + features = ICL_feature_extractor(EEG, flag_autocorr=False) + self.assertEqual(len(features), 2) def test_icl_feature_extractor_missing_icaact(self): """Test ICL_feature_extractor with missing icaact.""" @@ -109,54 +101,46 @@ def test_icl_feature_extractor_missing_icaact(self): def test_icl_feature_extractor_basic_functionality(self): """Test basic ICL_feature_extractor functionality.""" - try: - features = ICL_feature_extractor(self.test_eeg, flag_autocorr=False) - - # Should return 2 features (topo and psd) when flag_autocorr=False - self.assertEqual(len(features), 2) + features = ICL_feature_extractor(self.test_eeg, flag_autocorr=False) - # Check topo features - topo = features[0] - self.assertEqual(topo.shape, (32, 32, 1, self.n_components)) - self.assertEqual(topo.dtype, np.float32) - self.assertTrue(np.all(np.abs(topo) <= 0.99)) # Should be scaled by 0.99 + # Should return 2 features (topo and psd) when flag_autocorr=False + self.assertEqual(len(features), 2) - # Check psd features - psd = features[1] - self.assertEqual(psd.shape, (1, 100, 1, self.n_components)) - self.assertEqual(psd.dtype, np.float32) - self.assertTrue(np.all(np.abs(psd) <= 0.99)) # Should be scaled by 0.99 + # Check topo features + topo = features[0] + self.assertEqual(topo.shape, (32, 32, 1, self.n_components)) + self.assertEqual(topo.dtype, np.float32) + self.assertTrue(np.all(np.abs(topo) <= 0.99)) # Should be scaled by 0.99 - except Exception as e: - self.skipTest(f"ICL_feature_extractor basic functionality not available: {e}") + # Check psd features + psd = features[1] + self.assertEqual(psd.shape, (1, 100, 1, self.n_components)) + self.assertEqual(psd.dtype, np.float32) + self.assertTrue(np.all(np.abs(psd) <= 0.99)) # Should be scaled by 0.99 def test_icl_feature_extractor_with_autocorr(self): """Test ICL_feature_extractor with autocorrelation features.""" - try: - features = ICL_feature_extractor(self.test_eeg, flag_autocorr=True) + features = ICL_feature_extractor(self.test_eeg, flag_autocorr=True) - # Should return 3 features (topo, psd, autocorr) when flag_autocorr=True - self.assertEqual(len(features), 3) + # Should return 3 features (topo, psd, autocorr) when flag_autocorr=True + self.assertEqual(len(features), 3) - # Check topo features - topo = features[0] - self.assertEqual(topo.shape, (32, 32, 1, self.n_components)) - self.assertEqual(topo.dtype, np.float32) + # Check topo features + topo = features[0] + self.assertEqual(topo.shape, (32, 32, 1, self.n_components)) + self.assertEqual(topo.dtype, np.float32) - # Check psd features - psd = features[1] - self.assertEqual(psd.shape, (1, 100, 1, self.n_components)) - self.assertEqual(psd.dtype, np.float32) + # Check psd features + psd = features[1] + self.assertEqual(psd.shape, (1, 100, 1, self.n_components)) + self.assertEqual(psd.dtype, np.float32) - # Check autocorr features - autocorr = features[2] - self.assertEqual(autocorr.ndim, 4) # Should be 4D - self.assertEqual(autocorr.dtype, np.float32) - self.assertEqual(autocorr.shape[3], self.n_components) # Last dimension should be n_components - self.assertTrue(np.all(np.abs(autocorr) <= 0.99)) # Should be scaled by 0.99 - - except Exception as e: - self.skipTest(f"ICL_feature_extractor with autocorr not available: {e}") + # Check autocorr features + autocorr = features[2] + self.assertEqual(autocorr.ndim, 4) # Should be 4D + self.assertEqual(autocorr.dtype, np.float32) + self.assertEqual(autocorr.shape[3], self.n_components) # Last dimension should be n_components + self.assertTrue(np.all(np.abs(autocorr) <= 0.99)) # Should be scaled by 0.99 class TestICLFeatureExtractorDataTypes(unittest.TestCase): @@ -201,32 +185,24 @@ def test_icl_feature_extractor_float32_data(self): EEG = self.base_eeg.copy() EEG['icaact'] = EEG['icaact'].astype(np.float32) - try: - features = ICL_feature_extractor(EEG, flag_autocorr=False) + features = ICL_feature_extractor(EEG, flag_autocorr=False) - # Should work and return float32 features - self.assertEqual(len(features), 2) - for feature in features: - self.assertEqual(feature.dtype, np.float32) - - except Exception as e: - self.skipTest(f"ICL_feature_extractor float32 test not available: {e}") + # Should work and return float32 features + self.assertEqual(len(features), 2) + for feature in features: + self.assertEqual(feature.dtype, np.float32) def test_icl_feature_extractor_float64_data(self): """Test ICL_feature_extractor with float64 input data.""" EEG = self.base_eeg.copy() EEG['icaact'] = EEG['icaact'].astype(np.float64) - try: - features = ICL_feature_extractor(EEG, flag_autocorr=False) - - # Should work and return float32 features (converted internally) - self.assertEqual(len(features), 2) - for feature in features: - self.assertEqual(feature.dtype, np.float32) + features = ICL_feature_extractor(EEG, flag_autocorr=False) - except Exception as e: - self.skipTest(f"ICL_feature_extractor float64 test not available: {e}") + # Should work and return float32 features (converted internally) + self.assertEqual(len(features), 2) + for feature in features: + self.assertEqual(feature.dtype, np.float32) class TestICLFeatureExtractorEdgeCases(unittest.TestCase): @@ -268,21 +244,17 @@ def setUp(self): def test_icl_feature_extractor_small_eeg_data(self): """Test ICL_feature_extractor with small EEG data.""" - try: - features = ICL_feature_extractor(self.base_eeg, flag_autocorr=False) - - # Should work with small data - self.assertEqual(len(features), 2) + features = ICL_feature_extractor(self.base_eeg, flag_autocorr=False) - # Check feature dimensions - topo = features[0] - self.assertEqual(topo.shape, (32, 32, 1, self.n_components)) + # Should work with small data + self.assertEqual(len(features), 2) - psd = features[1] - self.assertEqual(psd.shape, (1, 100, 1, self.n_components)) + # Check feature dimensions + topo = features[0] + self.assertEqual(topo.shape, (32, 32, 1, self.n_components)) - except Exception as e: - self.skipTest(f"ICL_feature_extractor small EEG test not available: {e}") + psd = features[1] + self.assertEqual(psd.shape, (1, 100, 1, self.n_components)) def test_icl_feature_extractor_single_component(self): """Test ICL_feature_extractor with single ICA component.""" @@ -291,21 +263,17 @@ def test_icl_feature_extractor_single_component(self): EEG['icaweights'] = EEG['icaweights'][:1, :] # Keep only first component EEG['icaact'] = EEG['icaact'][:1, :, :] # Keep only first component - try: - features = ICL_feature_extractor(EEG, flag_autocorr=False) - - # Should work with single component - self.assertEqual(len(features), 2) + features = ICL_feature_extractor(EEG, flag_autocorr=False) - # Check feature dimensions - topo = features[0] - self.assertEqual(topo.shape, (32, 32, 1, 1)) # 1 component + # Should work with single component + self.assertEqual(len(features), 2) - psd = features[1] - self.assertEqual(psd.shape, (1, 100, 1, 1)) # 1 component + # Check feature dimensions + topo = features[0] + self.assertEqual(topo.shape, (32, 32, 1, 1)) # 1 component - except Exception as e: - self.skipTest(f"ICL_feature_extractor single component test not available: {e}") + psd = features[1] + self.assertEqual(psd.shape, (1, 100, 1, 1)) # 1 component def test_icl_feature_extractor_many_components(self): """Test ICL_feature_extractor with many ICA components.""" @@ -317,21 +285,17 @@ def test_icl_feature_extractor_many_components(self): EEG['icaweights'] = np.linalg.pinv(EEG['icawinv']) EEG['icaact'] = np.random.randn(n_many_components, self.n_samples, 1) * 0.5 - try: - features = ICL_feature_extractor(EEG, flag_autocorr=False) + features = ICL_feature_extractor(EEG, flag_autocorr=False) - # Should work with many components - self.assertEqual(len(features), 2) + # Should work with many components + self.assertEqual(len(features), 2) - # Check feature dimensions - topo = features[0] - self.assertEqual(topo.shape, (32, 32, 1, n_many_components)) + # Check feature dimensions + topo = features[0] + self.assertEqual(topo.shape, (32, 32, 1, n_many_components)) - psd = features[1] - self.assertEqual(psd.shape, (1, 100, 1, n_many_components)) - - except Exception as e: - self.skipTest(f"ICL_feature_extractor many components test not available: {e}") + psd = features[1] + self.assertEqual(psd.shape, (1, 100, 1, n_many_components)) def test_icl_feature_extractor_very_short_data(self): """Test ICL_feature_extractor with short data (minimum for 100 freq bins).""" @@ -344,21 +308,17 @@ def test_icl_feature_extractor_very_short_data(self): EEG['pnts'] = short_samples EEG['xmax'] = short_samples / self.srate - try: - features = ICL_feature_extractor(EEG, flag_autocorr=False) - - # Should work with very short data - self.assertEqual(len(features), 2) + features = ICL_feature_extractor(EEG, flag_autocorr=False) - # Features should still have expected shapes - topo = features[0] - self.assertEqual(topo.shape[0:3], (32, 32, 1)) + # Should work with very short data + self.assertEqual(len(features), 2) - psd = features[1] - self.assertEqual(psd.shape[0:3], (1, 100, 1)) + # Features should still have expected shapes + topo = features[0] + self.assertEqual(topo.shape[0:3], (32, 32, 1)) - except Exception as e: - self.skipTest(f"ICL_feature_extractor very short data test not available: {e}") + psd = features[1] + self.assertEqual(psd.shape[0:3], (1, 100, 1)) def test_icl_feature_extractor_autocorr_path_selection(self): """Test ICL_feature_extractor autocorr path selection based on data length.""" @@ -371,11 +331,8 @@ def test_icl_feature_extractor_autocorr_path_selection(self): short_eeg['xmax'] = 3.0 short_eeg['times'] = np.arange(short_pnts) / self.srate - try: - features = ICL_feature_extractor(short_eeg, flag_autocorr=True) - self.assertEqual(len(features), 3) # Should include autocorr - except Exception as e: - self.skipTest(f"ICL_feature_extractor short data autocorr test not available: {e}") + features = ICL_feature_extractor(short_eeg, flag_autocorr=True) + self.assertEqual(len(features), 3) # Should include autocorr # Test long data (> 5 seconds) - should use eeg_autocorr_welch long_eeg = self.base_eeg.copy() @@ -386,11 +343,8 @@ def test_icl_feature_extractor_autocorr_path_selection(self): long_eeg['xmax'] = 6.0 long_eeg['times'] = np.arange(long_pnts) / self.srate - try: - features = ICL_feature_extractor(long_eeg, flag_autocorr=True) - self.assertEqual(len(features), 3) # Should include autocorr - except Exception as e: - self.skipTest(f"ICL_feature_extractor long data autocorr test not available: {e}") + features = ICL_feature_extractor(long_eeg, flag_autocorr=True) + self.assertEqual(len(features), 3) # Should include autocorr def test_icl_feature_extractor_multi_trial_data(self): """Test ICL_feature_extractor with multi-trial data.""" @@ -402,18 +356,14 @@ def test_icl_feature_extractor_multi_trial_data(self): EEG['icaact'] = np.random.randn(self.n_components, self.n_samples, n_trials) * 0.5 EEG['data'] = np.random.randn(self.n_channels, self.n_samples, n_trials) * 0.5 - try: - features = ICL_feature_extractor(EEG, flag_autocorr=True) - - # Should work with multi-trial data and use eeg_autocorr_fftw - self.assertEqual(len(features), 3) + features = ICL_feature_extractor(EEG, flag_autocorr=True) - # Check that features have correct component dimension - for feature in features: - self.assertEqual(feature.shape[3], self.n_components) + # Should work with multi-trial data and use eeg_autocorr_fftw + self.assertEqual(len(features), 3) - except Exception as e: - self.skipTest(f"ICL_feature_extractor multi-trial test not available: {e}") + # Check that features have correct component dimension + for feature in features: + self.assertEqual(feature.shape[3], self.n_components) class TestICLFeatureExtractorValidation(unittest.TestCase): @@ -455,49 +405,37 @@ def setUp(self): def test_icl_feature_extractor_no_inf_nan_in_features(self): """Test ICL_feature_extractor produces no inf/nan values in features.""" - try: - features = ICL_feature_extractor(self.base_eeg, flag_autocorr=False) - - # Check that no features contain inf or nan - for i, feature in enumerate(features): - self.assertTrue(np.all(np.isfinite(feature)), f"Feature {i} contains inf or nan values") + features = ICL_feature_extractor(self.base_eeg, flag_autocorr=False) - except Exception as e: - self.skipTest(f"ICL_feature_extractor inf/nan test not available: {e}") + # Check that no features contain inf or nan + for i, feature in enumerate(features): + self.assertTrue(np.all(np.isfinite(feature)), f"Feature {i} contains inf or nan values") def test_icl_feature_extractor_deterministic_seed(self): """Test ICL_feature_extractor produces consistent results with same seed.""" - try: - # Set seed and extract features - np.random.seed(123) - features1 = ICL_feature_extractor(self.base_eeg, flag_autocorr=False) + # Set seed and extract features + np.random.seed(123) + features1 = ICL_feature_extractor(self.base_eeg, flag_autocorr=False) - # Reset seed and extract features again - np.random.seed(123) - features2 = ICL_feature_extractor(self.base_eeg, flag_autocorr=False) + # Reset seed and extract features again + np.random.seed(123) + features2 = ICL_feature_extractor(self.base_eeg, flag_autocorr=False) - # Results should be identical (at least for the deterministic parts) - # Note: Some randomness may come from internal functions, so we check structure - self.assertEqual(len(features1), len(features2)) - for i in range(len(features1)): - self.assertEqual(features1[i].shape, features2[i].shape) - self.assertEqual(features1[i].dtype, features2[i].dtype) - - except Exception as e: - self.skipTest(f"ICL_feature_extractor deterministic test not available: {e}") + # Results should be identical (at least for the deterministic parts) + # Note: Some randomness may come from internal functions, so we check structure + self.assertEqual(len(features1), len(features2)) + for i in range(len(features1)): + self.assertEqual(features1[i].shape, features2[i].shape) + self.assertEqual(features1[i].dtype, features2[i].dtype) def test_icl_feature_extractor_ref_not_average(self): """Test ICL_feature_extractor with reference not set to average.""" EEG = self.base_eeg.copy() EEG['ref'] = 'Cz' # Not average reference - try: - # Should still work (function re-references internally) - features = ICL_feature_extractor(EEG, flag_autocorr=False) - self.assertEqual(len(features), 2) - - except Exception as e: - self.skipTest(f"ICL_feature_extractor non-average ref test not available: {e}") + # Should still work (function re-references internally) + features = ICL_feature_extractor(EEG, flag_autocorr=False) + self.assertEqual(len(features), 2) def test_icl_feature_extractor_mismatched_icachansind(self): """Test ICL_feature_extractor with mismatched icachansind.""" @@ -514,33 +452,62 @@ def test_icl_feature_extractor_mismatched_icachansind(self): except (ValueError, IndexError): # Expected behavior for mismatched indices pass - except Exception as e: - self.skipTest(f"ICL_feature_extractor mismatched icachansind test not available: {e}") def test_icl_feature_extractor_feature_scaling(self): """Test ICL_feature_extractor feature scaling (should be scaled by 0.99).""" - try: - features = ICL_feature_extractor(self.base_eeg, flag_autocorr=True) + features = ICL_feature_extractor(self.base_eeg, flag_autocorr=True) - # All features should be scaled by 0.99 (max absolute value <= 0.99) - for i, feature in enumerate(features): - max_abs_val = np.max(np.abs(feature)) - self.assertLessEqual(max_abs_val, 0.99 + 1e-6, f"Feature {i} not properly scaled by 0.99") - - except Exception as e: - self.skipTest(f"ICL_feature_extractor scaling test not available: {e}") + # All features should be scaled by 0.99 (max absolute value <= 0.99) + for i, feature in enumerate(features): + max_abs_val = np.max(np.abs(feature)) + self.assertLessEqual(max_abs_val, 0.99 + 1e-6, f"Feature {i} not properly scaled by 0.99") def test_icl_feature_extractor_psd_length_extrapolation(self): """Test ICL_feature_extractor PSD length handling (should be 100 frequencies).""" - try: - features = ICL_feature_extractor(self.base_eeg, flag_autocorr=False) + features = ICL_feature_extractor(self.base_eeg, flag_autocorr=False) - # PSD should always have 100 frequency bins (extrapolated if needed) - psd = features[1] - self.assertEqual(psd.shape[1], 100, "PSD should have exactly 100 frequency bins") + # PSD should always have 100 frequency bins (extrapolated if needed) + psd = features[1] + self.assertEqual(psd.shape[1], 100, "PSD should have exactly 100 frequency bins") - except Exception as e: - self.skipTest(f"ICL_feature_extractor PSD extrapolation test not available: {e}") + +class TestEegRpsdGlobalRng(unittest.TestCase): + """Regression tests that eeg_rpsd never mutates the global numpy RNG.""" + + def setUp(self): + np.random.seed(42) + n_channels = 16 + n_components = 4 + n_samples = 500 + srate = 250.0 + eeg = create_test_eeg(n_channels=n_channels, n_samples=n_samples, srate=srate, n_trials=1) + eeg['icawinv'] = np.random.randn(n_channels, n_components) * 0.5 + eeg['icaweights'] = np.linalg.pinv(eeg['icawinv']) + eeg['icasphere'] = np.eye(n_channels) + eeg['icaact'] = np.random.randn(n_components, n_samples, 1) * 0.5 + eeg['icachansind'] = np.arange(n_channels) + self.eeg = eeg + + def test_does_not_mutate_global_rng(self): + """eeg_rpsd must use a local RNG, leaving np.random's global state intact.""" + np.random.seed(123) + before = np.random.get_state() + eeg_rpsd(self.eeg) + after = np.random.get_state() + + self.assertEqual(before[0], after[0]) + self.assertTrue(np.array_equal(before[1], after[1])) + self.assertEqual(before[2:], after[2:]) + + def test_output_is_deterministic(self): + """eeg_rpsd must return identical output regardless of global RNG state.""" + np.random.seed(1) + psd_a = eeg_rpsd(self.eeg) + np.random.seed(999) + np.random.rand(37) # perturb the global RNG between calls + psd_b = eeg_rpsd(self.eeg) + + self.assertTrue(np.array_equal(psd_a, psd_b)) class TestICLFeatureExtractorParity(unittest.TestCase): diff --git a/tests/test_bids_gen_derived_fpath.py b/tests/test_bids_gen_derived_fpath.py new file mode 100644 index 00000000..47468115 --- /dev/null +++ b/tests/test_bids_gen_derived_fpath.py @@ -0,0 +1,52 @@ +"""Unit tests for gen_derived_fpath path construction (no MATLAB/pybids needed).""" + +import os +import unittest + +from eegprep.plugins.EEG_BIDS.bids import gen_derived_fpath + + +def _raw_fpath(): + # A BIDS-style raw EEG file path inside a dataset rooted at . + return os.path.join(os.sep, 'data', 'ds001', 'sub-01', 'eeg', 'sub-01_task-rest_eeg.set') + + +class TestGenDerivedFpath(unittest.TestCase): + def test_default_root_placeholder_substituted(self): + """The default outputdir's {root} placeholder is replaced with the dataset root. + + Reproduces the bug where the documented default '${root}/...' left a literal + placeholder ('$/data/...') in the output path. + """ + out = gen_derived_fpath(_raw_fpath(), keyword='desc-cleaned') + expected = os.path.join( + os.sep, + 'data', + 'ds001', + 'derivatives', + 'clean_artifacts', + 'sub-01', + 'eeg', + 'sub-01_task-rest_desc-cleaned_eeg.set', + ) + self.assertEqual(out, expected) + self.assertNotIn('{root}', out) + self.assertNotIn('$', out) + + def test_explicit_root_placeholder_substituted(self): + """An explicit '{root}/...' outputdir is substituted with the dataset root.""" + out = gen_derived_fpath(_raw_fpath(), outputdir='{root}/derivatives/eegprep') + expected = os.path.join( + os.sep, 'data', 'ds001', 'derivatives', 'eegprep', 'sub-01', 'eeg', 'sub-01_task-rest_eeg.set' + ) + self.assertEqual(out, expected) + + def test_path_assembly_uses_os_sep(self): + """The assembled path uses the OS separator throughout (no hardcoded '/').""" + out = gen_derived_fpath(_raw_fpath(), keyword='desc-cleaned') + # Every separator must be the platform separator produced by os.path.join. + self.assertEqual(out, os.path.normpath(out)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_bids_preproc.py b/tests/test_bids_preproc.py index e550e903..3dbf325b 100644 --- a/tests/test_bids_preproc.py +++ b/tests/test_bids_preproc.py @@ -9,13 +9,11 @@ import numpy as np +from eegprep.plugins.EEG_BIDS.coords import coords_to_mm from eegprep.utils.testing import DebuggableTestCase logger = logging.getLogger(__name__) -if os.getenv('EEGPREP_SKIP_MATLAB') == '1': - raise unittest.SkipTest("MATLAB not available") - curhost = socket.gethostname() # add your host to this list if you want to run the (very) slow tests @@ -74,7 +72,7 @@ def test_end2end(self): """End-to-end test vs MATLAB.""" from eegprep import bids_preproc, pop_loadset, eeg_checkset_strict_mode from eegprep.functions.adminfunc.eeglabcompat import get_eeglab - from eegprep.plugins.EEG_BIDS.stage_comparison import generate_comparison_table, save_comparison_report + from tools.eeg_bids.stage_comparison import generate_comparison_table, save_comparison_report from datetime import datetime for study in self.studies: @@ -233,3 +231,24 @@ def test_crashability_slow(self): EpochBaseline=[None, 0], MinimizeDiskUsage=False, ) + + +class TestCoordsToMm(unittest.TestCase): + """Regression tests that coords_to_mm never mutates the caller's array.""" + + def test_does_not_mutate_input(self): + for unit, factor in (('mm', 1.0), ('cm', 10.0), ('m', 1000.0)): + with self.subTest(unit=unit): + coords = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + original = coords.copy() + + out = coords_to_mm(coords, unit) + + # Caller's array is unchanged and the result is a distinct array. + self.assertTrue(np.array_equal(original, coords)) + self.assertIsNot(out, coords) + np.testing.assert_array_equal(out, original * factor) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_clean_artifacts.py b/tests/test_clean_artifacts.py index 57add3e3..89d725fc 100644 --- a/tests/test_clean_artifacts.py +++ b/tests/test_clean_artifacts.py @@ -12,64 +12,15 @@ # Add src to path for imports sys.path.insert(0, 'src') from eegprep.plugins.clean_rawdata.clean_artifacts import clean_artifacts +from eegprep.plugins.clean_rawdata.pop_clean_rawdata import pop_clean_rawdata from eegprep.utils.testing import DebuggableTestCase +from tests.fixtures import create_test_eeg as _create_test_eeg + def create_test_eeg(): - """Create a complete test EEG structure with all required fields. - - Note: clean_artifacts expects continuous (2D) data, not epoched (3D) data. - """ - n_pnts = 10000 # 20 seconds at 500 Hz - return { - 'data': np.random.randn(32, n_pnts), # 2D continuous data - 'srate': 500.0, - 'nbchan': 32, - 'pnts': n_pnts, - 'trials': 1, - 'xmin': 0.0, - 'xmax': n_pnts / 500.0, - 'times': np.linspace(0, n_pnts / 500.0, n_pnts), - 'icaact': [], - 'icawinv': [], - 'icasphere': [], - 'icaweights': [], - 'icachansind': [], - 'chanlocs': [ - { - 'labels': f'EEG{i:03d}', - 'type': 'EEG', - 'theta': np.random.uniform(-90, 90), - 'radius': np.random.uniform(0, 1), - 'X': np.random.uniform(-1, 1), - 'Y': np.random.uniform(-1, 1), - 'Z': np.random.uniform(-1, 1), - 'sph_theta': np.random.uniform(-180, 180), - 'sph_phi': np.random.uniform(-90, 90), - 'sph_radius': np.random.uniform(0, 1), - 'urchan': i + 1, - 'ref': '', - } - for i in range(32) - ], - 'urchanlocs': [], - 'chaninfo': {'removedchans': []}, - 'ref': 'common', - 'history': '', - 'saved': 'yes', - 'etc': {}, - 'event': [], - 'epoch': [], - 'setname': 'test_dataset', - 'filename': 'test.set', - 'filepath': '/tmp', - 'specdata': [], - 'specicaact': [], - 'reject': [], - 'stats': [], - 'dipfit': [], - 'roi': [], - } + """Continuous (2D) EEG fixture sized for clean_artifacts (20 s at 500 Hz).""" + return _create_test_eeg(n_channels=32, n_samples=10000, srate=500.0, n_trials=1) class TestCleanArtifactsBasic(DebuggableTestCase): @@ -81,49 +32,44 @@ def setUp(self): def test_clean_artifacts_basic_functionality(self): """Test basic clean_artifacts functionality with default parameters.""" - try: - EEG, HP, BUR, removed_channels = clean_artifacts(self.test_eeg) - - # Check that all return values are present - self.assertIsInstance(EEG, dict) - self.assertIsInstance(HP, dict) - self.assertIsInstance(BUR, dict) - self.assertIsInstance(removed_channels, np.ndarray) + EEG, HP, BUR, removed_channels = clean_artifacts(self.test_eeg) - # Check that EEG structure is preserved - self.assertIn('data', EEG) - self.assertIn('srate', EEG) - self.assertIn('nbchan', EEG) - self.assertIn('pnts', EEG) + # Check that all return values are present + self.assertIsInstance(EEG, dict) + self.assertIsInstance(HP, dict) + self.assertIsInstance(BUR, dict) + self.assertIsInstance(removed_channels, np.ndarray) - # Check that data dimensions are reasonable - self.assertEqual(EEG['srate'], self.test_eeg['srate']) - self.assertGreaterEqual(EEG['nbchan'], 1) # At least one channel should remain - self.assertLessEqual(EEG['nbchan'], self.test_eeg['nbchan']) + # Check that EEG structure is preserved + self.assertIn('data', EEG) + self.assertIn('srate', EEG) + self.assertIn('nbchan', EEG) + self.assertIn('pnts', EEG) - except Exception as e: - self.skipTest(f"clean_artifacts basic functionality not available: {e}") + # Check that data dimensions are reasonable + self.assertEqual(EEG['srate'], self.test_eeg['srate']) + self.assertGreaterEqual(EEG['nbchan'], 1) # At least one channel should remain + self.assertLessEqual(EEG['nbchan'], self.test_eeg['nbchan']) def test_clean_artifacts_all_off(self): """Test clean_artifacts with all criteria disabled.""" - try: - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - - # With all criteria off, data should be unchanged - self.assertEqual(EEG['nbchan'], self.test_eeg['nbchan']) - self.assertEqual(EEG['pnts'], self.test_eeg['pnts']) - np.testing.assert_array_equal(EEG['data'], self.test_eeg['data']) - - except Exception as e: - self.skipTest(f"clean_artifacts all off not available: {e}") + self.test_eeg.pop('etc') + original_keys = set(self.test_eeg) + EEG, HP, BUR, removed_channels = clean_artifacts( + self.test_eeg, + ChannelCriterion='off', + LineNoiseCriterion='off', + BurstCriterion='off', + WindowCriterion='off', + Highpass='off', + FlatlineCriterion='off', + ) + + # With all criteria off, data should be unchanged + self.assertEqual(EEG['nbchan'], self.test_eeg['nbchan']) + self.assertEqual(EEG['pnts'], self.test_eeg['pnts']) + np.testing.assert_array_equal(EEG['data'], self.test_eeg['data']) + self.assertEqual(set(self.test_eeg), original_keys) def test_clean_artifacts_invalid_highpass_string(self): """Test clean_artifacts with invalid highpass string parameter.""" @@ -157,91 +103,79 @@ def test_clean_artifacts_invalid_highpass_list_single(self): def test_clean_artifacts_valid_highpass_list(self): """Test clean_artifacts with valid highpass list (should work like tuple).""" - try: - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - Highpass=[0.25, 0.75], # List instead of tuple - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - FlatlineCriterion='off', - ) - # Should work - list is acceptable - self.assertIsInstance(EEG, dict) - except Exception as e: - self.skipTest(f"clean_artifacts valid highpass list not available: {e}") + EEG, HP, BUR, removed_channels = clean_artifacts( + self.test_eeg, + Highpass=[0.25, 0.75], # List instead of tuple + ChannelCriterion='off', + LineNoiseCriterion='off', + BurstCriterion='off', + WindowCriterion='off', + FlatlineCriterion='off', + ) + # Should work - list is acceptable + self.assertIsInstance(EEG, dict) def test_clean_artifacts_mutually_exclusive_channels(self): """Test clean_artifacts with mutually exclusive channel parameters.""" with self.assertRaises(ValueError) as cm: - clean_artifacts(self.test_eeg, Channels=['EEG001', 'EEG002'], Channels_ignore=['EEG003']) + clean_artifacts(self.test_eeg, Channels=['Ch1', 'Ch2'], Channels_ignore=['Ch3']) self.assertIn('mutually exclusive', str(cm.exception)) def test_clean_artifacts_mutually_exclusive_channels_both_empty(self): """Test clean_artifacts with both channel parameters empty (should work).""" - try: - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - Channels=[], # Empty list - Channels_ignore=[], # Empty list - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - # Should work - empty lists are not mutually exclusive - self.assertIsInstance(EEG, dict) - except Exception as e: - self.skipTest(f"clean_artifacts empty channel lists not available: {e}") + EEG, HP, BUR, removed_channels = clean_artifacts( + self.test_eeg, + Channels=[], # Empty list + Channels_ignore=[], # Empty list + ChannelCriterion='off', + LineNoiseCriterion='off', + BurstCriterion='off', + WindowCriterion='off', + Highpass='off', + FlatlineCriterion='off', + ) + # Should work - empty lists are not mutually exclusive + self.assertIsInstance(EEG, dict) def test_clean_artifacts_mutually_exclusive_channels_none_and_list(self): """Test clean_artifacts with None and non-empty list (should work).""" - try: - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - Channels=None, # None - Channels_ignore=['EEG001'], # Non-empty list - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - # Should work - None and list is not mutually exclusive - self.assertIsInstance(EEG, dict) - except Exception as e: - self.skipTest(f"clean_artifacts None and channel list not available: {e}") + EEG, HP, BUR, removed_channels = clean_artifacts( + self.test_eeg, + Channels=None, # None + Channels_ignore=['Ch1'], # Non-empty list + ChannelCriterion='off', + LineNoiseCriterion='off', + BurstCriterion='off', + WindowCriterion='off', + Highpass='off', + FlatlineCriterion='off', + ) + # Should work - None and list is not mutually exclusive + self.assertIsInstance(EEG, dict) def test_clean_artifacts_mutually_exclusive_channels_both_none(self): """Test clean_artifacts with both channel parameters as None (should work).""" - try: - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - Channels=None, # None - Channels_ignore=None, # None - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - # Should work - both None is not mutually exclusive - self.assertIsInstance(EEG, dict) - except Exception as e: - self.skipTest(f"clean_artifacts both None not available: {e}") + EEG, HP, BUR, removed_channels = clean_artifacts( + self.test_eeg, + Channels=None, # None + Channels_ignore=None, # None + ChannelCriterion='off', + LineNoiseCriterion='off', + BurstCriterion='off', + WindowCriterion='off', + Highpass='off', + FlatlineCriterion='off', + ) + # Should work - both None is not mutually exclusive + self.assertIsInstance(EEG, dict) def test_clean_artifacts_mutually_exclusive_channels_overlapping(self): """Test clean_artifacts with overlapping channel lists (error expected).""" with self.assertRaises(ValueError) as cm: clean_artifacts( self.test_eeg, - Channels=['EEG001', 'EEG002', 'EEG003'], - Channels_ignore=['EEG002', 'EEG004'], # EEG002 overlaps + Channels=['Ch1', 'Ch2', 'Ch3'], + Channels_ignore=['Ch2', 'Ch4'], # Ch2 overlaps ) self.assertIn('mutually exclusive', str(cm.exception)) @@ -255,53 +189,45 @@ def setUp(self): def test_clean_artifacts_flatline_removal(self): """Test flatline channel removal.""" - try: - # Create some flatline channels - eeg_with_flatlines = self.test_eeg.copy() - eeg_with_flatlines['data'] = self.test_eeg['data'].copy() - eeg_with_flatlines['data'][5, :] = 0.0 # Flatline channel (2D data) - eeg_with_flatlines['data'][10, :] = 1.0 # Another flatline channel - original_nbchan = eeg_with_flatlines['nbchan'] - - EEG, HP, BUR, removed_channels = clean_artifacts( - eeg_with_flatlines, - FlatlineCriterion=1.0, # Short flatline duration - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - ) - - # Should have removed some channels - self.assertLess(EEG['nbchan'], original_nbchan) - - except Exception as e: - self.skipTest(f"clean_artifacts flatline removal not available: {e}") + # Create some flatline channels + eeg_with_flatlines = self.test_eeg.copy() + eeg_with_flatlines['data'] = self.test_eeg['data'].copy() + eeg_with_flatlines['data'][5, :] = 0.0 # Flatline channel (2D data) + eeg_with_flatlines['data'][10, :] = 1.0 # Another flatline channel + original_nbchan = eeg_with_flatlines['nbchan'] + + EEG, HP, BUR, removed_channels = clean_artifacts( + eeg_with_flatlines, + FlatlineCriterion=1.0, # Short flatline duration + ChannelCriterion='off', + LineNoiseCriterion='off', + BurstCriterion='off', + WindowCriterion='off', + Highpass='off', + ) + + # Should have removed some channels + self.assertLess(EEG['nbchan'], original_nbchan) def test_clean_artifacts_flatline_off(self): """Test flatline removal disabled.""" - try: - # Create some flatline channels - eeg_with_flatlines = self.test_eeg.copy() - eeg_with_flatlines['data'] = self.test_eeg['data'].copy() - eeg_with_flatlines['data'][5, :] = 0.0 # Flatline channel (2D data) - - EEG, HP, BUR, removed_channels = clean_artifacts( - eeg_with_flatlines, - FlatlineCriterion='off', - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - ) + # Create some flatline channels + eeg_with_flatlines = self.test_eeg.copy() + eeg_with_flatlines['data'] = self.test_eeg['data'].copy() + eeg_with_flatlines['data'][5, :] = 0.0 # Flatline channel (2D data) - # Should not have removed any channels - self.assertEqual(EEG['nbchan'], eeg_with_flatlines['nbchan']) + EEG, HP, BUR, removed_channels = clean_artifacts( + eeg_with_flatlines, + FlatlineCriterion='off', + ChannelCriterion='off', + LineNoiseCriterion='off', + BurstCriterion='off', + WindowCriterion='off', + Highpass='off', + ) - except Exception as e: - self.skipTest(f"clean_artifacts flatline off not available: {e}") + # Should not have removed any channels + self.assertEqual(EEG['nbchan'], eeg_with_flatlines['nbchan']) class TestCleanArtifactsHighpass(DebuggableTestCase): @@ -313,47 +239,39 @@ def setUp(self): def test_clean_artifacts_highpass_filtering(self): """Test highpass filtering.""" - try: - original_data = self.test_eeg['data'].copy() # Save before call - - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - Highpass=(0.5, 1.0), - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - FlatlineCriterion='off', - ) + original_data = self.test_eeg['data'].copy() # Save before call - # HP should contain the highpass filtered data - self.assertIsInstance(HP, dict) - self.assertIn('data', HP) + EEG, HP, BUR, removed_channels = clean_artifacts( + self.test_eeg, + Highpass=(0.5, 1.0), + ChannelCriterion='off', + LineNoiseCriterion='off', + BurstCriterion='off', + WindowCriterion='off', + FlatlineCriterion='off', + ) - # Data should be different after filtering - self.assertFalse(np.array_equal(HP['data'], original_data)) + # HP should contain the highpass filtered data + self.assertIsInstance(HP, dict) + self.assertIn('data', HP) - except Exception as e: - self.skipTest(f"clean_artifacts highpass filtering not available: {e}") + # Data should be different after filtering + self.assertFalse(np.array_equal(HP['data'], original_data)) def test_clean_artifacts_highpass_off(self): """Test highpass filtering disabled.""" - try: - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - Highpass='off', - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - FlatlineCriterion='off', - ) - - # Data should be unchanged - np.testing.assert_array_equal(HP['data'], self.test_eeg['data']) + EEG, HP, BUR, removed_channels = clean_artifacts( + self.test_eeg, + Highpass='off', + ChannelCriterion='off', + LineNoiseCriterion='off', + BurstCriterion='off', + WindowCriterion='off', + FlatlineCriterion='off', + ) - except Exception as e: - self.skipTest(f"clean_artifacts highpass off not available: {e}") + # Data should be unchanged + np.testing.assert_array_equal(HP['data'], self.test_eeg['data']) class TestCleanArtifactsChannelCleaning(DebuggableTestCase): @@ -365,60 +283,48 @@ def setUp(self): def test_clean_artifacts_channel_criterion(self): """Test channel correlation criterion.""" - try: - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - ChannelCriterion=0.9, # High threshold - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - - # Should have removed some channels with high threshold - self.assertLessEqual(EEG['nbchan'], self.test_eeg['nbchan']) - - except Exception as e: - self.skipTest(f"clean_artifacts channel criterion not available: {e}") + EEG, HP, BUR, removed_channels = clean_artifacts( + self.test_eeg, + ChannelCriterion=0.9, # High threshold + LineNoiseCriterion='off', + BurstCriterion='off', + WindowCriterion='off', + Highpass='off', + FlatlineCriterion='off', + ) + + # Should have removed some channels with high threshold + self.assertLessEqual(EEG['nbchan'], self.test_eeg['nbchan']) def test_clean_artifacts_line_noise_criterion(self): """Test line noise criterion.""" - try: - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion=2.0, # Low threshold - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - - # Should have removed some channels with low threshold - self.assertLessEqual(EEG['nbchan'], self.test_eeg['nbchan']) - - except Exception as e: - self.skipTest(f"clean_artifacts line noise criterion not available: {e}") + EEG, HP, BUR, removed_channels = clean_artifacts( + self.test_eeg, + ChannelCriterion='off', + LineNoiseCriterion=2.0, # Low threshold + BurstCriterion='off', + WindowCriterion='off', + Highpass='off', + FlatlineCriterion='off', + ) + + # Should have removed some channels with low threshold + self.assertLessEqual(EEG['nbchan'], self.test_eeg['nbchan']) def test_clean_artifacts_both_channel_criteria(self): """Test both channel and line noise criteria.""" - try: - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - ChannelCriterion=0.8, - LineNoiseCriterion=4.0, - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) + EEG, HP, BUR, removed_channels = clean_artifacts( + self.test_eeg, + ChannelCriterion=0.8, + LineNoiseCriterion=4.0, + BurstCriterion='off', + WindowCriterion='off', + Highpass='off', + FlatlineCriterion='off', + ) - # Should have removed some channels - self.assertLessEqual(EEG['nbchan'], self.test_eeg['nbchan']) - - except Exception as e: - self.skipTest(f"clean_artifacts both channel criteria not available: {e}") + # Should have removed some channels + self.assertLessEqual(EEG['nbchan'], self.test_eeg['nbchan']) class TestCleanArtifactsBurstCleaning(DebuggableTestCase): @@ -430,62 +336,50 @@ def setUp(self): def test_clean_artifacts_burst_criterion(self): """Test burst criterion.""" - try: - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion=5.0, - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - - # BUR should contain the burst repaired data - self.assertIsInstance(BUR, dict) - self.assertIn('data', BUR) - - except Exception as e: - self.skipTest(f"clean_artifacts burst criterion not available: {e}") + EEG, HP, BUR, removed_channels = clean_artifacts( + self.test_eeg, + ChannelCriterion='off', + LineNoiseCriterion='off', + BurstCriterion=5.0, + WindowCriterion='off', + Highpass='off', + FlatlineCriterion='off', + ) + + # BUR should contain the burst repaired data + self.assertIsInstance(BUR, dict) + self.assertIn('data', BUR) def test_clean_artifacts_burst_rejection(self): """Test burst rejection mode.""" - try: - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion=5.0, - BurstRejection='on', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - - # Should have removed some samples - self.assertLessEqual(EEG['pnts'], self.test_eeg['pnts']) - - except Exception as e: - self.skipTest(f"clean_artifacts burst rejection not available: {e}") + EEG, HP, BUR, removed_channels = clean_artifacts( + self.test_eeg, + ChannelCriterion='off', + LineNoiseCriterion='off', + BurstCriterion=5.0, + BurstRejection='on', + WindowCriterion='off', + Highpass='off', + FlatlineCriterion='off', + ) + + # Should have removed some samples + self.assertLessEqual(EEG['pnts'], self.test_eeg['pnts']) def test_clean_artifacts_burst_off(self): """Test burst cleaning disabled.""" - try: - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) + EEG, HP, BUR, removed_channels = clean_artifacts( + self.test_eeg, + ChannelCriterion='off', + LineNoiseCriterion='off', + BurstCriterion='off', + WindowCriterion='off', + Highpass='off', + FlatlineCriterion='off', + ) - # Data should be unchanged - np.testing.assert_array_equal(BUR['data'], self.test_eeg['data']) - - except Exception as e: - self.skipTest(f"clean_artifacts burst off not available: {e}") + # Data should be unchanged + np.testing.assert_array_equal(BUR['data'], self.test_eeg['data']) class TestCleanArtifactsWindowCleaning(DebuggableTestCase): @@ -497,41 +391,33 @@ def setUp(self): def test_clean_artifacts_window_criterion(self): """Test window criterion.""" - try: - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion=0.5, # Allow 50% bad channels per window - Highpass='off', - FlatlineCriterion='off', - ) - - # Should have removed some samples - self.assertLessEqual(EEG['pnts'], self.test_eeg['pnts']) - - except Exception as e: - self.skipTest(f"clean_artifacts window criterion not available: {e}") + EEG, HP, BUR, removed_channels = clean_artifacts( + self.test_eeg, + ChannelCriterion='off', + LineNoiseCriterion='off', + BurstCriterion='off', + WindowCriterion=0.5, # Allow 50% bad channels per window + Highpass='off', + FlatlineCriterion='off', + ) + + # Should have removed some samples + self.assertLessEqual(EEG['pnts'], self.test_eeg['pnts']) def test_clean_artifacts_window_off(self): """Test window cleaning disabled.""" - try: - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - - # Data should be unchanged - self.assertEqual(EEG['pnts'], self.test_eeg['pnts']) + EEG, HP, BUR, removed_channels = clean_artifacts( + self.test_eeg, + ChannelCriterion='off', + LineNoiseCriterion='off', + BurstCriterion='off', + WindowCriterion='off', + Highpass='off', + FlatlineCriterion='off', + ) - except Exception as e: - self.skipTest(f"clean_artifacts window off not available: {e}") + # Data should be unchanged + self.assertEqual(EEG['pnts'], self.test_eeg['pnts']) class TestCleanArtifactsChannelSelection(DebuggableTestCase): @@ -543,48 +429,40 @@ def setUp(self): def test_clean_artifacts_channels_include(self): """Test channel inclusion.""" - try: - channels_to_include = ['EEG001', 'EEG002', 'EEG003'] - - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - Channels=channels_to_include, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - - # Should have only the specified channels - self.assertEqual(EEG['nbchan'], len(channels_to_include)) - - except Exception as e: - self.skipTest(f"clean_artifacts channels include not available: {e}") + channels_to_include = ['Ch1', 'Ch2', 'Ch3'] + + EEG, HP, BUR, removed_channels = clean_artifacts( + self.test_eeg, + Channels=channels_to_include, + ChannelCriterion='off', + LineNoiseCriterion='off', + BurstCriterion='off', + WindowCriterion='off', + Highpass='off', + FlatlineCriterion='off', + ) + + # Should have only the specified channels + self.assertEqual(EEG['nbchan'], len(channels_to_include)) def test_clean_artifacts_channels_ignore(self): """Test channel exclusion.""" - try: - channels_to_ignore = ['EEG001', 'EEG002'] - original_nbchan = self.test_eeg['nbchan'] # Save before call - - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - Channels_ignore=channels_to_ignore, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) + channels_to_ignore = ['Ch1', 'Ch2'] + original_nbchan = self.test_eeg['nbchan'] # Save before call - # Should have fewer channels - self.assertEqual(EEG['nbchan'], original_nbchan - len(channels_to_ignore)) + EEG, HP, BUR, removed_channels = clean_artifacts( + self.test_eeg, + Channels_ignore=channels_to_ignore, + ChannelCriterion='off', + LineNoiseCriterion='off', + BurstCriterion='off', + WindowCriterion='off', + Highpass='off', + FlatlineCriterion='off', + ) - except Exception as e: - self.skipTest(f"clean_artifacts channels ignore not available: {e}") + # Should have fewer channels + self.assertEqual(EEG['nbchan'], original_nbchan - len(channels_to_ignore)) class TestCleanArtifactsParameterValidation(DebuggableTestCase): @@ -597,187 +475,166 @@ def setUp(self): def test_clean_artifacts_invalid_channel_criterion_type(self): """Test clean_artifacts with invalid ChannelCriterion type.""" # Should accept numeric values and 'off' - try: - # Valid cases - clean_artifacts( - self.test_eeg, - ChannelCriterion=0.8, - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - except Exception as e: - self.skipTest(f"clean_artifacts channel criterion validation not available: {e}") + # Valid cases + clean_artifacts( + self.test_eeg, + ChannelCriterion=0.8, + LineNoiseCriterion='off', + BurstCriterion='off', + WindowCriterion='off', + Highpass='off', + FlatlineCriterion='off', + ) + clean_artifacts( + self.test_eeg, + ChannelCriterion='off', + LineNoiseCriterion='off', + BurstCriterion='off', + WindowCriterion='off', + Highpass='off', + FlatlineCriterion='off', + ) def test_clean_artifacts_invalid_line_noise_criterion_type(self): """Test clean_artifacts with invalid LineNoiseCriterion type.""" # Should accept numeric values and 'off' - try: - # Valid cases - clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion=4.0, - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - except Exception as e: - self.skipTest(f"clean_artifacts line noise criterion validation not available: {e}") + # Valid cases + clean_artifacts( + self.test_eeg, + ChannelCriterion='off', + LineNoiseCriterion=4.0, + BurstCriterion='off', + WindowCriterion='off', + Highpass='off', + FlatlineCriterion='off', + ) + clean_artifacts( + self.test_eeg, + ChannelCriterion='off', + LineNoiseCriterion='off', + BurstCriterion='off', + WindowCriterion='off', + Highpass='off', + FlatlineCriterion='off', + ) def test_clean_artifacts_invalid_burst_criterion_type(self): """Test clean_artifacts with invalid BurstCriterion type.""" # Should accept numeric values and 'off' - try: - # Valid cases - clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion=5.0, - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - except Exception as e: - self.skipTest(f"clean_artifacts burst criterion validation not available: {e}") + # Valid cases + clean_artifacts( + self.test_eeg, + ChannelCriterion='off', + LineNoiseCriterion='off', + BurstCriterion=5.0, + WindowCriterion='off', + Highpass='off', + FlatlineCriterion='off', + ) + clean_artifacts( + self.test_eeg, + ChannelCriterion='off', + LineNoiseCriterion='off', + BurstCriterion='off', + WindowCriterion='off', + Highpass='off', + FlatlineCriterion='off', + ) def test_clean_artifacts_invalid_window_criterion_type(self): """Test clean_artifacts with invalid WindowCriterion type.""" # Should accept numeric values and 'off' - try: - # Valid cases - clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion=0.25, - Highpass='off', - FlatlineCriterion='off', - ) - clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - except Exception as e: - self.skipTest(f"clean_artifacts window criterion validation not available: {e}") + # Valid cases + clean_artifacts( + self.test_eeg, + ChannelCriterion='off', + LineNoiseCriterion='off', + BurstCriterion='off', + WindowCriterion=0.25, + Highpass='off', + FlatlineCriterion='off', + ) + clean_artifacts( + self.test_eeg, + ChannelCriterion='off', + LineNoiseCriterion='off', + BurstCriterion='off', + WindowCriterion='off', + Highpass='off', + FlatlineCriterion='off', + ) def test_clean_artifacts_invalid_flatline_criterion_type(self): """Test clean_artifacts with invalid FlatlineCriterion type.""" # Should accept numeric values and 'off' - try: - # Valid cases - clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion=5.0, - ) - clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - except Exception as e: - self.skipTest(f"clean_artifacts flatline criterion validation not available: {e}") + # Valid cases + clean_artifacts( + self.test_eeg, + ChannelCriterion='off', + LineNoiseCriterion='off', + BurstCriterion='off', + WindowCriterion='off', + Highpass='off', + FlatlineCriterion=5.0, + ) + clean_artifacts( + self.test_eeg, + ChannelCriterion='off', + LineNoiseCriterion='off', + BurstCriterion='off', + WindowCriterion='off', + Highpass='off', + FlatlineCriterion='off', + ) def test_clean_artifacts_invalid_burst_rejection_type(self): """Test clean_artifacts with invalid BurstRejection type.""" # Should accept 'on' and 'off' strings - try: - # Valid cases - clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - BurstRejection='on', - ) - clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - BurstRejection='off', - ) - except Exception as e: - self.skipTest(f"clean_artifacts burst rejection validation not available: {e}") + # Valid cases + clean_artifacts( + self.test_eeg, + ChannelCriterion='off', + LineNoiseCriterion='off', + BurstCriterion='off', + WindowCriterion='off', + Highpass='off', + FlatlineCriterion='off', + BurstRejection='on', + ) + clean_artifacts( + self.test_eeg, + ChannelCriterion='off', + LineNoiseCriterion='off', + BurstCriterion='off', + WindowCriterion='off', + Highpass='off', + FlatlineCriterion='off', + BurstRejection='off', + ) def test_clean_artifacts_documented_distance_metrics_with_asr_disabled(self): """Test clean_artifacts accepts documented Distance spellings when ASR is disabled.""" - try: - # Valid cases - clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - Distance='euclidian', - ) - clean_artifacts( - self.test_eeg, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - Distance='riemannian', - ) - except Exception as e: - self.skipTest(f"clean_artifacts distance metric validation not available: {e}") + # Valid cases + clean_artifacts( + self.test_eeg, + ChannelCriterion='off', + LineNoiseCriterion='off', + BurstCriterion='off', + WindowCriterion='off', + Highpass='off', + FlatlineCriterion='off', + Distance='euclidian', + ) + clean_artifacts( + self.test_eeg, + ChannelCriterion='off', + LineNoiseCriterion='off', + BurstCriterion='off', + WindowCriterion='off', + Highpass='off', + FlatlineCriterion='off', + Distance='riemannian', + ) def test_clean_artifacts_rejects_unknown_distance_metric(self): """Test clean_artifacts rejects unknown Distance spellings before cleaning.""" @@ -813,33 +670,27 @@ def test_clean_artifacts_negative_values(self): def test_clean_artifacts_zero_values(self): """Test clean_artifacts with zero parameter values.""" - try: - clean_artifacts( - self.test_eeg, - ChannelCriterion=0.0, # Zero correlation threshold - LineNoiseCriterion=0.0, - BurstCriterion='off', - WindowCriterion=0.0, - Highpass='off', - FlatlineCriterion=0.0, - ) - except Exception as e: - self.skipTest(f"clean_artifacts zero values not available: {e}") + clean_artifacts( + self.test_eeg, + ChannelCriterion=0.0, # Zero correlation threshold + LineNoiseCriterion=0.0, + BurstCriterion='off', + WindowCriterion=0.0, + Highpass='off', + FlatlineCriterion=0.0, + ) def test_clean_artifacts_extreme_values(self): """Test clean_artifacts with extreme parameter values.""" - try: - clean_artifacts( - self.test_eeg, - ChannelCriterion=1.0, # Perfect correlation required - LineNoiseCriterion=100.0, - BurstCriterion='off', - WindowCriterion=1.0, - Highpass='off', - FlatlineCriterion=1000.0, - ) - except Exception as e: - self.skipTest(f"clean_artifacts extreme values not available: {e}") + clean_artifacts( + self.test_eeg, + ChannelCriterion=1.0, # Perfect correlation required + LineNoiseCriterion=100.0, + BurstCriterion='off', + WindowCriterion=1.0, + Highpass='off', + FlatlineCriterion=1000.0, + ) class TestCleanArtifactsParameters(DebuggableTestCase): @@ -851,63 +702,51 @@ def setUp(self): def test_clean_artifacts_available_ram(self): """Test available RAM parameter.""" - try: - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - availableRAM_GB=2.0, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - - # Should complete without error - self.assertIsInstance(EEG, dict) - - except Exception as e: - self.skipTest(f"clean_artifacts available RAM not available: {e}") + EEG, HP, BUR, removed_channels = clean_artifacts( + self.test_eeg, + availableRAM_GB=2.0, + ChannelCriterion='off', + LineNoiseCriterion='off', + BurstCriterion='off', + WindowCriterion='off', + Highpass='off', + FlatlineCriterion='off', + ) + + # Should complete without error + self.assertIsInstance(EEG, dict) def test_clean_artifacts_distance_metric(self): """Test distance metric parameter.""" - try: - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - Distance='euclidian', - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - - # Should complete without error - self.assertIsInstance(EEG, dict) - - except Exception as e: - self.skipTest(f"clean_artifacts distance metric not available: {e}") + EEG, HP, BUR, removed_channels = clean_artifacts( + self.test_eeg, + Distance='euclidian', + ChannelCriterion='off', + LineNoiseCriterion='off', + BurstCriterion='off', + WindowCriterion='off', + Highpass='off', + FlatlineCriterion='off', + ) + + # Should complete without error + self.assertIsInstance(EEG, dict) def test_clean_artifacts_max_mem(self): """Test max memory parameter.""" - try: - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - MaxMem=128, - ChannelCriterion='off', - LineNoiseCriterion='off', - BurstCriterion='off', - WindowCriterion='off', - Highpass='off', - FlatlineCriterion='off', - ) - - # Should complete without error - self.assertIsInstance(EEG, dict) + EEG, HP, BUR, removed_channels = clean_artifacts( + self.test_eeg, + MaxMem=128, + ChannelCriterion='off', + LineNoiseCriterion='off', + BurstCriterion='off', + WindowCriterion='off', + Highpass='off', + FlatlineCriterion='off', + ) - except Exception as e: - self.skipTest(f"clean_artifacts max memory not available: {e}") + # Should complete without error + self.assertIsInstance(EEG, dict) class TestCleanArtifactsIntegration(DebuggableTestCase): @@ -919,73 +758,235 @@ def setUp(self): def test_clean_artifacts_full_pipeline(self): """Test the full clean_artifacts pipeline.""" - try: - EEG, HP, BUR, removed_channels = clean_artifacts( - self.test_eeg, - FlatlineCriterion=5.0, - Highpass=(0.25, 0.75), - ChannelCriterion=0.8, - LineNoiseCriterion=4.0, - BurstCriterion=5.0, - WindowCriterion=0.25, - ) + EEG, HP, BUR, removed_channels = clean_artifacts( + self.test_eeg, + FlatlineCriterion=5.0, + Highpass=(0.25, 0.75), + ChannelCriterion=0.8, + LineNoiseCriterion=4.0, + BurstCriterion=5.0, + WindowCriterion=0.25, + ) + + # Check all return values + self.assertIsInstance(EEG, dict) + self.assertIsInstance(HP, dict) + self.assertIsInstance(BUR, dict) + self.assertIsInstance(removed_channels, np.ndarray) + + # Check data integrity + self.assertIn('data', EEG) + self.assertIn('srate', EEG) + self.assertIn('nbchan', EEG) + self.assertIn('pnts', EEG) + + # Check that some processing occurred + self.assertLessEqual(EEG['nbchan'], self.test_eeg['nbchan']) + + def test_clean_artifacts_return_values(self): + """Test that all return values have correct structure.""" + EEG, HP, BUR, removed_channels = clean_artifacts( + self.test_eeg, + ChannelCriterion='off', + LineNoiseCriterion='off', + BurstCriterion='off', + WindowCriterion='off', + Highpass='off', + FlatlineCriterion='off', + ) + + # Check EEG structure + self.assertIn('data', EEG) + self.assertIn('srate', EEG) + self.assertIn('nbchan', EEG) + self.assertIn('pnts', EEG) + self.assertIn('etc', EEG) + + # Check HP structure (should be same as EEG when no highpass) + self.assertIn('data', HP) + self.assertIn('srate', HP) + self.assertIn('nbchan', HP) + self.assertIn('pnts', HP) + + # Check BUR structure (should be same as EEG when no burst cleaning) + self.assertIn('data', BUR) + self.assertIn('srate', BUR) + self.assertIn('nbchan', BUR) + self.assertIn('pnts', BUR) + + # Check removed_channels array + self.assertEqual(len(removed_channels), self.test_eeg['nbchan']) + self.assertTrue(np.issubdtype(removed_channels.dtype, np.bool_)) + + +class TestCleanArtifactsHpSnapshot(DebuggableTestCase): + """Regression test for the high-pass snapshot point-in-time contract.""" + + def setUp(self): + np.random.seed(11) + self.test_eeg = create_test_eeg() - # Check all return values - self.assertIsInstance(EEG, dict) - self.assertIsInstance(HP, dict) - self.assertIsInstance(BUR, dict) - self.assertIsInstance(removed_channels, np.ndarray) + def test_hp_snapshot_is_point_in_time(self): + """HP must not carry the sample mask written by the later window stage.""" + EEG, HP, _BUR, _removed = clean_artifacts( + self.test_eeg, + Highpass='off', + ChannelCriterion=0.8, + LineNoiseCriterion=4.0, + BurstCriterion='off', + WindowCriterion=0.25, + ) - # Check data integrity - self.assertIn('data', EEG) - self.assertIn('srate', EEG) - self.assertIn('nbchan', EEG) - self.assertIn('pnts', EEG) + # The window stage populates clean_sample_mask on the final EEG dataset... + self.assertIn('clean_sample_mask', EEG['etc']) + # ...but the high-pass snapshot predates that stage, so it must not share + # the same etc object or carry the later mask. + self.assertIsNot(HP['etc'], EEG['etc']) + self.assertNotIn('clean_sample_mask', HP['etc']) - # Check that some processing occurred - self.assertLessEqual(EEG['nbchan'], self.test_eeg['nbchan']) - except Exception as e: - self.skipTest(f"clean_artifacts full pipeline not available: {e}") +class TestPopCleanRawdataNoMutation(DebuggableTestCase): + """Regression test that pop_clean_rawdata never mutates the caller's EEG.""" - def test_clean_artifacts_return_values(self): - """Test that all return values have correct structure.""" + def setUp(self): + np.random.seed(13) + self.test_eeg = create_test_eeg() + + def test_does_not_mutate_input(self): + """The wrapper must deep-copy so the caller's dataset is untouched.""" + EEG_in = self.test_eeg + original_data = EEG_in['data'].copy() + original_nbchan = EEG_in['nbchan'] + + cleaned = pop_clean_rawdata( + EEG_in, + gui=False, + ChannelCriterion=0.8, + LineNoiseCriterion=4.0, + BurstCriterion='off', + WindowCriterion=0.25, + ) + + # Caller's data, channel count, and etc are all unchanged. + self.assertTrue(np.array_equal(original_data, EEG_in['data'])) + self.assertEqual(EEG_in['nbchan'], original_nbchan) + self.assertNotIn('clean_channel_mask', EEG_in.get('etc', {})) + self.assertNotIn('clean_sample_mask', EEG_in.get('etc', {})) + # The returned dataset is a distinct object. + self.assertIsNot(cleaned, EEG_in) + + +class TestCleanArtifactsErrorSurfacing(DebuggableTestCase): + """Errors inside the selection / channel-cleaning paths must surface, not be masked.""" + + def setUp(self): + self.test_eeg = create_test_eeg() + + def test_pop_select_internal_error_propagates(self): + """A non-ImportError raised inside pop_select must propagate, not silently + fall back to manual label selection (which could select different channels). + """ + import eegprep + + original = eegprep.pop_select + + def failing_pop_select(*args, **kwargs): + raise RuntimeError("simulated pop_select bug") + + eegprep.pop_select = failing_pop_select + try: + with self.assertRaises(RuntimeError): + clean_artifacts( + self.test_eeg, + Channels_ignore=['EEG001', 'EEG002'], + ChannelCriterion='off', + LineNoiseCriterion='off', + BurstCriterion='off', + WindowCriterion='off', + Highpass='off', + FlatlineCriterion='off', + ) + finally: + eegprep.pop_select = original + + def test_channels_ignore_preserves_events(self): + """Restricting channels must not wipe the dataset's events.""" + eeg = create_test_eeg() + eeg['event'] = [{'type': 'mark', 'latency': 100.0}, {'type': 'mark', 'latency': 5000.0}] + original_events = list(eeg['event']) + + EEG, _HP, _BUR, _removed = clean_artifacts( + eeg, + Channels_ignore=['EEG001'], + ChannelCriterion='off', + LineNoiseCriterion='off', + BurstCriterion='off', + WindowCriterion='off', + Highpass='off', + FlatlineCriterion='off', + ) + + self.assertEqual(len(EEG['event']), len(original_events)) + + def test_clean_channels_unexpected_value_error_propagates(self): + """A ValueError from clean_channels unrelated to missing locations must + propagate rather than silently switching to the no-locs algorithm. + """ + import eegprep.plugins.clean_rawdata.clean_artifacts as ca_mod + + original = ca_mod.clean_channels + + def boom(*args, **kwargs): + raise ValueError("totally unrelated bug") + + ca_mod.clean_channels = boom + try: + with self.assertRaises(ValueError) as cm: + clean_artifacts( + self.test_eeg, + ChannelCriterion=0.8, + LineNoiseCriterion='off', + BurstCriterion='off', + WindowCriterion='off', + Highpass='off', + FlatlineCriterion='off', + ) + self.assertIn('totally unrelated bug', str(cm.exception)) + finally: + ca_mod.clean_channels = original + + def test_clean_channels_location_error_falls_back(self): + """A missing-locations ValueError still triggers the no-locs fallback.""" + import eegprep.plugins.clean_rawdata.clean_artifacts as ca_mod + + original_cc = ca_mod.clean_channels + original_nolocs = ca_mod.clean_channels_nolocs + + def locs_error(*args, **kwargs): + raise ValueError('To use this function most of your channels should have X,Y,Z location measurements.') + + called = {'nolocs': False} + + def fake_nolocs(EEG, **kwargs): + called['nolocs'] = True + return EEG, np.zeros(EEG['nbchan'], dtype=bool) + + ca_mod.clean_channels = locs_error + ca_mod.clean_channels_nolocs = fake_nolocs try: - EEG, HP, BUR, removed_channels = clean_artifacts( + clean_artifacts( self.test_eeg, - ChannelCriterion='off', + ChannelCriterion=0.8, LineNoiseCriterion='off', BurstCriterion='off', WindowCriterion='off', Highpass='off', FlatlineCriterion='off', ) - - # Check EEG structure - self.assertIn('data', EEG) - self.assertIn('srate', EEG) - self.assertIn('nbchan', EEG) - self.assertIn('pnts', EEG) - self.assertIn('etc', EEG) - - # Check HP structure (should be same as EEG when no highpass) - self.assertIn('data', HP) - self.assertIn('srate', HP) - self.assertIn('nbchan', HP) - self.assertIn('pnts', HP) - - # Check BUR structure (should be same as EEG when no burst cleaning) - self.assertIn('data', BUR) - self.assertIn('srate', BUR) - self.assertIn('nbchan', BUR) - self.assertIn('pnts', BUR) - - # Check removed_channels array - self.assertEqual(len(removed_channels), self.test_eeg['nbchan']) - self.assertTrue(np.issubdtype(removed_channels.dtype, np.bool_)) - - except Exception as e: - self.skipTest(f"clean_artifacts return values not available: {e}") + self.assertTrue(called['nolocs']) + finally: + ca_mod.clean_channels = original_cc + ca_mod.clean_channels_nolocs = original_nolocs if __name__ == '__main__': diff --git a/tests/test_clean_asr.py b/tests/test_clean_asr.py index 4170437e..829bc390 100644 --- a/tests/test_clean_asr.py +++ b/tests/test_clean_asr.py @@ -100,6 +100,10 @@ def setUp(self): 'nbchan': self.n_channels, 'pnts': self.n_samples, 'trials': 1, + 'xmin': 0.0, + 'xmax': (self.n_samples - 1) / self.srate, + 'times': np.arange(self.n_samples) / self.srate, + 'event': [], } def test_clean_asr_parameter_acceptance(self): @@ -159,6 +163,10 @@ def setUp(self): 'nbchan': self.n_channels, 'pnts': self.n_samples, 'trials': 1, + 'xmin': 0.0, + 'xmax': (self.n_samples - 1) / self.srate, + 'times': np.arange(self.n_samples) / self.srate, + 'event': [], } def test_clean_asr_ref_maxbadchannels_off(self): @@ -288,6 +296,66 @@ def test_clean_asr_automatic_calibration_fallback(self): else: self.skipTest(f"clean_asr automatic calibration test not available: {e}") + def test_clean_asr_unexpected_clean_windows_error_propagates(self): + """An unexpected (non-ValueError) failure in clean_windows must propagate, + not be swallowed into a silent 'use all data for calibration' fallback. + """ + import eegprep.plugins.clean_rawdata.clean_asr as casr_mod + + original = casr_mod.clean_windows + + def boom(*args, **kwargs): + raise RuntimeError("simulated clean_windows bug") + + casr_mod.clean_windows = boom + try: + with self.assertRaises(RuntimeError) as cm: + clean_asr( + self.test_eeg, + ref_maxbadchannels=0.1, + ref_tolerances=(-3.0, 5.0), + ref_wndlen=1.0, + cutoff=20.0, + ) + self.assertIn('simulated clean_windows bug', str(cm.exception)) + finally: + casr_mod.clean_windows = original + + def test_clean_asr_clean_windows_value_error_falls_back(self): + """A ValueError from clean_windows (expected calibration-data problem) still + triggers the documented all-data fallback rather than crashing. + """ + import eegprep.plugins.clean_rawdata.clean_asr as casr_mod + + original = casr_mod.clean_windows + + def insufficient(*args, **kwargs): + raise ValueError('Not enough data for even a single window.') + + # Enough samples that all-data calibration succeeds once the fallback kicks in. + eeg = { + 'data': np.random.randn(self.n_channels, 5000) * 0.5, + 'srate': self.srate, + 'nbchan': self.n_channels, + 'pnts': 5000, + 'trials': 1, + } + + casr_mod.clean_windows = insufficient + try: + with self.assertLogs('eegprep.plugins.clean_rawdata.clean_asr', level='WARNING') as log: + result = clean_asr( + eeg, + ref_maxbadchannels=0.1, + ref_tolerances=(-3.0, 5.0), + ref_wndlen=1.0, + cutoff=20.0, + ) + self.assertTrue(any('Falling back to using the entire data' in msg for msg in log.output)) + self.assertIsInstance(result, dict) + finally: + casr_mod.clean_windows = original + class TestCleanASRSignalExtrapolation(unittest.TestCase): """Test clean_asr signal extrapolation logic.""" @@ -305,6 +373,10 @@ def setUp(self): 'nbchan': self.n_channels, 'pnts': self.n_samples, 'trials': 1, + 'xmin': 0.0, + 'xmax': (self.n_samples - 1) / self.srate, + 'times': np.arange(self.n_samples) / self.srate, + 'event': [], } def test_clean_asr_with_different_window_lengths(self): @@ -359,6 +431,10 @@ def setUp(self): 'nbchan': self.n_channels, 'pnts': self.n_samples, 'trials': 1, + 'xmin': 0.0, + 'xmax': (self.n_samples - 1) / self.srate, + 'times': np.arange(self.n_samples) / self.srate, + 'event': [], } def test_clean_asr_single_channel_data(self): @@ -369,6 +445,10 @@ def test_clean_asr_single_channel_data(self): 'nbchan': 1, 'pnts': self.n_samples, 'trials': 1, + 'xmin': 0.0, + 'xmax': (self.n_samples - 1) / self.srate, + 'times': np.arange(self.n_samples) / self.srate, + 'event': [], } try: diff --git a/tests/test_clean_drifts.py b/tests/test_clean_drifts.py index f6f90991..5f00607e 100644 --- a/tests/test_clean_drifts.py +++ b/tests/test_clean_drifts.py @@ -14,56 +14,12 @@ from eegprep.plugins.clean_rawdata.clean_drifts import clean_drifts from eegprep.utils.testing import DebuggableTestCase +from tests.fixtures import create_test_eeg as _create_test_eeg + def create_test_eeg(): - """Create a complete test EEG structure with all required fields. - - Note: clean_drifts expects continuous (2D) data, not epoched (3D) data. - """ - n_pnts = 10000 # 20 seconds at 500 Hz - return { - 'data': np.random.randn(32, n_pnts), # 2D continuous data - 'srate': 500.0, - 'nbchan': 32, - 'pnts': n_pnts, - 'trials': 1, - 'xmin': 0.0, - 'xmax': n_pnts / 500.0, - 'times': np.linspace(0, n_pnts / 500.0, n_pnts), - 'icaact': [], - 'icawinv': [], - 'icasphere': [], - 'icaweights': [], - 'icachansind': [], - 'chanlocs': [ - { - 'labels': f'EEG{i:03d}', - 'type': 'EEG', - 'theta': np.random.uniform(-90, 90), - 'radius': np.random.uniform(0, 1), - 'X': np.random.uniform(-1, 1), - 'Y': np.random.uniform(-1, 1), - 'Z': np.random.uniform(-1, 1), - 'sph_theta': np.random.uniform(-180, 180), - 'sph_phi': np.random.uniform(-90, 90), - 'sph_radius': np.random.uniform(0, 1), - 'urchan': i + 1, - 'ref': '', - } - for i in range(32) - ], - 'urchanlocs': [], - 'chaninfo': {'removedchans': []}, - 'ref': 'common', - 'history': '', - 'saved': 'yes', - 'etc': {}, - 'event': [], - 'epoch': [], - 'setname': 'test_dataset', - 'filename': 'test.set', - 'filepath': '/tmp', - } + """Continuous (2D) EEG fixture sized for clean_drifts (20 s at 500 Hz).""" + return _create_test_eeg(n_channels=32, n_samples=10000, srate=500.0, n_trials=1) class TestCleanDriftsBasic(DebuggableTestCase): @@ -75,90 +31,66 @@ def setUp(self): def test_clean_drifts_basic_functionality(self): """Test basic clean_drifts functionality with default parameters.""" - try: - result = clean_drifts(self.test_eeg.copy()) - - # Check that EEG structure is preserved - self.assertIn('data', result) - self.assertIn('srate', result) - self.assertIn('nbchan', result) - self.assertIn('pnts', result) - self.assertIn('etc', result) + result = clean_drifts(self.test_eeg.copy()) - # Check that data dimensions are preserved - self.assertEqual(result['srate'], self.test_eeg['srate']) - self.assertEqual(result['nbchan'], self.test_eeg['nbchan']) - self.assertEqual(result['pnts'], self.test_eeg['pnts']) - self.assertEqual(result['trials'], self.test_eeg['trials']) + # Check that EEG structure is preserved + self.assertIn('data', result) + self.assertIn('srate', result) + self.assertIn('nbchan', result) + self.assertIn('pnts', result) + self.assertIn('etc', result) - # Check that data type is float64 - self.assertEqual(result['data'].dtype, np.float64) + # Check that data dimensions are preserved + self.assertEqual(result['srate'], self.test_eeg['srate']) + self.assertEqual(result['nbchan'], self.test_eeg['nbchan']) + self.assertEqual(result['pnts'], self.test_eeg['pnts']) + self.assertEqual(result['trials'], self.test_eeg['trials']) - # Check that filter kernel is stored - self.assertIn('clean_drifts_kernel', result['etc']) + # Check that data type is float64 + self.assertEqual(result['data'].dtype, np.float64) - except Exception as e: - self.skipTest(f"clean_drifts basic functionality not available: {e}") + # Check that filter kernel is stored + self.assertIn('clean_drifts_kernel', result['etc']) def test_clean_drifts_default_parameters(self): """Test clean_drifts with default parameters.""" - try: - result = clean_drifts(self.test_eeg.copy()) + result = clean_drifts(self.test_eeg.copy()) - # Should work with default parameters - self.assertIn('data', result) - self.assertIn('clean_drifts_kernel', result['etc']) - - except Exception as e: - self.skipTest(f"clean_drifts default parameters not available: {e}") + # Should work with default parameters + self.assertIn('data', result) + self.assertIn('clean_drifts_kernel', result['etc']) def test_clean_drifts_custom_transition(self): """Test clean_drifts with custom transition band.""" - try: - result = clean_drifts(self.test_eeg.copy(), transition=(1.0, 2.0)) + result = clean_drifts(self.test_eeg.copy(), transition=(1.0, 2.0)) - # Should work with custom transition band - self.assertIn('data', result) - self.assertIn('clean_drifts_kernel', result['etc']) - - except Exception as e: - self.skipTest(f"clean_drifts custom transition not available: {e}") + # Should work with custom transition band + self.assertIn('data', result) + self.assertIn('clean_drifts_kernel', result['etc']) def test_clean_drifts_custom_attenuation(self): """Test clean_drifts with custom attenuation.""" - try: - result = clean_drifts(self.test_eeg.copy(), attenuation=60.0) + result = clean_drifts(self.test_eeg.copy(), attenuation=60.0) - # Should work with custom attenuation - self.assertIn('data', result) - self.assertIn('clean_drifts_kernel', result['etc']) - - except Exception as e: - self.skipTest(f"clean_drifts custom attenuation not available: {e}") + # Should work with custom attenuation + self.assertIn('data', result) + self.assertIn('clean_drifts_kernel', result['etc']) def test_clean_drifts_fir_method(self): """Test clean_drifts with FIR method.""" - try: - result = clean_drifts(self.test_eeg.copy(), method='fir') - - # Should work with FIR method - self.assertIn('data', result) - self.assertIn('clean_drifts_kernel', result['etc']) + result = clean_drifts(self.test_eeg.copy(), method='fir') - except Exception as e: - self.skipTest(f"clean_drifts FIR method not available: {e}") + # Should work with FIR method + self.assertIn('data', result) + self.assertIn('clean_drifts_kernel', result['etc']) def test_clean_drifts_fft_method(self): """Test clean_drifts with FFT method.""" - try: - result = clean_drifts(self.test_eeg.copy(), method='fft') + result = clean_drifts(self.test_eeg.copy(), method='fft') - # Should work with FFT method - self.assertIn('data', result) - self.assertIn('clean_drifts_kernel', result['etc']) - - except Exception as e: - self.skipTest(f"clean_drifts FFT method not available: {e}") + # Should work with FFT method + self.assertIn('data', result) + self.assertIn('clean_drifts_kernel', result['etc']) class TestCleanDriftsEdgeCases(DebuggableTestCase): @@ -170,87 +102,67 @@ def setUp(self): def test_clean_drifts_single_channel(self): """Test clean_drifts with single channel data.""" - try: - # Create single channel data (2D continuous) - single_channel_eeg = self.test_eeg.copy() - single_channel_eeg['data'] = np.random.randn(1, 10000) - single_channel_eeg['nbchan'] = 1 - single_channel_eeg['chanlocs'] = [single_channel_eeg['chanlocs'][0]] + # Create single channel data (2D continuous) + single_channel_eeg = self.test_eeg.copy() + single_channel_eeg['data'] = np.random.randn(1, 10000) + single_channel_eeg['nbchan'] = 1 + single_channel_eeg['chanlocs'] = [single_channel_eeg['chanlocs'][0]] - result = clean_drifts(single_channel_eeg) - - # Should work with single channel - self.assertEqual(result['nbchan'], 1) - self.assertIn('clean_drifts_kernel', result['etc']) + result = clean_drifts(single_channel_eeg) - except Exception as e: - self.skipTest(f"clean_drifts single channel not available: {e}") + # Should work with single channel + self.assertEqual(result['nbchan'], 1) + self.assertIn('clean_drifts_kernel', result['etc']) def test_clean_drifts_single_trial(self): """Test clean_drifts with continuous (single trial) data.""" - try: - # Create continuous data (2D - single trial is the normal case) - single_trial_eeg = self.test_eeg.copy() - single_trial_eeg['data'] = np.random.randn(32, 10000) - single_trial_eeg['trials'] = 1 + # Create continuous data (2D - single trial is the normal case) + single_trial_eeg = self.test_eeg.copy() + single_trial_eeg['data'] = np.random.randn(32, 10000) + single_trial_eeg['trials'] = 1 - result = clean_drifts(single_trial_eeg) + result = clean_drifts(single_trial_eeg) - # Should work with single trial - self.assertEqual(result['trials'], 1) - self.assertIn('clean_drifts_kernel', result['etc']) - - except Exception as e: - self.skipTest(f"clean_drifts single trial not available: {e}") + # Should work with single trial + self.assertEqual(result['trials'], 1) + self.assertIn('clean_drifts_kernel', result['etc']) def test_clean_drifts_continuous_data(self): """Test clean_drifts with continuous (2D) data.""" - try: - # Create continuous data (2D) - continuous_eeg = self.test_eeg.copy() - continuous_eeg['data'] = np.random.randn(32, 1000) - continuous_eeg['trials'] = 1 - - result = clean_drifts(continuous_eeg) + # Create continuous data (2D) + continuous_eeg = self.test_eeg.copy() + continuous_eeg['data'] = np.random.randn(32, 1000) + continuous_eeg['trials'] = 1 - # Should work with continuous data - self.assertIn('data', result) - self.assertIn('clean_drifts_kernel', result['etc']) + result = clean_drifts(continuous_eeg) - except Exception as e: - self.skipTest(f"clean_drifts continuous data not available: {e}") + # Should work with continuous data + self.assertIn('data', result) + self.assertIn('clean_drifts_kernel', result['etc']) def test_clean_drifts_float32_data(self): """Test clean_drifts with float32 data.""" - try: - # Create float32 data - float32_eeg = self.test_eeg.copy() - float32_eeg['data'] = np.random.randn(32, 10000).astype(np.float32) + # Create float32 data + float32_eeg = self.test_eeg.copy() + float32_eeg['data'] = np.random.randn(32, 10000).astype(np.float32) - result = clean_drifts(float32_eeg) + result = clean_drifts(float32_eeg) - # Should convert to float64 - self.assertEqual(result['data'].dtype, np.float64) - self.assertIn('clean_drifts_kernel', result['etc']) - - except Exception as e: - self.skipTest(f"clean_drifts float32 data not available: {e}") + # Should convert to float64 + self.assertEqual(result['data'].dtype, np.float64) + self.assertIn('clean_drifts_kernel', result['etc']) def test_clean_drifts_float64_data(self): """Test clean_drifts with float64 data.""" - try: - # Create float64 data - float64_eeg = self.test_eeg.copy() - float64_eeg['data'] = np.random.randn(32, 10000).astype(np.float64) - - result = clean_drifts(float64_eeg) + # Create float64 data + float64_eeg = self.test_eeg.copy() + float64_eeg['data'] = np.random.randn(32, 10000).astype(np.float64) - # Should remain float64 - self.assertEqual(result['data'].dtype, np.float64) - self.assertIn('clean_drifts_kernel', result['etc']) + result = clean_drifts(float64_eeg) - except Exception as e: - self.skipTest(f"clean_drifts float64 data not available: {e}") + # Should remain float64 + self.assertEqual(result['data'].dtype, np.float64) + self.assertIn('clean_drifts_kernel', result['etc']) class TestCleanDriftsParameters(DebuggableTestCase): @@ -262,45 +174,33 @@ def setUp(self): def test_clean_drifts_different_transition_bands(self): """Test clean_drifts with different transition bands.""" - try: - # Test different transition bands - transitions = [(0.1, 0.5), (0.5, 1.0), (1.0, 2.0), (2.0, 5.0)] + # Test different transition bands + transitions = [(0.1, 0.5), (0.5, 1.0), (1.0, 2.0), (2.0, 5.0)] - for transition in transitions: - result = clean_drifts(self.test_eeg.copy(), transition=transition) - self.assertIn('data', result) - self.assertIn('clean_drifts_kernel', result['etc']) - - except Exception as e: - self.skipTest(f"clean_drifts different transition bands not available: {e}") + for transition in transitions: + result = clean_drifts(self.test_eeg.copy(), transition=transition) + self.assertIn('data', result) + self.assertIn('clean_drifts_kernel', result['etc']) def test_clean_drifts_different_attenuations(self): """Test clean_drifts with different attenuation values.""" - try: - # Test different attenuation values - attenuations = [40.0, 60.0, 80.0, 100.0] - - for attenuation in attenuations: - result = clean_drifts(self.test_eeg.copy(), attenuation=attenuation) - self.assertIn('data', result) - self.assertIn('clean_drifts_kernel', result['etc']) + # Test different attenuation values + attenuations = [40.0, 60.0, 80.0, 100.0] - except Exception as e: - self.skipTest(f"clean_drifts different attenuations not available: {e}") + for attenuation in attenuations: + result = clean_drifts(self.test_eeg.copy(), attenuation=attenuation) + self.assertIn('data', result) + self.assertIn('clean_drifts_kernel', result['etc']) def test_clean_drifts_both_methods(self): """Test clean_drifts with both FIR and FFT methods.""" - try: - # Test both methods - methods = ['fir', 'fft'] - - for method in methods: - result = clean_drifts(self.test_eeg.copy(), method=method) - self.assertIn('data', result) - self.assertIn('clean_drifts_kernel', result['etc']) + # Test both methods + methods = ['fir', 'fft'] - except Exception as e: - self.skipTest(f"clean_drifts both methods not available: {e}") + for method in methods: + result = clean_drifts(self.test_eeg.copy(), method=method) + self.assertIn('data', result) + self.assertIn('clean_drifts_kernel', result['etc']) class TestCleanDriftsIntegration(DebuggableTestCase): @@ -312,56 +212,44 @@ def setUp(self): def test_clean_drifts_preserves_structure(self): """Test that clean_drifts preserves EEG structure.""" - try: - result = clean_drifts(self.test_eeg.copy()) - - # Check that all essential fields are preserved - essential_fields = ['data', 'srate', 'nbchan', 'pnts', 'trials', 'xmin', 'xmax', 'times', 'chanlocs'] - for field in essential_fields: - self.assertIn(field, result) + result = clean_drifts(self.test_eeg.copy()) - # Check that data integrity is maintained - self.assertEqual(result['srate'], self.test_eeg['srate']) - self.assertEqual(result['nbchan'], self.test_eeg['nbchan']) - self.assertEqual(result['pnts'], self.test_eeg['pnts']) - self.assertEqual(result['trials'], self.test_eeg['trials']) + # Check that all essential fields are preserved + essential_fields = ['data', 'srate', 'nbchan', 'pnts', 'trials', 'xmin', 'xmax', 'times', 'chanlocs'] + for field in essential_fields: + self.assertIn(field, result) - except Exception as e: - self.skipTest(f"clean_drifts preserves structure not available: {e}") + # Check that data integrity is maintained + self.assertEqual(result['srate'], self.test_eeg['srate']) + self.assertEqual(result['nbchan'], self.test_eeg['nbchan']) + self.assertEqual(result['pnts'], self.test_eeg['pnts']) + self.assertEqual(result['trials'], self.test_eeg['trials']) def test_clean_drifts_data_modification(self): """Test that clean_drifts actually modifies the data.""" - try: - original_data = self.test_eeg['data'].copy() - result = clean_drifts(self.test_eeg.copy()) + original_data = self.test_eeg['data'].copy() + result = clean_drifts(self.test_eeg.copy()) - # Data should be modified (filtered) - self.assertFalse(np.array_equal(original_data, result['data'])) + # Data should be modified (filtered) + self.assertFalse(np.array_equal(original_data, result['data'])) - # But shape should be preserved - self.assertEqual(original_data.shape, result['data'].shape) - - except Exception as e: - self.skipTest(f"clean_drifts data modification not available: {e}") + # But shape should be preserved + self.assertEqual(original_data.shape, result['data'].shape) def test_clean_drifts_kernel_properties(self): """Test properties of the filter kernel.""" - try: - result = clean_drifts(self.test_eeg.copy()) - - kernel = result['etc']['clean_drifts_kernel'] + result = clean_drifts(self.test_eeg.copy()) - # Kernel should be a numpy array - self.assertIsInstance(kernel, np.ndarray) + kernel = result['etc']['clean_drifts_kernel'] - # Kernel should not be empty - self.assertGreater(len(kernel), 0) + # Kernel should be a numpy array + self.assertIsInstance(kernel, np.ndarray) - # Kernel should be 1D - self.assertEqual(kernel.ndim, 1) + # Kernel should not be empty + self.assertGreater(len(kernel), 0) - except Exception as e: - self.skipTest(f"clean_drifts kernel properties not available: {e}") + # Kernel should be 1D + self.assertEqual(kernel.ndim, 1) if __name__ == '__main__': diff --git a/tests/test_clean_flatlines.py b/tests/test_clean_flatlines.py index de96948b..40723893 100644 --- a/tests/test_clean_flatlines.py +++ b/tests/test_clean_flatlines.py @@ -14,52 +14,12 @@ from eegprep.plugins.clean_rawdata.clean_flatlines import clean_flatlines from eegprep.utils.testing import DebuggableTestCase +from tests.fixtures import create_test_eeg as _create_test_eeg + def create_test_eeg(): - """Create a complete test EEG structure with all required fields.""" - return { - 'data': np.random.randn(32, 1000, 10), - 'srate': 500.0, - 'nbchan': 32, - 'pnts': 1000, - 'trials': 10, - 'xmin': -1.0, - 'xmax': 1.0, - 'times': np.linspace(-1.0, 1.0, 1000), - 'icaact': [], - 'icawinv': [], - 'icasphere': [], - 'icaweights': [], - 'icachansind': [], - 'chanlocs': [ - { - 'labels': f'EEG{i:03d}', - 'type': 'EEG', - 'theta': np.random.uniform(-90, 90), - 'radius': np.random.uniform(0, 1), - 'X': np.random.uniform(-1, 1), - 'Y': np.random.uniform(-1, 1), - 'Z': np.random.uniform(-1, 1), - 'sph_theta': np.random.uniform(-180, 180), - 'sph_phi': np.random.uniform(-90, 90), - 'sph_radius': np.random.uniform(0, 1), - 'urchan': i + 1, - 'ref': '', - } - for i in range(32) - ], - 'urchanlocs': [], - 'chaninfo': [], - 'ref': 'common', - 'history': '', - 'saved': 'yes', - 'etc': {}, - 'event': [], - 'epoch': [], - 'setname': 'test_dataset', - 'filename': 'test.set', - 'filepath': '/tmp', - } + """Epoched EEG fixture sized for clean_flatlines (32 ch, 1000 pnts, 10 trials).""" + return _create_test_eeg(n_channels=32, n_samples=1000, srate=500.0, n_trials=10) class TestCleanFlatlinesBasic(DebuggableTestCase): @@ -392,6 +352,45 @@ def test_clean_flatlines_walrus_operator_branch(self): if result['nbchan'] < eeg_walrus['nbchan']: self.assertFalse(result['etc']['clean_channel_mask'][5]) + def test_clean_flatlines_fallback_composites_existing_mask(self): + """Fallback path with a prior clean_channel_mask must composite, not crash. + + Reproduces the walrus-precedence bug: when pop_select fails and a prior + clean_channel_mask exists, the mask update must run ``mask[mask] = ~removed`` + rather than treating the mask as a bool. Uses continuous (2D) data so the + composite indexing exercises the real fallback branch. + """ + eeg = self.test_eeg.copy() + eeg['data'] = np.random.randn(32, 1000) + eeg['trials'] = 1 + eeg['data'][5, :] = 1.0 # flatline channel 5 + eeg['etc'] = {'clean_channel_mask': np.ones(32, dtype=bool)} + # Empty chanlocs so the unrelated chanlocs-trim branch is skipped and the + # test isolates the clean_channel_mask compositing branch. + eeg['chanlocs'] = [] + + # Force the no-pop_select fallback with a non-ImportError so the + # mask-compositing branch runs (this is where the bug lived). + import eegprep + + original = eegprep.pop_select + + def failing_pop_select(*args, **kwargs): + raise RuntimeError("simulated pop_select failure") + + eegprep.pop_select = failing_pop_select + try: + result = clean_flatlines(eeg, max_flatline_duration=1.0) + finally: + eegprep.pop_select = original + + mask = result['etc']['clean_channel_mask'] + # Original mask had 32 True entries; after compositing exactly channel 5 + # (the flatline) must be False and the rest True. + self.assertEqual(mask.shape[0], 32) + self.assertFalse(mask[5]) + self.assertEqual(int(np.sum(~mask)), 1) + class TestCleanFlatlinesNoOpPath(DebuggableTestCase): """No-operation path test cases for clean_flatlines function.""" diff --git a/tests/test_clean_windows.py b/tests/test_clean_windows.py index 09c5980d..096ff313 100644 --- a/tests/test_clean_windows.py +++ b/tests/test_clean_windows.py @@ -488,6 +488,25 @@ def test_sample_mask_consistency(self): # Check that sample_mask is boolean self.assertEqual(sample_mask.dtype, bool) + def test_does_not_mutate_input(self): + """clean_windows must not mutate the caller's EEG dict or data array.""" + EEG_in = self.EEG_artifacts + original_data = EEG_in['data'].copy() + original_dtype = EEG_in['data'].dtype + original_keys = set(EEG_in.keys()) + + EEG_out, _ = clean_windows(EEG_in) + + # Caller's data array is unchanged in value, dtype, and shape. + self.assertTrue(np.array_equal(original_data, EEG_in['data'])) + self.assertEqual(EEG_in['data'].dtype, original_dtype) + # No new keys (e.g. 'etc') were injected into the caller's dict. + self.assertEqual(set(EEG_in.keys()), original_keys) + self.assertNotIn('etc', EEG_in) + # Output is a distinct object from the input. + self.assertIsNot(EEG_out, EEG_in) + self.assertIsNot(EEG_out['data'], EEG_in['data']) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_cli_bids_eeglab_commands.py b/tests/test_cli_bids_migrate_commands.py similarity index 76% rename from tests/test_cli_bids_eeglab_commands.py rename to tests/test_cli_bids_migrate_commands.py index 86a9cec0..f8778c68 100644 --- a/tests/test_cli_bids_eeglab_commands.py +++ b/tests/test_cli_bids_migrate_commands.py @@ -52,6 +52,20 @@ def test_bids_export_validate_import_roundtrip(tmp_path, capsys): assert captured.err == "" +def test_bids_validate_reports_error_when_no_eeg_files(tmp_path): + from eegprep.cli.commands import bids as bids_cli + + empty_root = tmp_path / "empty_bids" + empty_root.mkdir() + (empty_root / "dataset_description.json").write_text("{}", encoding="utf-8") + + payload = bids_cli.validate_dataset(empty_root) + + assert payload["status"] == "error" + assert payload["can_continue"] is False + assert [issue["code"] for issue in payload["errors"]] == ["BIDS_EEG_FILES_MISSING"] + + def test_bids_validate_missing_path_returns_structured_error(tmp_path, capsys): from eegprep.cli.commands import bids as bids_cli @@ -66,6 +80,29 @@ def test_bids_validate_missing_path_returns_structured_error(tmp_path, capsys): assert captured.err == "" +def test_bids_import_set_file_uses_eeglab_loader_without_error_sniffing(tmp_path, monkeypatch): + from eegprep.cli.commands import bids as bids_cli + + input_set = tmp_path / "input.set" + imported_set = tmp_path / "imported.set" + pop_saveset(_eeg(), input_set) + + # A .set file must dispatch to the EEGLAB loader without ever invoking the BIDS sidecar + # importer; the previous fallback only recovered when the IndexError message matched a + # specific string, so any other wording silently re-raised as an opaque crash. + def _fail(*_args, **_kwargs): + raise IndexError("array index out of range") + + monkeypatch.setattr(bids_cli, "pop_importbids", _fail) + + payload = bids_cli.import_dataset(input_set, output=imported_set) + + assert payload["status"] == "ok" + assert payload["dataset"]["nbchan"] == 2 + assert imported_set.exists() + assert [warning["code"] for warning in payload["warnings"]] == ["BIDS_SIDECARS_SKIPPED"] + + def test_bids_import_refuses_existing_manifest_without_overwrite(tmp_path): from eegprep.cli.commands import bids as bids_cli @@ -103,8 +140,8 @@ def test_bids_export_refuses_non_empty_root_without_overwrite(tmp_path): assert existing.read_text(encoding="utf-8") == "existing" -def test_eeglab_history_maps_supported_and_unsupported_commands(tmp_path): - from eegprep.cli.commands import eeglab as eeglab_cli +def test_migrate_history_maps_supported_and_unsupported_commands(tmp_path): + from eegprep.cli.commands import migrate as migrate_cli set_file = tmp_path / "history.set" eeg = _eeg() @@ -117,9 +154,10 @@ def test_eeglab_history_maps_supported_and_unsupported_commands(tmp_path): ) pop_saveset(eeg, set_file) - payload = eeglab_cli.history(set_file) + payload = migrate_cli.history(set_file) assert payload["status"] == "ok" + assert payload["schema_version"] == "eegprep.migrate.history.v1" assert [operation["eeglab_command"] for operation in payload["operations"]] == [ "pop_loadset", "pop_resample", @@ -131,8 +169,8 @@ def test_eeglab_history_maps_supported_and_unsupported_commands(tmp_path): assert payload["operations"][2]["unsupported"]["code"] == "COMMAND_NOT_IMPLEMENTED" -def test_eeglab_compare_reports_structured_differences(tmp_path): - from eegprep.cli.commands import eeglab as eeglab_cli +def test_migrate_compare_reports_structured_differences(tmp_path): + from eegprep.cli.commands import migrate as migrate_cli left = tmp_path / "left.set" right = tmp_path / "right.set" @@ -144,9 +182,10 @@ def test_eeglab_compare_reports_structured_differences(tmp_path): pop_saveset(eeg_left, left) pop_saveset(eeg_right, right) - payload = eeglab_cli.compare(left, right) + payload = migrate_cli.compare(left, right) assert payload["status"] == "ok" + assert payload["schema_version"] == "eegprep.migrate.compare.v1" assert payload["equivalent"] is False differences_by_path = {difference["path"]: difference for difference in payload["differences"]} assert differences_by_path["srate"]["code"] == "VALUE_MISMATCH" @@ -154,8 +193,8 @@ def test_eeglab_compare_reports_structured_differences(tmp_path): assert payload["data"]["max_abs_diff"] == 1.25 -def test_eeglab_compare_reports_nan_placement_differences(tmp_path): - from eegprep.cli.commands import eeglab as eeglab_cli +def test_migrate_compare_reports_nan_placement_differences(tmp_path): + from eegprep.cli.commands import migrate as migrate_cli left = tmp_path / "left_nan.set" right = tmp_path / "right_nan.set" @@ -168,14 +207,14 @@ def test_eeglab_compare_reports_nan_placement_differences(tmp_path): pop_saveset(eeg_left, left) pop_saveset(eeg_right, right) - payload = eeglab_cli.compare(left, right) + payload = migrate_cli.compare(left, right) assert payload["equivalent"] is False assert any(difference["code"] == "DATA_FINITE_MASK_MISMATCH" for difference in payload["differences"]) -def test_eeglab_convert_script_reports_best_effort_conversion(tmp_path): - from eegprep.cli.commands import eeglab as eeglab_cli +def test_migrate_convert_script_reports_best_effort_conversion(tmp_path): + from eegprep.cli.commands import migrate as migrate_cli script = tmp_path / "pipeline.m" output = tmp_path / "pipeline.yaml" @@ -190,9 +229,10 @@ def test_eeglab_convert_script_reports_best_effort_conversion(tmp_path): encoding="utf-8", ) - payload = eeglab_cli.convert_script(script, output=output) + payload = migrate_cli.convert_script(script, output=output) assert payload["status"] == "ok" + assert payload["schema_version"] == "eegprep.migrate.convert_script.v1" assert payload["target"] == "eegprep-yaml" assert payload["converted_steps"][1]["name"] == "resample" assert payload["unsupported_commands"][0]["command"] == "topoplot" diff --git a/tests/test_cli_main.py b/tests/test_cli_main.py index 32dd4cb8..d772f9d4 100644 --- a/tests/test_cli_main.py +++ b/tests/test_cli_main.py @@ -8,6 +8,7 @@ import yaml +from eegprep.cli.core import EEGPrepCLIError, command_error, emit_command_result from tests.fixtures import SAMPLE_DATASET_PATH @@ -37,6 +38,8 @@ def test_help_has_agent_start_section(): assert result.returncode == 0 assert "Start here (for AI agents):" in result.stdout assert "eegprep skills get eegprep-cli" in result.stdout + assert "migrate" in result.stdout + assert "eeglab Inspect EEGLAB history" not in result.stdout def test_capabilities_schema_examples_and_skill_are_json_readable(): @@ -49,6 +52,8 @@ def test_capabilities_schema_examples_and_skill_are_json_readable(): commands = _json_stdout(capabilities)["commands"] assert "filter" in commands assert "batch" in commands + assert "migrate" in commands + assert "eeglab" not in commands assert schema.returncode == 0 assert _json_stdout(schema)["schema"]["schema_version"] == "eegprep.schema.command.filter.v1" assert examples.returncode == 0 @@ -57,6 +62,18 @@ def test_capabilities_schema_examples_and_skill_are_json_readable(): assert "Agent Rules" in _json_stdout(skill)["content"] +def test_transform_command_schemas_define_every_required_property(): + for command in ("resample", "rereference", "filter", "clean", "epoch", "ica"): + schema = _json_stdout(_run_cli("schema", "command", command, "--json"))["schema"] + properties = set(schema["properties"]) + missing = [name for name in schema["required"] if name not in properties] + assert not missing, f"{command} required params absent from properties: {missing}" + # --output is conditionally required (allowed to be omitted with --overwrite), so it must + # not appear in the unconditional required list and the alternative must be advertised. + assert "output" not in schema["required"], command + assert schema["anyOf"] == [{"required": ["output"]}, {"required": ["overwrite"]}], command + + def test_every_advertised_capability_has_schema_and_examples(): capabilities = _json_stdout(_run_cli("capabilities", "--json"))["commands"] @@ -186,3 +203,68 @@ def test_missing_input_returns_stable_error_code(): payload = _json_stdout(result) assert payload["status"] == "error" assert payload["code"] == "INPUT_FILE_NOT_FOUND" + + +def test_qc_remainder_command_still_honors_json_flag(): + # ``qc`` consumes its arguments via argparse.REMAINDER, so it never binds a top-level + # ``args.json``. The root dispatcher must still emit clean JSON for it via the + # command-agnostic --json detection rather than introspecting one subcommand's attribute. + json_result = _run_cli("qc", str(ROOT / "does-not-exist.set"), "--json") + human_result = _run_cli("qc", str(ROOT / "does-not-exist.set")) + + assert json_result.returncode == 1 + payload = _json_stdout(json_result) + assert payload["status"] == "error" + assert payload["code"] == "INPUT_FILE_NOT_FOUND" + # The human path must not emit the JSON envelope, proving --json actually toggled output. + assert not human_result.stdout.strip().startswith("{") + assert "INPUT_FILE_NOT_FOUND" in human_result.stdout + + +def test_structured_command_error_preserves_non_default_exit_code(capsys): + error = EEGPrepCLIError("CONFIG_SCHEMA_ERROR", "bad config", exit_code=2) + + result = command_error("pipeline run", error) + exit_code = emit_command_result(result, json_output=True) + + # A structured-result usage error (exit_code=2) must not be silently downgraded to 1, so the + # structured path matches the exception path for the same error class. + assert result["exit_code"] == 2 + assert exit_code == 2 + payload = json.loads(capsys.readouterr().out) + assert payload["exit_code"] == 2 + assert payload["error"]["code"] == "CONFIG_SCHEMA_ERROR" + + +def _write_run_config(tmp_path, name): + config = tmp_path / name + config.write_text( + yaml.safe_dump( + { + "schema_version": "eegprep.pipeline.v1", + "input": {"path": str(SAMPLE_DATASET_PATH), "format": "eeglab"}, + "output": {"directory": str(tmp_path / f"out_{name}")}, + "steps": [{"name": "resample", "freq": 128}, {"name": "qc"}], + } + ), + encoding="utf-8", + ) + return config + + +def test_pipeline_run_verbose_and_quiet_flags_take_effect(tmp_path): + verbose = _run_cli("pipeline", "run", str(_write_run_config(tmp_path, "verbose.yaml")), "--json", "--verbose") + quiet = _run_cli("pipeline", "run", str(_write_run_config(tmp_path, "quiet.yaml")), "--json", "--quiet") + + assert verbose.returncode == 0, verbose.stderr + assert quiet.returncode == 0, quiet.stderr + # The flags must actually drive logging: --verbose emits progress chatter, --quiet suppresses it. + verbose_stderr = [line for line in verbose.stderr.splitlines() if line.strip()] + quiet_stderr = [line for line in quiet.stderr.splitlines() if line.strip()] + assert verbose_stderr, "expected --verbose to emit progress logging on stderr" + assert len(quiet_stderr) < len(verbose_stderr) + # In both modes the single JSON line agents parse must stay clean on stdout. + assert len([line for line in verbose.stdout.splitlines() if line.strip()]) == 1 + assert len([line for line in quiet.stdout.splitlines() if line.strip()]) == 1 + assert _json_stdout(verbose)["status"] == "ok" + assert _json_stdout(quiet)["status"] == "ok" diff --git a/tests/test_cli_pipeline_qc_report.py b/tests/test_cli_pipeline_qc_report.py index ad3789a2..1f2d18a6 100644 --- a/tests/test_cli_pipeline_qc_report.py +++ b/tests/test_cli_pipeline_qc_report.py @@ -138,6 +138,20 @@ def fake_pop_eegfiltnew(eeg, **kwargs): assert "pop_eegfiltnew" in result["history"][0] +def test_pipeline_filter_rejects_negative_notch_lower_edge(tmp_path): + config_path = _write_pipeline_config( + tmp_path, + steps=[{"name": "filter", "notch": 1, "notch_width": 4}], + ) + + result = run_pipeline_config(config_path) + + assert result["status"] == "error" + assert result["code"] == "CONFIG_SCHEMA_ERROR" + assert "notch minus half notch_width must be positive" in result["message"] + assert not (tmp_path / "out").exists() + + def test_pipeline_invalid_config_returns_structured_error(tmp_path): config_path = _write_pipeline_config( tmp_path, diff --git a/tests/test_console_workspace.py b/tests/test_console_workspace.py index d996217a..cc85f054 100644 --- a/tests/test_console_workspace.py +++ b/tests/test_console_workspace.py @@ -219,6 +219,26 @@ def test_storedisk_session_retrieve_and_console_pop_call_stay_synchronized(tmp_p EEG_OPTIONS.update(old_options) +def test_console_currentset_reassignment_preserves_both_datasets(): + session = EEGPrepSession() + session.store_current(_demo_eeg("first"), new=True) + session.store_current(_demo_eeg("second"), new=True) + workspace = EEGPrepConsoleWorkspace(session, exports={}) + session.retrieve(1) + + assert workspace.namespace["EEG"]["setname"] == "first" + assert session.CURRENTSET == [1] + + workspace.namespace["CURRENTSET"] = 2 + workspace.after_execute("CURRENTSET = 2") + + assert session.CURRENTSET == [2] + assert session.EEG["setname"] == "second" + assert session.ALLEEG[0]["setname"] == "first" + assert session.ALLEEG[1]["setname"] == "second" + assert workspace.namespace["EEG"]["setname"] == "second" + + def test_console_pop_study_result_updates_shared_study_workspace(): session = EEGPrepSession() session.store_current(_demo_eeg(), new=True) @@ -367,6 +387,16 @@ def test_console_eegh_displays_and_finds_session_history(): workspace.close() +def test_default_console_eegh_uses_session_history_after_public_exports_bind(): + session = EEGPrepSession() + workspace = EEGPrepConsoleWorkspace(session) + session.add_history("EEG = pop_fileio('sample.set');") + + assert workspace.namespace["eegh"]() == "1. EEG = pop_fileio('sample.set');" + assert workspace.namespace["eegprep"].eegh() == "1. EEG = pop_fileio('sample.set');" + workspace.close() + + def test_console_eegh_positive_index_replays_command_through_workspace(): session = EEGPrepSession() session.store_current(_demo_eeg(), new=True) @@ -381,6 +411,37 @@ def test_console_eegh_positive_index_replays_command_through_workspace(): workspace.close() +def test_console_eegh_string_command_notifies_session_listeners(): + session = EEGPrepSession() + workspace = EEGPrepConsoleWorkspace(session, exports={}) + notified: list[int] = [] + session.add_change_listener(lambda _session: notified.append(len(session.ALLCOM))) + + result = workspace.namespace["eegh"]("EEG = pop_loadset('demo.set');") + + assert result == "EEG = pop_loadset('demo.set');" + assert session.ALLCOM == ["EEG = pop_loadset('demo.set');"] + assert session.LASTCOM == "EEG = pop_loadset('demo.set');" + # Routing through session.add_history fires the change listener. + assert notified == [1] + workspace.close() + + +def test_menu_actions_reuses_console_pop_result_decoders(): + # The GUI extension-result path delegates to the canonical console decoders + # instead of keeping its own copies. + from eegprep.functions.guifunc import menu_actions as menu_actions_module + + assert not hasattr(menu_actions_module, "_extension_dataset_state") + assert not hasattr(menu_actions_module, "_extension_eeg_and_command") + + eeg = _demo_eeg() + command = "EEG = pop_demo(EEG);" + result = ([eeg], eeg, 1, command) + assert console_module._extract_pop_dataset_state(result) == ([eeg], eeg, 1, command) + assert console_module._extract_pop_eeg_and_command((eeg, command)) == (eeg, command) + + def test_gui_action_buffers_output_until_command_echo(): session = EEGPrepSession() stream = io.StringIO() @@ -809,6 +870,26 @@ def test_bare_legacy_pop_averef_alias_updates_session_history(): workspace.close() +def test_console_dataset_state_result_updates_session_once_without_duplicate_history(): + session = EEGPrepSession() + session.store_current(_demo_eeg("one"), new=True) + session.store_current(_demo_eeg("two"), new=True) + refresh = mock.Mock() + workspace = EEGPrepConsoleWorkspace(session, refresh=refresh, exports={}) + first = dict(session.ALLEEG[0], setname="one edited") + second = dict(session.ALLEEG[1], setname="two edited") + command = "[ALLEEG EEG CURRENTSET] = pop_newset(ALLEEG, EEG, CURRENTSET, retrieve=[2, 1]);" + + result = workspace.accept_pop_result(([first, second], [second, first], [2, 1], command), (), {}) + + assert list(result)[3] == command + assert session.CURRENTSET == [2, 1] + assert [item["setname"] for item in session.EEG] == ["two edited", "one edited"] + assert [item["setname"] for item in session.ALLEEG] == ["one edited", "two edited"] + assert session.ALLCOM == [command] + refresh.assert_called_once() + + def test_pop_call_without_history_command_records_raw_console_source(): session = EEGPrepSession() session.store_current(_demo_eeg(), new=True) @@ -1504,7 +1585,7 @@ def test_console_python_command_converts_common_eeglab_history_syntax(): assert converted == [ "ALLEEG, EEG, CURRENTSET = pop_newset(ALLEEG, EEG, CURRENTSET, retrieve=1)", "CURRENTSTUDY = 0; ALLEEG, EEG, CURRENTSET = pop_newset(ALLEEG, EEG, CURRENTSET, retrieve=2)", - "EEG = pop_select(EEG, channel=[1, 2], chantype=['EEG', 'EOG'])", + "EEG = pop_select(EEG, channel=[0, 1], chantype=['EEG', 'EOG'])", 'LASTCOM = pop_export(EEG, filename="/tmp/demo\'s data.tsv")', "EEG = pop_resample(EEG, freq=64)", "EEG = pop_comments(EEG, plottitle='', newcomments='sample notes')", @@ -1519,6 +1600,20 @@ def test_console_python_command_converts_common_eeglab_history_syntax(): ast.parse(command) +def test_console_pop_select_numeric_channels_zero_based_on_replay(): + # GUI/history is 1-based (EEGLAB parity); the 0-based Python API requires the + # console to zero-base numeric channel selections so replayed history selects + # the same channels. Channel-by-name and chantype selections must pass through. + numeric = console_module._console_python_command("EEG = pop_select(EEG, 'channel', [1 3], 'rmchannel', [5]);") + assert numeric == "EEG = pop_select(EEG, channel=[0, 2], rmchannel=[4])" + by_name = console_module._console_python_command( + "EEG = pop_select(EEG, 'channel', {'Fz' 'Cz'}, 'chantype', {'EEG'});" + ) + assert by_name == "EEG = pop_select(EEG, channel=['Fz', 'Cz'], chantype=['EEG'])" + for command in (numeric, by_name): + ast.parse(command) + + def test_ipython_adapter_keeps_prompt_message_dynamic(): shell = _FakeShell() shell.prompts = _FakePrompts(shell) diff --git a/tests/test_eeg_amica.py b/tests/test_eeg_amica.py index b50df050..8256aa98 100644 --- a/tests/test_eeg_amica.py +++ b/tests/test_eeg_amica.py @@ -13,9 +13,12 @@ """ import unittest +from unittest import mock import numpy as np +from eegprep.functions.miscfunc.pinv import pinv +from eegprep.functions.popfunc import eeg_amica as eeg_amica_module from eegprep.functions.popfunc.eeg_amica import eeg_amica, load_amica_model from eegprep.functions.sigprocfunc.runamica import is_amica_available @@ -297,5 +300,133 @@ def test_load_amica_model_invalid(self): load_amica_model(result, mods, model_num=-1) +def _eeglab_flattened(data): + """Concatenate trials in EEGLAB column-major epoch order.""" + return np.concatenate([data[:, :, trial] for trial in range(data.shape[2])], axis=1) + + +def _epoched_eeg(): + data = np.zeros((2, 3, 2), dtype=np.float64) + data[0, :, 0] = [1.0, -2.0, 3.0] + data[0, :, 1] = [4.0, -5.0, 6.0] + data[1, :, 0] = [-7.0, 8.0, -9.0] + data[1, :, 1] = [10.0, -11.0, 12.0] + return { + 'data': data, + 'nbchan': 2, + 'pnts': 3, + 'trials': 2, + 'srate': 100, + 'chanlocs': [], + 'icaweights': np.eye(2), + 'icasphere': np.eye(2), + 'etc': {}, + } + + +def _continuous_eeg(): + data = np.array([[1.0, -2.0, 3.0, -4.0], [-5.0, 6.0, -7.0, 8.0]]) + return { + 'data': data, + 'nbchan': 2, + 'pnts': 4, + 'trials': 1, + 'srate': 100, + 'chanlocs': [], + 'icaweights': np.eye(2), + 'icasphere': np.eye(2), + 'etc': {}, + } + + +# A weights/sphere pair with a non-identity sphere so that the posact sign-flip +# step exercises the icaweights @ icasphere @ data == icaact invariant. +_AMICA_WEIGHTS = np.array([[2.0, 1.0], [-1.0, 3.0]]) +_AMICA_SPHERE = np.array([[1.5, 0.5], [0.0, 2.0]]) + + +def _fake_runamica_factory(captured): + winv = pinv(_AMICA_WEIGHTS @ _AMICA_SPHERE) + + def fake_runamica(data, **_kwargs): + captured['data'] = np.asarray(data).copy() + mods = { + 'num_pcs': 2, + 'num_models': 1, + 'W': _AMICA_WEIGHTS.copy()[:, :, None], + 'S': _AMICA_SPHERE.copy(), + 'A': winv.copy()[:, :, None], + } + return _AMICA_WEIGHTS.copy(), _AMICA_SPHERE.copy(), mods + + return fake_runamica + + +class TestEegAmicaRegressions(unittest.TestCase): + """Regression tests that do not require the AMICA binary (runamica mocked).""" + + def test_eeg_amica_flattens_epoched_data_like_eeglab(self): + eeg = _epoched_eeg() + captured = {} + with mock.patch.object(eeg_amica_module, 'runamica', _fake_runamica_factory(captured)): + out = eeg_amica(eeg) + + # AMICA must receive data flattened in EEGLAB column-major epoch order. + np.testing.assert_array_equal(captured['data'], _eeglab_flattened(eeg['data'])) + + # icaact must reshape back to channel x point x trial via the same order, + # so it equals icaweights @ icasphere @ data reshaped per-trial. + expected_2d = (out['icaweights'] @ out['icasphere']) @ _eeglab_flattened(eeg['data']) + expected_3d = expected_2d.reshape(expected_2d.shape[0], out['pnts'], out['trials'], order='F') + np.testing.assert_allclose(out['icaact'], expected_3d) + + def test_eeg_amica_does_not_mutate_caller(self): + eeg = _continuous_eeg() + data_before = eeg['data'].copy() + weights_before = eeg['icaweights'].copy() + captured = {} + with mock.patch.object(eeg_amica_module, 'runamica', _fake_runamica_factory(captured)): + eeg_amica(eeg, posact=True) + + self.assertTrue(np.array_equal(eeg['data'], data_before)) + self.assertTrue(np.array_equal(eeg['icaweights'], weights_before)) + self.assertEqual(eeg['etc'], {}) + + def test_eeg_amica_posact_preserves_ica_invariants(self): + eeg = _continuous_eeg() + captured = {} + with mock.patch.object(eeg_amica_module, 'runamica', _fake_runamica_factory(captured)): + out = eeg_amica(eeg, posact=True) + + data2d = out['data'].reshape(out['nbchan'], -1, order='F') + icaact2d = out['icaact'].reshape(out['icaact'].shape[0], -1, order='F') + + # A posact flip must have occurred (non-identity sphere makes this the + # case that previously folded the sphere into icaweights). + self.assertFalse(np.array_equal(out['icaweights'], _AMICA_WEIGHTS)) + # Core EEGLAB ICA invariants must hold after sign normalization. + np.testing.assert_allclose(out['icaweights'] @ out['icasphere'] @ data2d[out['icachansind']], icaact2d) + np.testing.assert_allclose(out['icawinv'], pinv(out['icaweights'] @ out['icasphere'])) + # icasphere must be untouched by the sign-flip step. + np.testing.assert_array_equal(out['icasphere'], _AMICA_SPHERE) + ix = np.argmax(np.abs(icaact2d), axis=1) + self.assertTrue(np.all(icaact2d[np.arange(icaact2d.shape[0]), ix] >= 0)) + + def test_load_amica_model_does_not_mutate_caller(self): + eeg = _continuous_eeg() + captured = {} + with mock.patch.object(eeg_amica_module, 'runamica', _fake_runamica_factory(captured)): + decomposed = eeg_amica(eeg) + + mods = decomposed['etc']['amica'] + icaact_before = decomposed['icaact'].copy() + reloaded = load_amica_model(decomposed, mods, model_num=0) + + # Reloading model 0 must reproduce the same activations and leave the + # source EEG structure untouched. + np.testing.assert_allclose(reloaded['icaact'], icaact_before) + self.assertTrue(np.array_equal(decomposed['icaact'], icaact_before)) + + if __name__ == '__main__': unittest.main() diff --git a/tests/test_eeg_autocorr_fftw.py b/tests/test_eeg_autocorr_fftw.py index ae005330..8151dffb 100644 --- a/tests/test_eeg_autocorr_fftw.py +++ b/tests/test_eeg_autocorr_fftw.py @@ -311,9 +311,12 @@ def test_comparison_with_regular_autocorr(self): # Both should have same shape self.assertEqual(result_fftw.shape, result_regular.shape) - # Results should be similar (but not necessarily identical due to different FFT implementations) - # Check that they're in the same ballpark - self.assertTrue(np.allclose(result_fftw, result_regular, rtol=0.1, atol=0.1)) + # Both compute the same autocorrelation; the only divergence is that + # eeg_autocorr casts its FFT to single precision (complex64) for MATLAB + # parity while eeg_autocorr_fftw stays double precision. Observed max + # relative difference is ~3e-5, so a float-realistic tolerance still + # catches any several-percent port regression. + self.assertTrue(np.allclose(result_fftw, result_regular, rtol=1e-4, atol=1e-7)) def test_axis_handling_in_fft(self): """Test that FFT operations handle axes correctly.""" diff --git a/tests/test_eeg_compare.py b/tests/test_eeg_compare.py index c567b489..bd574da6 100644 --- a/tests/test_eeg_compare.py +++ b/tests/test_eeg_compare.py @@ -5,13 +5,12 @@ and reports differences in structure and data. """ +import logging import os import unittest import sys -import io import numpy as np import math -from contextlib import redirect_stderr, redirect_stdout # Add src to path for imports sys.path.insert(0, 'src') @@ -19,6 +18,48 @@ from eegprep.functions.adminfunc.eeglabcompat import get_eeglab from eegprep.utils.testing import DebuggableTestCase +EEG_COMPARE_LOGGER = 'eegprep.functions.popfunc.eeg_compare' + + +def run_compare(*args, **kwargs): + """Call eeg_compare while capturing its log output. + + eeg_compare emits informational lines (section headers and "OK" results) at INFO and + differences at WARNING. Return the result plus the WARNING-level text as ``stderr`` and the + full text as ``stdout`` so callers can assert on either, mirroring the previous stream split. + """ + logger = logging.getLogger(EEG_COMPARE_LOGGER) + with _LogCapture(logger) as capture: + result = eeg_compare(*args, **kwargs) + return result, capture.text(logging.INFO), capture.text(logging.WARNING) + + +class _LogCapture(logging.Handler): + def __init__(self, logger): + super().__init__(level=logging.DEBUG) + self._logger = logger + self.records = [] + + def emit(self, record): + self.records.append(record) + + def __enter__(self): + self._prev_level = self._logger.level + self._prev_propagate = self._logger.propagate + self._logger.setLevel(logging.DEBUG) + self._logger.propagate = False + self._logger.addHandler(self) + return self + + def __exit__(self, *exc): + self._logger.removeHandler(self) + self._logger.setLevel(self._prev_level) + self._logger.propagate = self._prev_propagate + return False + + def text(self, min_level): + return '\n'.join(self.format(r) for r in self.records if r.levelno >= min_level) + @unittest.skipIf(os.getenv('EEGPREP_SKIP_MATLAB') == '1', "MATLAB not available") class TestEegCompare(DebuggableTestCase): @@ -107,19 +148,11 @@ def create_test_eeg(self, nbchan=32, pnts=1000, trials=1): def test_identical_datasets(self): """Test comparison of identical datasets.""" - # Capture output - stderr_capture = io.StringIO() - stdout_capture = io.StringIO() - - with redirect_stderr(stderr_capture), redirect_stdout(stdout_capture): - result = eeg_compare(self.basic_eeg1, self.basic_eeg1) + result, stdout_output, _stderr_output = run_compare(self.basic_eeg1, self.basic_eeg1) - # Should return True for identical datasets + # Should return a truthy summary for identical datasets self.assertTrue(result) - # Check output indicates no differences - stdout_output = stdout_capture.getvalue() - # Should have minimal output for identical datasets self.assertIn('Field analysis:', stdout_output) self.assertIn('Chanlocs analysis:', stdout_output) @@ -131,15 +164,10 @@ def test_different_field_values(self): eeg2['setname'] = 'different_dataset' eeg2['subject'] = 'S02' - stderr_capture = io.StringIO() - stdout_capture = io.StringIO() + result, _stdout_output, stderr_output = run_compare(self.basic_eeg1, eeg2) - with redirect_stderr(stderr_capture), redirect_stdout(stdout_capture): - result = eeg_compare(self.basic_eeg1, eeg2) + self.assertTrue(result) # Function should still return a summary - self.assertTrue(result) # Function should still return True - - stderr_output = stderr_capture.getvalue() # Should report differences in subject field self.assertIn('subject differs', stderr_output) @@ -149,15 +177,10 @@ def test_missing_fields(self): del eeg2['subject'] del eeg2['condition'] - stderr_capture = io.StringIO() - stdout_capture = io.StringIO() - - with redirect_stderr(stderr_capture), redirect_stdout(stdout_capture): - result = eeg_compare(self.basic_eeg1, eeg2) + result, _stdout_output, stderr_output = run_compare(self.basic_eeg1, eeg2) self.assertTrue(result) - stderr_output = stderr_capture.getvalue() self.assertIn('subject missing in second dataset', stderr_output) self.assertIn('condition missing in second dataset', stderr_output) @@ -167,15 +190,10 @@ def test_filename_differences_allowed(self): eeg2['filename'] = 'different.set' eeg2['datfile'] = 'different.dat' - stderr_capture = io.StringIO() - stdout_capture = io.StringIO() - - with redirect_stderr(stderr_capture), redirect_stdout(stdout_capture): - result = eeg_compare(self.basic_eeg1, eeg2) + result, stdout_output, _stderr_output = run_compare(self.basic_eeg1, eeg2) self.assertTrue(result) - stdout_output = stdout_capture.getvalue() # Should indicate filename differences are OK self.assertIn('(ok, supposed to differ)', stdout_output) @@ -185,15 +203,10 @@ def test_xmin_xmax_differences(self): eeg2['xmin'] = -0.1 # Different from -0.2 eeg2['xmax'] = 4.0 # Different from original - stderr_capture = io.StringIO() - stdout_capture = io.StringIO() - - with redirect_stderr(stderr_capture), redirect_stdout(stdout_capture): - result = eeg_compare(self.basic_eeg1, eeg2) + result, _stdout_output, stderr_output = run_compare(self.basic_eeg1, eeg2) self.assertTrue(result) - stderr_output = stderr_capture.getvalue() self.assertIn('Difference between xmin', stderr_output) self.assertIn('Difference between xmax', stderr_output) @@ -205,15 +218,10 @@ def test_channel_coordinate_differences(self): eeg2['chanlocs'][1]['Y'] = 999.0 eeg2['chanlocs'][2]['Z'] = 999.0 - stderr_capture = io.StringIO() - stdout_capture = io.StringIO() - - with redirect_stderr(stderr_capture), redirect_stdout(stdout_capture): - result = eeg_compare(self.basic_eeg1, eeg2) + result, _stdout_output, stderr_output = run_compare(self.basic_eeg1, eeg2) self.assertTrue(result) - stderr_output = stderr_capture.getvalue() self.assertIn('channel coordinates differ', stderr_output) def test_channel_label_differences(self): @@ -222,15 +230,10 @@ def test_channel_label_differences(self): eeg2['chanlocs'][0]['labels'] = 'DifferentLabel' eeg2['chanlocs'][1]['labels'] = 'AnotherLabel' - stderr_capture = io.StringIO() - stdout_capture = io.StringIO() - - with redirect_stderr(stderr_capture), redirect_stdout(stdout_capture): - result = eeg_compare(self.basic_eeg1, eeg2) + result, _stdout_output, stderr_output = run_compare(self.basic_eeg1, eeg2) self.assertTrue(result) - stderr_output = stderr_capture.getvalue() self.assertIn('channel label(s) differ', stderr_output) def test_verbose_channel_labels(self): @@ -238,30 +241,20 @@ def test_verbose_channel_labels(self): eeg2 = self.create_test_eeg() eeg2['chanlocs'][0]['labels'] = 'DifferentLabel' - stderr_capture = io.StringIO() - stdout_capture = io.StringIO() - - with redirect_stderr(stderr_capture), redirect_stdout(stdout_capture): - result = eeg_compare(self.basic_eeg1, eeg2, verbose_level=1) + result, _stdout_output, stderr_output = run_compare(self.basic_eeg1, eeg2, verbose_level=1) self.assertTrue(result) - stderr_output = stderr_capture.getvalue() self.assertIn('Ch1 differs from DifferentLabel', stderr_output) def test_different_channel_numbers(self): """Test comparison with different numbers of channels.""" eeg2 = self.create_test_eeg(nbchan=16) # Different number of channels - stderr_capture = io.StringIO() - stdout_capture = io.StringIO() - - with redirect_stderr(stderr_capture), redirect_stdout(stdout_capture): - result = eeg_compare(self.basic_eeg1, eeg2) + result, _stdout_output, stderr_output = run_compare(self.basic_eeg1, eeg2) self.assertTrue(result) - stderr_output = stderr_capture.getvalue() self.assertIn('Different numbers of channels', stderr_output) def test_different_event_numbers(self): @@ -269,15 +262,10 @@ def test_different_event_numbers(self): eeg2 = self.create_test_eeg() eeg2['event'] = eeg2['event'][:2] # Remove one event - stderr_capture = io.StringIO() - stdout_capture = io.StringIO() - - with redirect_stderr(stderr_capture), redirect_stdout(stdout_capture): - result = eeg_compare(self.basic_eeg1, eeg2) + result, _stdout_output, stderr_output = run_compare(self.basic_eeg1, eeg2) self.assertTrue(result) - stderr_output = stderr_capture.getvalue() self.assertIn('Different numbers of events', stderr_output) def test_verbose_event_output(self): @@ -285,15 +273,10 @@ def test_verbose_event_output(self): eeg2 = self.create_test_eeg() eeg2['event'] = [] # No events - stderr_capture = io.StringIO() - stdout_capture = io.StringIO() - - with redirect_stderr(stderr_capture), redirect_stdout(stdout_capture): - result = eeg_compare(self.basic_eeg1, eeg2, verbose_level=1) + result, _stdout_output, stderr_output = run_compare(self.basic_eeg1, eeg2, verbose_level=1) self.assertTrue(result) - stderr_output = stderr_capture.getvalue() self.assertIn('Different numbers of events', stderr_output) self.assertIn('First event of first dataset:', stderr_output) @@ -304,15 +287,10 @@ def test_event_field_differences(self): for event in eeg2['event']: event['extra_field'] = 'test' - stderr_capture = io.StringIO() - stdout_capture = io.StringIO() - - with redirect_stderr(stderr_capture), redirect_stdout(stdout_capture): - result = eeg_compare(self.basic_eeg1, eeg2) + result, _stdout_output, stderr_output = run_compare(self.basic_eeg1, eeg2) self.assertTrue(result) - stderr_output = stderr_capture.getvalue() self.assertIn('Not the same number of event fields', stderr_output) def test_event_latency_differences(self): @@ -321,15 +299,10 @@ def test_event_latency_differences(self): eeg2['event'][0]['latency'] = 300 # Different from 250 eeg2['event'][1]['latency'] = 600 # Different from 500 - stderr_capture = io.StringIO() - stdout_capture = io.StringIO() - - with redirect_stderr(stderr_capture), redirect_stdout(stdout_capture): - result = eeg_compare(self.basic_eeg1, eeg2) + result, _stdout_output, stderr_output = run_compare(self.basic_eeg1, eeg2) self.assertTrue(result) - stderr_output = stderr_capture.getvalue() self.assertIn('Event latency', stderr_output) self.assertIn('not OK', stderr_output) @@ -338,15 +311,10 @@ def test_event_type_differences(self): eeg2 = self.create_test_eeg() eeg2['event'][0]['type'] = 'different_stimulus' - stderr_capture = io.StringIO() - stdout_capture = io.StringIO() - - with redirect_stderr(stderr_capture), redirect_stdout(stdout_capture): - result = eeg_compare(self.basic_eeg1, eeg2) + result, _stdout_output, stderr_output = run_compare(self.basic_eeg1, eeg2) self.assertTrue(result) - stderr_output = stderr_capture.getvalue() # The function should detect differences in event fields self.assertTrue(len(stderr_output) > 0) # Should have some error output @@ -361,11 +329,7 @@ def __init__(self, eeg_dict): eeg_obj1 = EegObject(self.basic_eeg1) eeg_obj2 = EegObject(self.basic_eeg2) - stderr_capture = io.StringIO() - stdout_capture = io.StringIO() - - with redirect_stderr(stderr_capture), redirect_stdout(stdout_capture): - result = eeg_compare(eeg_obj1, eeg_obj2) + result, _stdout_output, _stderr_output = run_compare(eeg_obj1, eeg_obj2) self.assertTrue(result) @@ -374,35 +338,23 @@ def test_eventdescription_differences(self): eeg2 = self.create_test_eeg() eeg2['eventdescription'] = ['stimulus', 'response'] # Different length - stderr_capture = io.StringIO() - stdout_capture = io.StringIO() - - with redirect_stderr(stderr_capture), redirect_stdout(stdout_capture): - result = eeg_compare(self.basic_eeg1, eeg2) + result, stdout_output, stderr_output = run_compare(self.basic_eeg1, eeg2) self.assertTrue(result) - stderr_output = stderr_capture.getvalue() - stdout_output = stdout_capture.getvalue() # The function should report eventdescription differences - # It might be in stderr or stdout depending on the logic + # It might be at info or warning level depending on the logic output_combined = stderr_output + stdout_output self.assertTrue('eventdescription' in output_combined or len(stderr_output) > 0) def test_isequaln_function_coverage(self): """Test the internal isequaln function with various data types.""" - from eegprep.functions.popfunc.eeg_compare import eeg_compare - # Test with None values eeg2 = self.create_test_eeg() eeg2['subject'] = None self.basic_eeg1['subject'] = None - stderr_capture = io.StringIO() - stdout_capture = io.StringIO() - - with redirect_stderr(stderr_capture), redirect_stdout(stdout_capture): - result = eeg_compare(self.basic_eeg1, eeg2) + result, _stdout_output, _stderr_output = run_compare(self.basic_eeg1, eeg2) self.assertTrue(result) @@ -412,11 +364,7 @@ def test_nan_handling(self): eeg2['xmin'] = float('nan') self.basic_eeg1['xmin'] = float('nan') - stderr_capture = io.StringIO() - stdout_capture = io.StringIO() - - with redirect_stderr(stderr_capture), redirect_stdout(stdout_capture): - result = eeg_compare(self.basic_eeg1, eeg2) + result, _stdout_output, _stderr_output = run_compare(self.basic_eeg1, eeg2) self.assertTrue(result) @@ -426,11 +374,7 @@ def test_array_comparisons(self): # Make arrays identical eeg2['data'] = self.basic_eeg1['data'].copy() - stderr_capture = io.StringIO() - stdout_capture = io.StringIO() - - with redirect_stderr(stderr_capture), redirect_stdout(stdout_capture): - result = eeg_compare(self.basic_eeg1, eeg2) + result, _stdout_output, _stderr_output = run_compare(self.basic_eeg1, eeg2) self.assertTrue(result) @@ -440,11 +384,7 @@ def test_scalar_vs_array_comparisons(self): # Test scalar vs array comparison edge cases eeg2['trials'] = np.array([1]) # Array instead of scalar - stderr_capture = io.StringIO() - stdout_capture = io.StringIO() - - with redirect_stderr(stderr_capture), redirect_stdout(stdout_capture): - result = eeg_compare(self.basic_eeg1, eeg2) + result, _stdout_output, _stderr_output = run_compare(self.basic_eeg1, eeg2) self.assertTrue(result) @@ -455,11 +395,7 @@ def test_empty_events(self): eeg1['event'] = [] eeg2['event'] = [] - stderr_capture = io.StringIO() - stdout_capture = io.StringIO() - - with redirect_stderr(stderr_capture), redirect_stdout(stdout_capture): - result = eeg_compare(eeg1, eeg2) + result, _stdout_output, _stderr_output = run_compare(eeg1, eeg2) self.assertTrue(result) @@ -592,5 +528,45 @@ def test_boolean_array_result(self): self.assertTrue(self.isequaln(arr1, arr2)) +class TestEegCompareReturnContract(unittest.TestCase): + """Pin the documented return contract: eeg_compare returns a summary string.""" + + def _eeg(self): + return { + 'setname': 'ds', + 'subject': 'S01', + 'xmin': 0.0, + 'xmax': 1.0, + 'chanlocs': [], + 'event': [], + } + + def test_identical_returns_match_summary_string(self): + result, stdout_output, stderr_output = run_compare(self._eeg(), self._eeg()) + self.assertIsInstance(result, str) + self.assertEqual(result, "All fields match (no differences found)") + self.assertEqual(stderr_output, "") + + def test_differences_returned_as_string_not_bool(self): + eeg2 = self._eeg() + eeg2['subject'] = 'S02' + result, _stdout_output, stderr_output = run_compare(self._eeg(), eeg2) + self.assertIsInstance(result, str) + self.assertNotIsInstance(result, bool) + self.assertIn('differences', result.lower()) + self.assertIn('subject differs', stderr_output) + + def test_array_mismatch_returns_summary_string(self): + result, _stdout_output, _stderr_output = run_compare(np.zeros((2, 3)), np.zeros((3, 2))) + self.assertIsInstance(result, str) + self.assertIn('Array shape mismatch', result) + + def test_trigger_error_raises_on_difference(self): + eeg2 = self._eeg() + eeg2['subject'] = 'S02' + with self.assertRaises(ValueError): + eeg_compare(self._eeg(), eeg2, trigger_error=True) + + if __name__ == '__main__': unittest.main() diff --git a/tests/test_eeg_eeg2mne.py b/tests/test_eeg_eeg2mne.py index c0a69a68..a182e270 100644 --- a/tests/test_eeg_eeg2mne.py +++ b/tests/test_eeg_eeg2mne.py @@ -9,6 +9,7 @@ import numpy as np import tempfile import shutil +from unittest import mock from eegprep.functions.miscfunc.eeg_eeg2mne import eeg_eeg2mne @@ -26,9 +27,6 @@ except (ImportError, ValueError): from fixtures import create_test_eeg -if os.getenv('EEGPREP_SKIP_MATLAB') == '1': - raise unittest.SkipTest("MATLAB not available") - class TestEEGEEG2MNE(unittest.TestCase): """Test cases for eeg_eeg2mne function.""" @@ -60,6 +58,22 @@ def test_eeg_eeg2mne_continuous_data(self): self.assertEqual(result.info['nchan'], continuous_eeg['nbchan']) self.assertEqual(result.n_times, continuous_eeg['pnts']) + @unittest.skipUnless(MNE_AVAILABLE, "MNE not available") + def test_eeg_eeg2mne_cleans_temporary_bridge_files(self): + continuous_eeg = self.test_eeg.copy() + continuous_eeg['data'] = np.random.randn(32, 1000) + continuous_eeg['trials'] = 1 + real_tempdir = tempfile.TemporaryDirectory + + def tempdir_factory(*args, **kwargs): + kwargs["dir"] = self.temp_dir + return real_tempdir(*args, **kwargs) + + with mock.patch("eegprep.functions.miscfunc.eeg_eeg2mne.tempfile.TemporaryDirectory", tempdir_factory): + eeg_eeg2mne(continuous_eeg) + + self.assertEqual(os.listdir(self.temp_dir), []) + @unittest.skipUnless(MNE_AVAILABLE, "MNE not available") def test_eeg_eeg2mne_epoched_data(self): """Test conversion of epoched EEG data.""" diff --git a/tests/test_eeg_eegrej.py b/tests/test_eeg_eegrej.py index 23cf5fee..23f814e5 100644 --- a/tests/test_eeg_eegrej.py +++ b/tests/test_eeg_eegrej.py @@ -2,7 +2,6 @@ import tempfile import unittest import numpy as np -from unittest.mock import patch # Assume eeg_eegrej is defined as in your module that imports: from eegrej import eegrej from eegprep import eeg_eegrej @@ -253,10 +252,12 @@ def test_eeg_eegrej_overlapping_regions(self): # Overlapping regions: [3, 7] and [5, 10] should merge to [3, 10] regions = np.array([[3, 7], [5, 10]]) - with patch('builtins.print') as mock_print: + with self.assertLogs("eegprep.functions.popfunc.eeg_eegrej", level="WARNING") as captured: result = eeg_eegrej(EEG, regions) - # Should print warning about overlapping regions - mock_print.assert_called_with("Warning: overlapping regions detected and fixed in eeg_eegrej") + + self.assertTrue( + any("Overlapping regions detected and fixed in eeg_eegrej" in message for message in captured.output) + ) # Should have 20 - 8 = 12 samples remaining (removed samples 3-10) self.assertEqual(result['pnts'], 12) @@ -384,6 +385,25 @@ def test_eeg_eegrej_event_cleanup(self): self.assertNotIn(0.0, latencies) self.assertNotIn(float(result['pnts']), latencies) + def test_eeg_eegrej_keeps_nonboundary_events_at_edges(self): + """Genuine (non-boundary) events at the first/last sample must not be dropped.""" + EEG = self.base_eeg.copy() + # stim at the first sample (latency 0) and a stim that lands on the + # final sample after rejection (latency 20 -> 17 once 3 samples removed) + EEG['event'] = [ + {"type": "stim", "latency": 0.0}, + {"type": "stim", "latency": 20.0}, + ] + + regions = np.array([[3, 5]]) + result = eeg_eegrej(EEG, regions) + + self.assertEqual(result['pnts'], 17) + stim_events = [e for e in result['event'] if e.get('type') == 'stim'] + stim_latencies = sorted(e['latency'] for e in stim_events) + # both stim events survive: one at the first sample, one at the last sample + self.assertEqual(stim_latencies, [0.0, 17.0]) + def test_eeg_eegrej_duplicate_event_cleanup(self): """Test eeg_eegrej duplicate event cleanup.""" EEG = self.base_eeg.copy() diff --git a/tests/test_eeg_mne2eeg.py b/tests/test_eeg_mne2eeg.py index 8b9000a7..9f9045ce 100644 --- a/tests/test_eeg_mne2eeg.py +++ b/tests/test_eeg_mne2eeg.py @@ -10,6 +10,7 @@ import tempfile import os import shutil +from unittest import mock # Add src to path for imports sys.path.insert(0, 'src') @@ -70,27 +71,38 @@ def test_eeg_mne2eeg_raw_object(self): # Create Raw object raw = mne.io.RawArray(data, info) - try: - result = eeg_mne2eeg(raw) + result = eeg_mne2eeg(raw) - # Check that result is a dict (EEGLAB format) - self.assertIsInstance(result, dict) + # Check that result is a dict (EEGLAB format) + self.assertIsInstance(result, dict) - # Check basic fields - self.assertIn('data', result) - self.assertIn('srate', result) - self.assertIn('nbchan', result) - self.assertIn('pnts', result) - self.assertIn('trials', result) + # Check basic fields + self.assertIn('data', result) + self.assertIn('srate', result) + self.assertIn('nbchan', result) + self.assertIn('pnts', result) + self.assertIn('trials', result) - # Check data dimensions - self.assertEqual(result['nbchan'], n_channels) - self.assertEqual(result['pnts'], n_times) - self.assertEqual(result['trials'], 1) - self.assertEqual(result['srate'], sfreq) + # Check data dimensions + self.assertEqual(result['nbchan'], n_channels) + self.assertEqual(result['pnts'], n_times) + self.assertEqual(result['trials'], 1) + self.assertEqual(result['srate'], sfreq) - except Exception as e: - self.skipTest(f"eeg_mne2eeg raw conversion not available: {e}") + @unittest.skipUnless(MNE_AVAILABLE, "MNE not available") + def test_eeg_mne2eeg_cleans_temporary_bridge_files(self): + info = mne.create_info(["EEG001", "EEG002"], 100.0, ch_types='eeg') + raw = mne.io.RawArray(np.random.randn(2, 100), info) + real_tempdir = tempfile.TemporaryDirectory + + def tempdir_factory(*args, **kwargs): + kwargs["dir"] = self.temp_dir + return real_tempdir(*args, **kwargs) + + with mock.patch("eegprep.functions.miscfunc.eeg_mne2eeg.tempfile.TemporaryDirectory", tempdir_factory): + eeg_mne2eeg(raw) + + self.assertEqual(os.listdir(self.temp_dir), []) @unittest.skipUnless(MNE_AVAILABLE, "MNE not available") def test_eeg_mne2eeg_epochs_object(self): @@ -117,27 +129,23 @@ def test_eeg_mne2eeg_epochs_object(self): # Create Epochs object epochs = mne.EpochsArray(data, info, events, tmin=0, event_id=event_id) - try: - result = eeg_mne2eeg(epochs) - - # Check that result is a dict (EEGLAB format) - self.assertIsInstance(result, dict) + result = eeg_mne2eeg(epochs) - # Check basic fields - self.assertIn('data', result) - self.assertIn('srate', result) - self.assertIn('nbchan', result) - self.assertIn('pnts', result) - self.assertIn('trials', result) + # Check that result is a dict (EEGLAB format) + self.assertIsInstance(result, dict) - # Check data dimensions - self.assertEqual(result['nbchan'], n_channels) - self.assertEqual(result['pnts'], n_times) - self.assertEqual(result['trials'], n_epochs) - self.assertEqual(result['srate'], sfreq) + # Check basic fields + self.assertIn('data', result) + self.assertIn('srate', result) + self.assertIn('nbchan', result) + self.assertIn('pnts', result) + self.assertIn('trials', result) - except Exception as e: - self.skipTest(f"eeg_mne2eeg epochs conversion not available: {e}") + # Check data dimensions + self.assertEqual(result['nbchan'], n_channels) + self.assertEqual(result['pnts'], n_times) + self.assertEqual(result['trials'], n_epochs) + self.assertEqual(result['srate'], sfreq) @unittest.skipUnless(MNE_AVAILABLE, "MNE not available") def test_eeg_mne2eeg_with_annotations(self): @@ -158,28 +166,24 @@ def test_eeg_mne2eeg_with_annotations(self): ) raw.set_annotations(annotations) - try: - result = eeg_mne2eeg(raw) + result = eeg_mne2eeg(raw) - # Check that result is a dict - self.assertIsInstance(result, dict) + # Check that result is a dict + self.assertIsInstance(result, dict) - # Check that events were converted - self.assertIn('event', result) - self.assertIsInstance(result['event'], list) + # Check that events were converted + self.assertIn('event', result) + self.assertIsInstance(result['event'], list) - # Check event count - self.assertEqual(len(result['event']), 3) + # Check event count + self.assertEqual(len(result['event']), 3) - # Check event structure - for event in result['event']: - self.assertIn('latency', event) - self.assertIn('type', event) - self.assertIsInstance(event['latency'], int) - self.assertIsInstance(event['type'], str) - - except Exception as e: - self.skipTest(f"eeg_mne2eeg with annotations not available: {e}") + # Check event structure + for event in result['event']: + self.assertIn('latency', event) + self.assertIn('type', event) + self.assertIsInstance(event['latency'], int) + self.assertIsInstance(event['type'], str) @unittest.skipUnless(MNE_AVAILABLE, "MNE not available") def test_eeg_mne2eeg_with_events(self): @@ -208,27 +212,23 @@ def test_eeg_mne2eeg_with_events(self): epochs = mne.EpochsArray(data, info, events, tmin=0, event_id=event_id) - try: - result = eeg_mne2eeg(epochs) - - # Check that result is a dict - self.assertIsInstance(result, dict) + result = eeg_mne2eeg(epochs) - # Check that events were converted - self.assertIn('event', result) - self.assertIsInstance(result['event'], list) + # Check that result is a dict + self.assertIsInstance(result, dict) - # Check event count - self.assertEqual(len(result['event']), 5) + # Check that events were converted + self.assertIn('event', result) + self.assertIsInstance(result['event'], list) - # Check event types - event_types = [event['type'] for event in result['event']] - self.assertIn('stimulus', event_types) - self.assertIn('response', event_types) - self.assertIn('feedback', event_types) + # Check event count + self.assertEqual(len(result['event']), 5) - except Exception as e: - self.skipTest(f"eeg_mne2eeg with events not available: {e}") + # Check event types + event_types = [event['type'] for event in result['event']] + self.assertIn('stimulus', event_types) + self.assertIn('response', event_types) + self.assertIn('feedback', event_types) @unittest.skipUnless(MNE_AVAILABLE, "MNE not available") def test_eeg_mne2eeg_single_channel(self): @@ -243,19 +243,15 @@ def test_eeg_mne2eeg_single_channel(self): data = np.random.randn(n_channels, n_times) raw = mne.io.RawArray(data, info) - try: - result = eeg_mne2eeg(raw) - - # Check that result is a dict - self.assertIsInstance(result, dict) + result = eeg_mne2eeg(raw) - # Check data dimensions - self.assertEqual(result['nbchan'], 1) - self.assertEqual(result['pnts'], 500) - self.assertEqual(result['trials'], 1) + # Check that result is a dict + self.assertIsInstance(result, dict) - except Exception as e: - self.skipTest(f"eeg_mne2eeg single channel not available: {e}") + # Check data dimensions + self.assertEqual(result['nbchan'], 1) + self.assertEqual(result['pnts'], 500) + self.assertEqual(result['trials'], 1) @unittest.skipUnless(MNE_AVAILABLE, "MNE not available") def test_eeg_mne2eeg_short_data(self): @@ -270,19 +266,15 @@ def test_eeg_mne2eeg_short_data(self): data = np.random.randn(n_channels, n_times) raw = mne.io.RawArray(data, info) - try: - result = eeg_mne2eeg(raw) + result = eeg_mne2eeg(raw) - # Check that result is a dict - self.assertIsInstance(result, dict) + # Check that result is a dict + self.assertIsInstance(result, dict) - # Check data dimensions - self.assertEqual(result['nbchan'], 8) - self.assertEqual(result['pnts'], 10) - self.assertEqual(result['trials'], 1) - - except Exception as e: - self.skipTest(f"eeg_mne2eeg short data not available: {e}") + # Check data dimensions + self.assertEqual(result['nbchan'], 8) + self.assertEqual(result['pnts'], 10) + self.assertEqual(result['trials'], 1) @unittest.skipUnless(MNE_AVAILABLE, "MNE not available") def test_eeg_mne2eeg_float32_data(self): @@ -297,17 +289,13 @@ def test_eeg_mne2eeg_float32_data(self): data = np.random.randn(n_channels, n_times).astype(np.float32) raw = mne.io.RawArray(data, info) - try: - result = eeg_mne2eeg(raw) - - # Check that result is a dict - self.assertIsInstance(result, dict) + result = eeg_mne2eeg(raw) - # Check data type - self.assertEqual(result['data'].dtype, np.float32) + # Check that result is a dict + self.assertIsInstance(result, dict) - except Exception as e: - self.skipTest(f"eeg_mne2eeg float32 data not available: {e}") + # Check data type + self.assertEqual(result['data'].dtype, np.float32) @unittest.skipUnless(MNE_AVAILABLE, "MNE not available") def test_eeg_mne2eeg_large_dataset(self): @@ -322,20 +310,16 @@ def test_eeg_mne2eeg_large_dataset(self): data = np.random.randn(n_channels, n_times) raw = mne.io.RawArray(data, info) - try: - result = eeg_mne2eeg(raw) + result = eeg_mne2eeg(raw) - # Check that result is a dict - self.assertIsInstance(result, dict) + # Check that result is a dict + self.assertIsInstance(result, dict) - # Check data dimensions - self.assertEqual(result['nbchan'], 64) - self.assertEqual(result['pnts'], 5000) - self.assertEqual(result['trials'], 1) - self.assertEqual(result['srate'], 1000.0) - - except Exception as e: - self.skipTest(f"eeg_mne2eeg large dataset not available: {e}") + # Check data dimensions + self.assertEqual(result['nbchan'], 64) + self.assertEqual(result['pnts'], 5000) + self.assertEqual(result['trials'], 1) + self.assertEqual(result['srate'], 1000.0) @unittest.skipUnless(MNE_AVAILABLE, "MNE not available") def test_eeg_mne2eeg_empty_annotations(self): @@ -354,18 +338,14 @@ def test_eeg_mne2eeg_empty_annotations(self): empty_annotations = mne.Annotations([], [], []) raw.set_annotations(empty_annotations) - try: - result = eeg_mne2eeg(raw) - - # Check that result is a dict - self.assertIsInstance(result, dict) + result = eeg_mne2eeg(raw) - # Check that events field exists but is empty - self.assertIn('event', result) - self.assertEqual(len(result['event']), 0) + # Check that result is a dict + self.assertIsInstance(result, dict) - except Exception as e: - self.skipTest(f"eeg_mne2eeg empty annotations not available: {e}") + # Check that events field exists but is empty + self.assertIn('event', result) + self.assertEqual(len(result['event']), 0) @unittest.skipUnless(MNE_AVAILABLE, "MNE not available") def test_eeg_mne2eeg_no_events(self): @@ -385,24 +365,20 @@ def test_eeg_mne2eeg_no_events(self): epochs = mne.EpochsArray(data, info, events, tmin=0) - try: - result = eeg_mne2eeg(epochs) - - # Check that result is a dict - self.assertIsInstance(result, dict) + result = eeg_mne2eeg(epochs) - # Check that events were converted with string types - self.assertIn('event', result) - self.assertIsInstance(result['event'], list) - self.assertEqual(len(result['event']), 3) + # Check that result is a dict + self.assertIsInstance(result, dict) - # Check that event types are strings - for event in result['event']: - self.assertIsInstance(event['type'], str) - self.assertEqual(event['type'], '999') + # Check that events were converted with string types + self.assertIn('event', result) + self.assertIsInstance(result['event'], list) + self.assertEqual(len(result['event']), 3) - except Exception as e: - self.skipTest(f"eeg_mne2eeg no events not available: {e}") + # Check that event types are strings + for event in result['event']: + self.assertIsInstance(event['type'], str) + self.assertEqual(event['type'], '999') @unittest.skipUnless(MNE_AVAILABLE, "MNE not available") def test_eeg_mne2eeg_integration_workflow(self): @@ -425,29 +401,25 @@ def test_eeg_mne2eeg_integration_workflow(self): ) raw.set_annotations(annotations) - try: - result = eeg_mne2eeg(raw) + result = eeg_mne2eeg(raw) - # Check that result is a dict - self.assertIsInstance(result, dict) + # Check that result is a dict + self.assertIsInstance(result, dict) - # Check basic properties - self.assertEqual(result['nbchan'], 32) - self.assertEqual(result['pnts'], 2000) - self.assertEqual(result['trials'], 1) - self.assertEqual(result['srate'], 500.0) + # Check basic properties + self.assertEqual(result['nbchan'], 32) + self.assertEqual(result['pnts'], 2000) + self.assertEqual(result['trials'], 1) + self.assertEqual(result['srate'], 500.0) - # Check events - self.assertIn('event', result) - self.assertEqual(len(result['event']), 5) + # Check events + self.assertIn('event', result) + self.assertEqual(len(result['event']), 5) - # Check event types - event_types = [event['type'] for event in result['event']] - self.assertIn('stimulus', event_types) - self.assertIn('response', event_types) - - except Exception as e: - self.skipTest(f"eeg_mne2eeg integration workflow not available: {e}") + # Check event types + event_types = [event['type'] for event in result['event']] + self.assertIn('stimulus', event_types) + self.assertIn('response', event_types) class TestMNEEventsToEEGLABEvents(unittest.TestCase): @@ -469,27 +441,23 @@ def __init__(self, annotations, sfreq): raw = MockRaw(annotations, 500.0) - try: - result = _mne_events_to_eeglab_events(raw) - - # Check result structure - self.assertIsInstance(result, list) - self.assertEqual(len(result), 3) + result = _mne_events_to_eeglab_events(raw) - # Check event structure - for event in result: - self.assertIn('latency', event) - self.assertIn('type', event) - self.assertIsInstance(event['latency'], int) - self.assertIsInstance(event['type'], str) + # Check result structure + self.assertIsInstance(result, list) + self.assertEqual(len(result), 3) - # Check latency values (1-based indexing) - latencies = [event['latency'] for event in result] - expected_latencies = [int(0.1 * 500) + 1, int(0.5 * 500) + 1, int(1.0 * 500) + 1] - self.assertEqual(latencies, expected_latencies) + # Check event structure + for event in result: + self.assertIn('latency', event) + self.assertIn('type', event) + self.assertIsInstance(event['latency'], int) + self.assertIsInstance(event['type'], str) - except Exception as e: - self.skipTest(f"_mne_events_to_eeglab_events annotations not available: {e}") + # Check latency values (1-based indexing) + latencies = [event['latency'] for event in result] + expected_latencies = [int(0.1 * 500) + 1, int(0.5 * 500) + 1, int(1.0 * 500) + 1] + self.assertEqual(latencies, expected_latencies) @unittest.skipUnless(MNE_AVAILABLE, "MNE not available") def test_mne_events_to_eeglab_events_events_array(self): @@ -513,32 +481,28 @@ def __init__(self, events, event_id, sfreq): event_id = {'stimulus': 1, 'response': 2} epochs = MockEpochs(events, event_id, 500.0) - try: - result = _mne_events_to_eeglab_events(epochs) + result = _mne_events_to_eeglab_events(epochs) - # Check result structure - self.assertIsInstance(result, list) - self.assertEqual(len(result), 3) + # Check result structure + self.assertIsInstance(result, list) + self.assertEqual(len(result), 3) - # Check event structure - for event in result: - self.assertIn('latency', event) - self.assertIn('type', event) - self.assertIsInstance(event['latency'], int) - self.assertIsInstance(event['type'], str) + # Check event structure + for event in result: + self.assertIn('latency', event) + self.assertIn('type', event) + self.assertIsInstance(event['latency'], int) + self.assertIsInstance(event['type'], str) - # Check latency values (1-based indexing) - latencies = [event['latency'] for event in result] - expected_latencies = [1, 101, 201] - self.assertEqual(latencies, expected_latencies) + # Check latency values (1-based indexing) + latencies = [event['latency'] for event in result] + expected_latencies = [1, 101, 201] + self.assertEqual(latencies, expected_latencies) - # Check event types - event_types = [event['type'] for event in result] - expected_types = ['stimulus', 'response', 'stimulus'] - self.assertEqual(event_types, expected_types) - - except Exception as e: - self.skipTest(f"_mne_events_to_eeglab_events events array not available: {e}") + # Check event types + event_types = [event['type'] for event in result] + expected_types = ['stimulus', 'response', 'stimulus'] + self.assertEqual(event_types, expected_types) @unittest.skipUnless(MNE_AVAILABLE, "MNE not available") def test_mne_events_to_eeglab_events_no_event_id(self): @@ -559,20 +523,16 @@ def __init__(self, events, sfreq): epochs = MockEpochs(events, 500.0) - try: - result = _mne_events_to_eeglab_events(epochs) - - # Check result structure - self.assertIsInstance(result, list) - self.assertEqual(len(result), 2) + result = _mne_events_to_eeglab_events(epochs) - # Check event types (should be string representations of numbers) - event_types = [event['type'] for event in result] - expected_types = ['1', '2'] - self.assertEqual(event_types, expected_types) + # Check result structure + self.assertIsInstance(result, list) + self.assertEqual(len(result), 2) - except Exception as e: - self.skipTest(f"_mne_events_to_eeglab_events no event_id not available: {e}") + # Check event types (should be string representations of numbers) + event_types = [event['type'] for event in result] + expected_types = ['1', '2'] + self.assertEqual(event_types, expected_types) @unittest.skipUnless(MNE_AVAILABLE, "MNE not available") def test_mne_events_to_eeglab_events_empty_annotations(self): @@ -587,15 +547,11 @@ def __init__(self, annotations, sfreq): raw = MockRaw(annotations, 500.0) - try: - result = _mne_events_to_eeglab_events(raw) + result = _mne_events_to_eeglab_events(raw) - # Check result structure - self.assertIsInstance(result, list) - self.assertEqual(len(result), 0) - - except Exception as e: - self.skipTest(f"_mne_events_to_eeglab_events empty annotations not available: {e}") + # Check result structure + self.assertIsInstance(result, list) + self.assertEqual(len(result), 0) @unittest.skipUnless(MNE_AVAILABLE, "MNE not available") def test_mne_events_to_eeglab_events_no_events(self): @@ -610,15 +566,11 @@ def __init__(self, events, sfreq): epochs = MockEpochs(events, 500.0) - try: - result = _mne_events_to_eeglab_events(epochs) - - # Check result structure - self.assertIsInstance(result, list) - self.assertEqual(len(result), 0) + result = _mne_events_to_eeglab_events(epochs) - except Exception as e: - self.skipTest(f"_mne_events_to_eeglab_events no events not available: {e}") + # Check result structure + self.assertIsInstance(result, list) + self.assertEqual(len(result), 0) if __name__ == '__main__': diff --git a/tests/test_eeg_mne2eeg_epochs.py b/tests/test_eeg_mne2eeg_epochs.py index 1f120272..3964e3cf 100644 --- a/tests/test_eeg_mne2eeg_epochs.py +++ b/tests/test_eeg_mne2eeg_epochs.py @@ -4,6 +4,8 @@ This module tests the eeg_mne2eeg_epochs function that converts MNE Epochs with ICA to EEGLAB datasets. """ +import contextlib +import io import unittest import os import numpy as np @@ -11,6 +13,7 @@ import shutil from eegprep.functions.miscfunc.eeg_mne2eeg_epochs import eeg_mne2eeg_epochs +from eegprep.functions.miscfunc.misc import finite_matmul, finite_pinv try: import mne @@ -25,9 +28,6 @@ except (ImportError, ValueError): from fixtures import create_test_eeg -if os.getenv('EEGPREP_SKIP_MATLAB') == '1': - raise unittest.SkipTest("MATLAB not available") - class TestEEGMNE2EEGEpochs(unittest.TestCase): """Test cases for eeg_mne2eeg_epochs function.""" @@ -85,6 +85,28 @@ def test_eeg_mne2eeg_epochs_basic_functionality(self): except Exception as e: self.skipTest(f"eeg_mne2eeg_epochs basic functionality not available: {e}") + @unittest.skipUnless(MNE_AVAILABLE, "MNE not available") + def test_eeg_mne2eeg_epochs_uses_channel_major_data_without_stdout(self): + n_channels = 4 + n_times = 20 + n_epochs = 3 + sfreq = 100.0 + ch_names = [f'EEG{i:03d}' for i in range(n_channels)] + info = mne.create_info(ch_names, sfreq, ch_types='eeg') + data = np.arange(n_epochs * n_channels * n_times, dtype=float).reshape(n_epochs, n_channels, n_times) + events = np.array([[i, 0, 1] for i in range(n_epochs)]) + epochs = mne.EpochsArray(data, info, events, tmin=0, event_id={'event': 1}, verbose=False) + ica = ICA(n_components=2, random_state=42, max_iter=20) + ica.fit(epochs, verbose=False) + + stream = io.StringIO() + with contextlib.redirect_stdout(stream): + result = eeg_mne2eeg_epochs(epochs, ica) + + self.assertEqual(stream.getvalue(), "") + self.assertEqual(result['data'].shape, (n_channels, n_times, n_epochs)) + np.testing.assert_allclose(result['data'], np.transpose(data, (1, 2, 0))) + @unittest.skipUnless(MNE_AVAILABLE, "MNE not available") def test_eeg_mne2eeg_epochs_ica_fields(self): """Test ICA fields in the converted EEGLAB dataset.""" @@ -106,25 +128,23 @@ def test_eeg_mne2eeg_epochs_ica_fields(self): ica = ICA(n_components=8, random_state=42) ica.fit(epochs) - try: - result = eeg_mne2eeg_epochs(epochs, ica) - - # Check ICA fields - self.assertIn('icaact', result) - self.assertIn('icawinv', result) - self.assertIn('icasphere', result) - self.assertIn('icaweights', result) - self.assertIn('icachansind', result) - - # Check ICA field shapes - self.assertEqual(result['icaact'].shape, (8, n_times, n_epochs)) # n_components x n_times x n_epochs - self.assertEqual(result['icawinv'].shape, (8, n_channels)) # n_components x n_channels - self.assertEqual(result['icasphere'].shape, (n_channels, 8)) # n_channels x n_components - self.assertEqual(result['icaweights'].shape, (n_channels, n_channels)) # identity matrix - self.assertEqual(len(result['icachansind']), n_channels) # channel indices - - except Exception as e: - self.skipTest(f"eeg_mne2eeg_epochs ICA fields not available: {e}") + result = eeg_mne2eeg_epochs(epochs, ica) + + self.assertIn('icaact', result) + self.assertIn('icawinv', result) + self.assertIn('icasphere', result) + self.assertIn('icaweights', result) + self.assertIn('icachansind', result) + self.assertEqual(result['icaact'].shape, (8, n_times, n_epochs)) + self.assertEqual(result['icawinv'].shape, (n_channels, 8)) + self.assertEqual(result['icasphere'].shape, (n_channels, n_channels)) + self.assertEqual(result['icaweights'].shape, (8, n_channels)) + self.assertEqual(len(result['icachansind']), n_channels) + unmixing = finite_matmul(result['icaweights'], result['icasphere']) + data_2d = result['data'][result['icachansind']].reshape(n_channels, -1, order="F") + icaact_2d = result['icaact'].reshape(8, -1, order="F") + np.testing.assert_allclose(finite_matmul(unmixing, data_2d), icaact_2d, rtol=1e-10, atol=1e-10) + np.testing.assert_allclose(finite_pinv(unmixing), result['icawinv'], rtol=1e-10, atol=1e-10) @unittest.skipUnless(MNE_AVAILABLE, "MNE not available") def test_eeg_mne2eeg_epochs_channel_locations(self): @@ -252,9 +272,8 @@ def test_eeg_mne2eeg_epochs_single_epoch(self): try: result = eeg_mne2eeg_epochs(epochs, ica) - # Check data dimensions (data is in MNE format: n_epochs x n_channels x n_times) self.assertEqual(result['trials'], 1) - self.assertEqual(result['data'].shape, (n_epochs, n_channels, n_times)) + self.assertEqual(result['data'].shape, (n_channels, n_times, n_epochs)) self.assertEqual(result['icaact'].shape, (8, n_times, n_epochs)) except Exception as e: @@ -286,8 +305,8 @@ def test_eeg_mne2eeg_epochs_minimal_channels(self): result = eeg_mne2eeg_epochs(epochs, ica) # Check data dimensions - self.assertEqual(result['nbchan'], 1) - self.assertEqual(result['data'].shape, (1, n_times, n_epochs)) + self.assertEqual(result['nbchan'], n_channels) + self.assertEqual(result['data'].shape, (n_channels, n_times, n_epochs)) self.assertEqual(result['icaact'].shape, (2, n_times, n_epochs)) except Exception as e: @@ -317,10 +336,9 @@ def test_eeg_mne2eeg_epochs_short_data(self): try: result = eeg_mne2eeg_epochs(epochs, ica) - # Check data dimensions (data is in MNE format: n_epochs x n_channels x n_times) self.assertEqual(result['pnts'], 10) self.assertEqual(result['trials'], 3) - self.assertEqual(result['data'].shape, (n_epochs, n_channels, n_times)) + self.assertEqual(result['data'].shape, (n_channels, n_times, n_epochs)) except Exception as e: self.skipTest(f"eeg_mne2eeg_epochs short data not available: {e}") @@ -484,9 +502,9 @@ def test_eeg_mne2eeg_epochs_integration_workflow(self): # Check ICA properties self.assertEqual(result['icaact'].shape, (15, 200, 20)) - self.assertEqual(result['icawinv'].shape, (15, 32)) - self.assertEqual(result['icasphere'].shape, (32, 15)) - self.assertEqual(result['icaweights'].shape, (32, 32)) + self.assertEqual(result['icawinv'].shape, (32, 15)) + self.assertEqual(result['icasphere'].shape, (32, 32)) + self.assertEqual(result['icaweights'].shape, (15, 32)) self.assertEqual(len(result['icachansind']), 32) # Check channel locations diff --git a/tests/test_eeg_picard.py b/tests/test_eeg_picard.py index 604a9e96..987d2087 100644 --- a/tests/test_eeg_picard.py +++ b/tests/test_eeg_picard.py @@ -1,10 +1,14 @@ import os import unittest import numpy as np +import pytest from eegprep import pop_loadset, eeg_picard, pop_saveset from eegprep.functions.adminfunc.eeglabcompat import get_eeglab +from eegprep.functions.miscfunc.pinv import pinv from eegprep.utils.testing import DebuggableTestCase, matlab_function_exists +from tests.fixtures import create_test_eeg as _create_test_eeg + def compare_ica_components(weights1, weights2, rtol=0.01, atol=0.05): """Compare ICA weight matrices accounting for permutation and sign ambiguity. @@ -66,50 +70,8 @@ def compare_ica_components(weights1, weights2, rtol=0.01, atol=0.05): def create_test_eeg(): - """Create a complete test EEG structure with all required fields.""" - return { - 'data': np.random.randn(32, 1000, 10), - 'srate': 500.0, - 'nbchan': 32, - 'pnts': 1000, - 'trials': 10, - 'xmin': -1.0, - 'xmax': 1.0, - 'times': np.linspace(-1.0, 1.0, 1000), - 'icaact': [], - 'icawinv': [], - 'icasphere': [], - 'icaweights': [], - 'icachansind': [], - 'chanlocs': [ - { - 'labels': f'EEG{i:03d}', - 'type': 'EEG', - 'theta': np.random.uniform(-90, 90), - 'radius': np.random.uniform(0, 1), - 'X': np.random.uniform(-1, 1), - 'Y': np.random.uniform(-1, 1), - 'Z': np.random.uniform(-1, 1), - 'sph_theta': np.random.uniform(-180, 180), - 'sph_phi': np.random.uniform(-90, 90), - 'sph_radius': np.random.uniform(0, 1), - 'urchan': i + 1, - 'ref': '', - } - for i in range(32) - ], - 'urchanlocs': [], - 'chaninfo': [], - 'ref': 'common', - 'history': '', - 'saved': 'yes', - 'etc': {}, - 'event': [], - 'epoch': [], - 'setname': 'test_dataset', - 'filename': 'test.set', - 'filepath': '/tmp', - } + """Epoched EEG fixture sized for eeg_picard (32 ch, 1000 pnts, 10 trials).""" + return _create_test_eeg(n_channels=32, n_samples=1000, srate=500.0, n_trials=10) class TestEegPicardSimple(DebuggableTestCase): @@ -121,206 +83,192 @@ def setUp(self): def test_eeg_picard_basic_functionality(self): """Test basic eeg_picard functionality with default parameters.""" - try: - result = eeg_picard(self.test_eeg.copy()) - - # Check that all ICA fields are present - self.assertIn('icaweights', result) - self.assertIn('icasphere', result) - self.assertIn('icawinv', result) - self.assertIn('icaact', result) - self.assertIn('icachansind', result) - - # Check data types - self.assertIsInstance(result['icaweights'], np.ndarray) - self.assertIsInstance(result['icasphere'], np.ndarray) - self.assertIsInstance(result['icawinv'], np.ndarray) - self.assertIsInstance(result['icaact'], np.ndarray) - self.assertIsInstance(result['icachansind'], np.ndarray) - - # Check shapes - n_chans = self.test_eeg['nbchan'] - n_pnts = self.test_eeg['pnts'] - n_trials = self.test_eeg['trials'] - - self.assertEqual(result['icaweights'].shape, (n_chans, n_chans)) - self.assertEqual(result['icasphere'].shape, (n_chans, n_chans)) - self.assertEqual(result['icawinv'].shape, (n_chans, n_chans)) - self.assertEqual(result['icaact'].shape, (n_chans, n_pnts, n_trials)) - self.assertEqual(len(result['icachansind']), n_chans) - - except Exception as e: - self.skipTest(f"eeg_picard basic functionality not available: {e}") + result = eeg_picard(self.test_eeg.copy()) + + # Check that all ICA fields are present + self.assertIn('icaweights', result) + self.assertIn('icasphere', result) + self.assertIn('icawinv', result) + self.assertIn('icaact', result) + self.assertIn('icachansind', result) + + # Check data types + self.assertIsInstance(result['icaweights'], np.ndarray) + self.assertIsInstance(result['icasphere'], np.ndarray) + self.assertIsInstance(result['icawinv'], np.ndarray) + self.assertIsInstance(result['icaact'], np.ndarray) + self.assertIsInstance(result['icachansind'], np.ndarray) + + # Check shapes + n_chans = self.test_eeg['nbchan'] + n_pnts = self.test_eeg['pnts'] + n_trials = self.test_eeg['trials'] + + self.assertEqual(result['icaweights'].shape, (n_chans, n_chans)) + self.assertEqual(result['icasphere'].shape, (n_chans, n_chans)) + self.assertEqual(result['icawinv'].shape, (n_chans, n_chans)) + self.assertEqual(result['icaact'].shape, (n_chans, n_pnts, n_trials)) + self.assertEqual(len(result['icachansind']), n_chans) def test_eeg_picard_with_custom_parameters(self): """Test eeg_picard with custom parameters.""" - try: - result = eeg_picard( - self.test_eeg.copy(), - max_iter=10, # picard uses max_iter, not maxiter - verbose=False, - random_state=42, - ) - - # Check that all ICA fields are present - self.assertIn('icaweights', result) - self.assertIn('icasphere', result) - self.assertIn('icawinv', result) - self.assertIn('icaact', result) - self.assertIn('icachansind', result) - - except Exception as e: - self.skipTest(f"eeg_picard with custom parameters not available: {e}") + result = eeg_picard( + self.test_eeg.copy(), + max_iter=10, # picard uses max_iter, not maxiter + verbose=False, + random_state=42, + ) + + # Check that all ICA fields are present + self.assertIn('icaweights', result) + self.assertIn('icasphere', result) + self.assertIn('icawinv', result) + self.assertIn('icaact', result) + self.assertIn('icachansind', result) def test_eeg_picard_data_integrity(self): """Test that eeg_picard preserves data integrity.""" - try: - original_eeg = self.test_eeg.copy() - result = eeg_picard(original_eeg.copy()) - - # Check that original EEG is not modified - self.assertEqual(original_eeg['nbchan'], self.test_eeg['nbchan']) - self.assertEqual(original_eeg['pnts'], self.test_eeg['pnts']) - self.assertEqual(original_eeg['trials'], self.test_eeg['trials']) - - # Check that result has same basic structure - self.assertEqual(result['nbchan'], self.test_eeg['nbchan']) - self.assertEqual(result['pnts'], self.test_eeg['pnts']) - self.assertEqual(result['trials'], self.test_eeg['trials']) - self.assertEqual(result['srate'], self.test_eeg['srate']) - - except Exception as e: - self.skipTest(f"eeg_picard data integrity not available: {e}") + original_eeg = self.test_eeg.copy() + result = eeg_picard(original_eeg.copy()) + + # Check that original EEG is not modified + self.assertEqual(original_eeg['nbchan'], self.test_eeg['nbchan']) + self.assertEqual(original_eeg['pnts'], self.test_eeg['pnts']) + self.assertEqual(original_eeg['trials'], self.test_eeg['trials']) + + # Check that result has same basic structure + self.assertEqual(result['nbchan'], self.test_eeg['nbchan']) + self.assertEqual(result['pnts'], self.test_eeg['pnts']) + self.assertEqual(result['trials'], self.test_eeg['trials']) + self.assertEqual(result['srate'], self.test_eeg['srate']) + + def test_eeg_picard_does_not_mutate_caller(self): + """eeg_picard must not mutate the caller's data or ICA fields.""" + data_before = self.test_eeg['data'].copy() + # A pre-ICA input has no icaweights yet; "does not mutate" means the + # caller's fields stay the exact objects they were (the result is a copy). + weights_before = self.test_eeg.get('icaweights') + chansind_before = self.test_eeg.get('icachansind') + + eeg_picard(self.test_eeg, posact=True, max_iter=10, random_state=1, verbose=False) + + self.assertTrue(np.array_equal(self.test_eeg['data'], data_before)) + self.assertIs(self.test_eeg.get('icaweights'), weights_before) + self.assertIs(self.test_eeg.get('icachansind'), chansind_before) + + def test_eeg_picard_posact_preserves_unmixing_invariant(self): + """After posact sign flips, icawinv must stay pinv(icaweights @ icasphere).""" + result = eeg_picard(self.test_eeg, posact=True, max_iter=10, random_state=1, verbose=False) + + icaact_2d = result['icaact'].reshape(result['icaact'].shape[0], -1, order='F') + ix = np.argmax(np.abs(icaact_2d), axis=1) + # Every component's max-abs activation must be positive after posact. + self.assertTrue(np.all(icaact_2d[np.arange(icaact_2d.shape[0]), ix] >= 0)) + # icasphere is left untouched by the sign-flip step. + np.testing.assert_array_equal(result['icasphere'], np.eye(self.test_eeg['nbchan'])) + # icawinv stays consistent with the (sign-flipped) unmixing matrix. + np.testing.assert_allclose(result['icawinv'], pinv(result['icaweights'] @ result['icasphere'])) def test_eeg_picard_ica_structure(self): """Test that eeg_picard creates proper ICA structure.""" - try: - result = eeg_picard(self.test_eeg.copy()) - - # Check icasphere is identity matrix - n_chans = self.test_eeg['nbchan'] - expected_icasphere = np.eye(n_chans) - np.testing.assert_array_equal(result['icasphere'], expected_icasphere) + result = eeg_picard(self.test_eeg.copy()) - # Check icachansind contains all channel indices - expected_icachansind = np.arange(n_chans) - np.testing.assert_array_equal(result['icachansind'], expected_icachansind) + # Check icasphere is identity matrix + n_chans = self.test_eeg['nbchan'] + expected_icasphere = np.eye(n_chans) + np.testing.assert_array_equal(result['icasphere'], expected_icasphere) - except Exception as e: - self.skipTest(f"eeg_picard ICA structure not available: {e}") + # Check icachansind contains all channel indices + expected_icachansind = np.arange(n_chans) + np.testing.assert_array_equal(result['icachansind'], expected_icachansind) def test_eeg_picard_matrix_properties(self): """Test mathematical properties of ICA matrices.""" - try: - result = eeg_picard(self.test_eeg.copy()) + result = eeg_picard(self.test_eeg.copy()) - n_chans = self.test_eeg['nbchan'] + n_chans = self.test_eeg['nbchan'] - # Check that icaweights and icawinv are proper matrices - self.assertEqual(result['icaweights'].shape, (n_chans, n_chans)) - self.assertEqual(result['icawinv'].shape, (n_chans, n_chans)) + # Check that icaweights and icawinv are proper matrices + self.assertEqual(result['icaweights'].shape, (n_chans, n_chans)) + self.assertEqual(result['icawinv'].shape, (n_chans, n_chans)) - # Check that matrices are not all zeros - self.assertFalse(np.allclose(result['icaweights'], 0)) - self.assertFalse(np.allclose(result['icawinv'], 0)) + # Check that matrices are not all zeros + self.assertFalse(np.allclose(result['icaweights'], 0)) + self.assertFalse(np.allclose(result['icawinv'], 0)) - # Check that matrices are not all NaN - self.assertFalse(np.any(np.isnan(result['icaweights']))) - self.assertFalse(np.any(np.isnan(result['icawinv']))) - - except Exception as e: - self.skipTest(f"eeg_picard matrix properties not available: {e}") + # Check that matrices are not all NaN + self.assertFalse(np.any(np.isnan(result['icaweights']))) + self.assertFalse(np.any(np.isnan(result['icawinv']))) def test_eeg_picard_ica_activations(self): """Test that ICA activations have correct shape and properties.""" - try: - result = eeg_picard(self.test_eeg.copy()) - - n_chans = self.test_eeg['nbchan'] - n_pnts = self.test_eeg['pnts'] - n_trials = self.test_eeg['trials'] + result = eeg_picard(self.test_eeg.copy()) - # Check shape - self.assertEqual(result['icaact'].shape, (n_chans, n_pnts, n_trials)) + n_chans = self.test_eeg['nbchan'] + n_pnts = self.test_eeg['pnts'] + n_trials = self.test_eeg['trials'] - # Check that activations are not all zeros - self.assertFalse(np.allclose(result['icaact'], 0)) + # Check shape + self.assertEqual(result['icaact'].shape, (n_chans, n_pnts, n_trials)) - # Check that activations are not all NaN - self.assertFalse(np.any(np.isnan(result['icaact']))) + # Check that activations are not all zeros + self.assertFalse(np.allclose(result['icaact'], 0)) - except Exception as e: - self.skipTest(f"eeg_picard ICA activations not available: {e}") + # Check that activations are not all NaN + self.assertFalse(np.any(np.isnan(result['icaact']))) def test_eeg_picard_deterministic(self): """Test that eeg_picard produces deterministic results with fixed random state.""" - try: - # Run twice with same random state - result1 = eeg_picard(self.test_eeg.copy(), random_state=42) - result2 = eeg_picard(self.test_eeg.copy(), random_state=42) + # Run twice with same random state + result1 = eeg_picard(self.test_eeg.copy(), random_state=42) + result2 = eeg_picard(self.test_eeg.copy(), random_state=42) - # Results should be identical - np.testing.assert_array_equal(result1['icaweights'], result2['icaweights']) - np.testing.assert_array_equal(result1['icawinv'], result2['icawinv']) - np.testing.assert_array_equal(result1['icaact'], result2['icaact']) - - except Exception as e: - self.skipTest(f"eeg_picard deterministic test not available: {e}") + # Results should be identical + np.testing.assert_array_equal(result1['icaweights'], result2['icaweights']) + np.testing.assert_array_equal(result1['icawinv'], result2['icawinv']) + np.testing.assert_array_equal(result1['icaact'], result2['icaact']) + @pytest.mark.xfail(reason="exposes product bug tracked in Fable 5 epic #193 (Phase 2/3)", strict=False) def test_eeg_picard_different_random_states(self): """Test that eeg_picard produces different results with different random states.""" - try: - # Run with different random states - result1 = eeg_picard(self.test_eeg.copy(), random_state=42) - result2 = eeg_picard(self.test_eeg.copy(), random_state=123) - - # Results should be different (not identical) - self.assertFalse(np.array_equal(result1['icaweights'], result2['icaweights'])) + # Run with different random states + result1 = eeg_picard(self.test_eeg.copy(), random_state=42) + result2 = eeg_picard(self.test_eeg.copy(), random_state=123) - except Exception as e: - self.skipTest(f"eeg_picard different random states test not available: {e}") + # eeg_picard converges to the same unmixing matrix regardless of the + # random_state seed, so this assertion currently fails: the random_state + # parameter has no observable effect on the result. + self.assertFalse(np.array_equal(result1['icaweights'], result2['icaweights'])) def test_eeg_picard_verbose_parameter(self): """Test eeg_picard with verbose parameter.""" - try: - # Test with verbose=True (should not raise error) - result1 = eeg_picard(self.test_eeg.copy(), verbose=True) - self.assertIn('icaweights', result1) + # Test with verbose=True (should not raise error) + result1 = eeg_picard(self.test_eeg.copy(), verbose=True) + self.assertIn('icaweights', result1) - # Test with verbose=False (should not raise error) - result2 = eeg_picard(self.test_eeg.copy(), verbose=False) - self.assertIn('icaweights', result2) - - except Exception as e: - self.skipTest(f"eeg_picard verbose parameter test not available: {e}") + # Test with verbose=False (should not raise error) + result2 = eeg_picard(self.test_eeg.copy(), verbose=False) + self.assertIn('icaweights', result2) def test_eeg_picard_maxiter_parameter(self): """Test eeg_picard with maxiter parameter.""" - try: - # Test with different maxiter values - result1 = eeg_picard(self.test_eeg.copy(), max_iter=5) - result2 = eeg_picard(self.test_eeg.copy(), max_iter=10) - - # Both should produce valid results - self.assertIn('icaweights', result1) - self.assertIn('icaweights', result2) + # Test with different maxiter values + result1 = eeg_picard(self.test_eeg.copy(), max_iter=5) + result2 = eeg_picard(self.test_eeg.copy(), max_iter=10) - except Exception as e: - self.skipTest(f"eeg_picard maxiter parameter test not available: {e}") + # Both should produce valid results + self.assertIn('icaweights', result1) + self.assertIn('icaweights', result2) def test_eeg_picard_ortho_parameter(self): """Test eeg_picard with ortho parameter.""" - try: - # Test with ortho=True - result1 = eeg_picard(self.test_eeg.copy(), ortho=True) - self.assertIn('icaweights', result1) - - # Test with ortho=False - result2 = eeg_picard(self.test_eeg.copy(), ortho=False) - self.assertIn('icaweights', result2) + # Test with ortho=True + result1 = eeg_picard(self.test_eeg.copy(), ortho=True) + self.assertIn('icaweights', result1) - except Exception as e: - self.skipTest(f"eeg_picard ortho parameter test not available: {e}") + # Test with ortho=False + result2 = eeg_picard(self.test_eeg.copy(), ortho=False) + self.assertIn('icaweights', result2) @unittest.skipIf(os.getenv('EEGPREP_SKIP_MATLAB') == '1', "MATLAB not available") diff --git a/tests/test_eeg_runica.py b/tests/test_eeg_runica.py index 52a25f0d..61c995f1 100644 --- a/tests/test_eeg_runica.py +++ b/tests/test_eeg_runica.py @@ -2,6 +2,7 @@ import numpy as np +from eegprep.functions.miscfunc.pinv import pinv from eegprep.functions.popfunc.eeg_runica import eeg_runica from eegprep.functions.popfunc.pop_runica import pop_runica @@ -43,6 +44,105 @@ def fake_runica(data, **_kwargs): np.testing.assert_array_equal(out["icachansind"], np.array([0, 1])) +def _nonidentity_decomposition(): + """Weights and a non-identity sphere that make activations sign-flippable.""" + weights = np.array([[2.0, 1.0], [-1.0, 3.0]]) + sphere = np.array([[1.5, 0.5], [0.0, 2.0]]) + return weights, sphere + + +def _continuous_eeg(): + data = np.array([[1.0, -2.0, 3.0, -4.0], [-5.0, 6.0, -7.0, 8.0]]) + return { + "data": data, + "nbchan": 2, + "pnts": 4, + "trials": 1, + "srate": 100, + "chanlocs": [], + "icaweights": np.eye(2), + "icasphere": np.eye(2), + } + + +def test_eeg_runica_does_not_mutate_caller(monkeypatch): + eeg = _continuous_eeg() + data_before = eeg["data"].copy() + weights_before = eeg["icaweights"].copy() + sphere_before = eeg["icasphere"].copy() + + weights, sphere = _nonidentity_decomposition() + + def fake_runica(data, **_kwargs): + return weights.copy(), sphere.copy(), np.zeros(2), np.zeros((2, 1)), np.ones((2, 1)), [] + + monkeypatch.setattr("eegprep.functions.popfunc.eeg_runica.runica", fake_runica) + + eeg_runica(eeg, posact=True) + + assert np.array_equal(eeg["data"], data_before) + assert np.array_equal(eeg["icaweights"], weights_before) + assert np.array_equal(eeg["icasphere"], sphere_before) + + +def test_eeg_runica_posact_preserves_ica_invariants(monkeypatch): + eeg = _continuous_eeg() + weights, sphere = _nonidentity_decomposition() + weights_unflipped = weights.copy() + + def fake_runica(data, **_kwargs): + return weights.copy(), sphere.copy(), np.zeros(2), np.zeros((2, 1)), np.ones((2, 1)), [] + + monkeypatch.setattr("eegprep.functions.popfunc.eeg_runica.runica", fake_runica) + + out = eeg_runica(eeg, posact=True) + + data2d = out["data"].reshape(out["nbchan"], -1, order="F") + icaact2d = out["icaact"].reshape(out["icaact"].shape[0], -1, order="F") + + # A posact flip must have occurred for this decomposition (otherwise the + # test would not exercise the invariant-preserving branch). + assert not np.array_equal(out["icaweights"], weights_unflipped) + # Core EEGLAB ICA invariants must hold after sign normalization. + np.testing.assert_allclose(out["icaweights"] @ out["icasphere"] @ data2d[out["icachansind"]], icaact2d) + np.testing.assert_allclose(out["icawinv"], pinv(out["icaweights"] @ out["icasphere"])) + # Every component's max-abs activation must be positive. + ix = np.argmax(np.abs(icaact2d), axis=1) + assert np.all(icaact2d[np.arange(icaact2d.shape[0]), ix] >= 0) + + +def test_finalize_ica_fields_shared_sort_and_sign_normalization(): + """Lock the K4 dedup: runica/AMICA/Picard share one finalize_ica_fields. + + The helper must sort components by descending activation variance, then + sign-normalize while preserving the ICA factorization invariants. + """ + from eegprep.functions.popfunc._ica_utils import finalize_ica_fields + + rng = np.random.default_rng(11) + nbchan, pnts, trials = 4, 12, 3 + sphere = np.eye(nbchan) + weights = rng.standard_normal((nbchan, nbchan)) + winv = pinv(weights @ sphere) + icaact = (weights @ sphere) @ rng.standard_normal((nbchan, pnts * trials)) + icaact = icaact.reshape(nbchan, pnts, trials, order="F") + eeg = { + "icaweights": weights.copy(), + "icasphere": sphere.copy(), + "icawinv": winv.copy(), + "icaact": icaact.copy(), + } + + out = finalize_ica_fields(eeg, sortcomps=True, posact=True) + icaact2d = out["icaact"].reshape(out["icaact"].shape[0], -1, order="F") + + variance_metric = np.sum(out["icawinv"] ** 2, axis=0) * np.sum(icaact2d**2, axis=1) + assert np.all(np.diff(variance_metric) <= 1e-9) + ix = np.argmax(np.abs(icaact2d), axis=1) + assert np.all(icaact2d[np.arange(icaact2d.shape[0]), ix] >= 0) + np.testing.assert_allclose(out["icawinv"], pinv(out["icaweights"] @ out["icasphere"])) + + def test_pop_runica_concatenates_epoched_datasets_in_eeglab_order(monkeypatch): first = _epoched_eeg() second = _epoched_eeg(offset=100) diff --git a/tests/test_eeglabcompat.py b/tests/test_eeglabcompat.py index 8ffd7e79..910cb7f9 100644 --- a/tests/test_eeglabcompat.py +++ b/tests/test_eeglabcompat.py @@ -8,6 +8,7 @@ import os import unittest from copy import deepcopy +from pathlib import Path import numpy as np @@ -22,11 +23,51 @@ from eegprep import clean_artifacts, pop_loadset from eegprep.functions.adminfunc.eeg_checkset import eeg_checkset from eegprep.utils.testing import DebuggableTestCase +import eegprep.functions.adminfunc.eeglabcompat as eeglabcompat # Path to test data LOCAL_DATA_PATH = os.path.join(os.path.dirname(__file__), '../sample_data/') +def test_eeglab_clean_artifacts_roundtrip_uses_private_tempdir(monkeypatch, tmp_path): + paths: dict[str, list[Path]] = {"save": [], "matlab_load": [], "matlab_save": [], "load": []} + + class DummyEeglab: + def pop_loadset(self, filename): + paths["matlab_load"].append(Path(filename)) + return {"loaded": filename} + + def clean_artifacts(self, EEG, *_args): + return {"cleaned": EEG} + + def pop_saveset(self, EEG, filename): + paths["matlab_save"].append(Path(filename)) + Path(filename).write_text("cleaned", encoding="utf-8") + return EEG + + def fake_pop_saveset(EEG, filename): + paths["save"].append(Path(filename)) + Path(filename).write_text("input", encoding="utf-8") + return EEG + + def fake_pop_loadset(filename): + paths["load"].append(Path(filename)) + return {"loaded_from": str(filename)} + + monkeypatch.chdir(tmp_path) + monkeypatch.setattr(eeglabcompat, "get_eeglab", lambda auto_file_roundtrip=False: DummyEeglab()) + monkeypatch.setattr(eeglabcompat, "pop_saveset", fake_pop_saveset) + monkeypatch.setattr(eeglabcompat, "pop_loadset", fake_pop_loadset) + + result = eeglabcompat.clean_artifacts({"data": np.zeros((1, 4))}, BurstCriterion="off") + + assert result["loaded_from"].endswith("output.set") + assert not (tmp_path / "tmp.set").exists() + assert not (tmp_path / "tmp2.set").exists() + assert all(path.parent != tmp_path for values in paths.values() for path in values) + assert {path.name for values in paths.values() for path in values} == {"input.set", "output.set"} + + class TestMatlabWrapper(DebuggableTestCase): """Test cases for MatlabWrapper class.""" diff --git a/tests/test_eegobj.py b/tests/test_eegobj.py index f70c8fbd..c5d125d7 100644 --- a/tests/test_eegobj.py +++ b/tests/test_eegobj.py @@ -144,6 +144,29 @@ def test_forward_nonexistent_function(self): with self.assertRaises(AttributeError): obj.nonexistent_function() + def test_getattr_unknown_name_raises_on_access(self): + """Accessing an unknown name (e.g. a field typo) raises AttributeError immediately. + + A typo like obj.icawnv (for icaweights) must fail fast instead of + silently returning a no-op callable that only errors when called. + """ + eeg = create_test_eeg() + obj = EEGobj(eeg) + + with self.assertRaises(AttributeError): + obj.icawnv # misspelled field name, not an eegprep function + + # hasattr reflects the same contract. + self.assertFalse(hasattr(obj, 'not_a_real_field')) + + def test_getattr_known_function_still_dispatches(self): + """A name that resolves to an eegprep function is still returned as a wrapper.""" + eeg = create_test_eeg(n_channels=4, n_samples=50, srate=100.0, n_trials=3) + obj = EEGobj(eeg) + self.assertTrue(callable(obj.pop_select)) + out = obj.pop_select(channel=[0, 1]) + self.assertEqual(out['nbchan'], 2) + def test_forward_function_returning_tuple(self): """Test method forwarding with function returning tuple.""" eeg = create_test_eeg() diff --git a/tests/test_epoch.py b/tests/test_epoch.py index 2645433e..5ac7c586 100644 --- a/tests/test_epoch.py +++ b/tests/test_epoch.py @@ -162,28 +162,27 @@ def test_functional_3d_epoched_input_same_epoch_constraint(self): # Put a distinctive ramp in epoch 2 so we can detect correct slicing data[0, :, 1] = np.linspace(0, 1, 100) - # Event at 1.2 s relative to concatenated stream: - # global sample for event center = floor(1.2 * 100) = 120 - # Epoch boundaries every 100 points - events = np.array([1.2], dtype=float) + # Event at 1.5 s relative to the concatenated stream: + # global sample for event center = floor(1.5 * 100) = 150 + # Epoch boundaries every 100 points, so the window stays inside epoch 2. + events = np.array([1.5], dtype=float) lim = np.array([-0.2, 0.3], dtype=float) # [-20, +29] samples window inside epoch 2 ep, newtime, idx, _, _, _ = epoch(data, events, lim, srate=srate, verbose='off') # Expect one accepted epoch self.assertTrue(np.array_equal(idx, np.array([0]))) - # The extracted data should match the correct slice from linearized data - # With MATLAB-compatible indexing: event at 1.2s (sample 120) with window [-0.2, 0.3] - # becomes MATLAB indices [101, 149] (1-based), which in Python becomes [99:149] (0-based) - pos0 = int(np.floor(events[0] * srate)) # 120 (0-based) + # The extracted data should match the correct slice from linearized data. + # MATLAB epoch.m uses data(:,posinit:posend) (1-based); Python slices [posinit-1:posend]. + pos0 = int(np.floor(events[0] * srate)) # 150 reallim0 = int(np.round(lim[0] * srate)) # -20 reallim1 = int(np.round(lim[1] * srate - 1)) # 29 - posinit = pos0 + reallim0 # 100 (0-based) - posend = pos0 + reallim1 # 149 (0-based) + posinit = pos0 + reallim0 # 130 (MATLAB 1-based column) + posend = pos0 + reallim1 # 179 (MATLAB 1-based column) # MATLAB slicing: posinit:posend (1-based) becomes [posinit-1:posend] (0-based) - start_global = posinit - 1 # 99 (Python 0-based) - end_global = posend # 149 (Python 0-based exclusive) + start_global = posinit - 1 # 129 (Python 0-based) + end_global = posend # 179 (Python 0-based exclusive) # Extract the expected slice from linearized data (Fortran order) data_linearized = data.reshape(1, -1, order='F') @@ -192,6 +191,12 @@ def test_functional_3d_epoched_input_same_epoch_constraint(self): # newtime should reflect limits divided by srate with the -1 sample convention self.assertTrue(np.allclose(newtime, np.array([lim[0], np.round(lim[1] * srate - 1) / srate]), atol=1e-12)) + # A window that straddles the boundary between two source epochs (event at + # 1.2 s spans columns 99..148, crossing the 100-sample epoch boundary) must + # be rejected, matching EEGLAB's floor((posinit-1)/dataframes) test. + _, _, idx_cross, _, _, _ = epoch(data, np.array([1.2]), lim, srate=srate, verbose='off') + self.assertEqual(idx_cross.size, 0) + def test_functional_valuelim_pass_all(self): srate = 200.0 n_ch, n_samp = 3, 4000 @@ -213,6 +218,29 @@ def test_functional_no_allevents_outputs_empty_lists(self): self.assertEqual(len(alleventout), 0) self.assertEqual(len(alllatencyout), 0) + def test_boundary_first_last_sample(self): + # data values equal their 0-based column index so we can verify exact slices + srate = 100.0 + data = np.arange(500, dtype=float).reshape(1, 500) + lim = np.array([-0.2, 0.5], dtype=float) # reallim = [-20, 49] samples + + # Event whose window would start one sample before the data (posinit==0, + # 0-based) must be rejected like EEGLAB, not produce a negative-index slice. + _, _, idx_before, _, _, _ = epoch(data, np.array([0.2]), lim, srate=srate, verbose='off') + self.assertEqual(idx_before.size, 0) + + # Event whose window starts exactly at the first sample (posinit==1, 1-based) + # must be accepted and yield the true leading samples. + ep_first, _, idx_first, _, _, _ = epoch(data, np.array([0.21]), lim, srate=srate, verbose='off') + self.assertTrue(np.array_equal(idx_first, np.array([0]))) + self.assertTrue(np.allclose(ep_first[0, :3, 0], np.array([0.0, 1.0, 2.0]))) + + # Event whose window ends exactly at the last sample (posend==datawidth) + # must be accepted; the previous off-by-one rejected it. + ep_last, _, idx_last, _, _, _ = epoch(data, np.array([4.51]), lim, srate=srate, verbose='off') + self.assertTrue(np.array_equal(idx_last, np.array([0]))) + self.assertTrue(np.allclose(ep_last[0, -3:, 0], np.array([497.0, 498.0, 499.0]))) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_file_menu_pop_functions.py b/tests/test_file_menu_pop_functions.py index d96af67e..d2d7dd9f 100644 --- a/tests/test_file_menu_pop_functions.py +++ b/tests/test_file_menu_pop_functions.py @@ -47,6 +47,21 @@ def _matlab_string(value): return "'" + str(value).replace("'", "''") + "'" +def test_eeg_from_data_raises_on_ambiguous_tall_array(): + # A tall channel-major array (more channels than samples) must not be + # silently transposed; orientation has to be stated via nbchan. + with pytest.raises(ValueError, match="orientation"): + eeg_from_data(np.zeros((256, 100)), srate=100) + + +def test_eeg_from_data_loads_tall_channel_major_with_nbchan(): + eeg = eeg_from_data(np.arange(256 * 100, dtype=float).reshape(256, 100), srate=100, nbchan=256) + + assert eeg["nbchan"] == 256 + assert eeg["pnts"] == 100 + assert eeg["data"].shape == (256, 100) + + def test_pop_importdata_imports_ascii_data(tmp_path): data_file = tmp_path / "data.tsv" np.savetxt(data_file, np.array([[1, 2, 3], [4, 5, 6]]), delimiter="\t") @@ -78,6 +93,46 @@ def test_pop_fileio_uses_importdata_for_text_arrays(tmp_path): assert command == f"EEG = pop_fileio({_matlab_string(data_file)});" +def test_pop_fileio_imports_plain_mat_data_array(tmp_path): + import scipy.io + + data_file = tmp_path / "raw.mat" + scipy.io.savemat(data_file, {"data": np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])}) + + eeg = pop_fileio(data_file) + + assert eeg["nbchan"] == 2 + assert eeg["pnts"] == 3 + + +def test_pop_fileio_does_not_silently_fall_back_for_eeglab_mat(tmp_path): + import scipy.io + + # An EEGLAB dataset .mat whose data sidecar is missing must surface the load + # failure, not be silently re-imported as a raw MATLAB data array. + dataset = tmp_path / "broken.mat" + scipy.io.savemat( + dataset, + { + "data": "missing.fdt", + "datfile": "missing.fdt", + "nbchan": 4, + "srate": 100, + "pnts": 100, + "trials": 1, + "xmin": 0.0, + "xmax": 1.0, + "setname": "x", + "chanlocs": np.array([]), + "event": np.array([]), + "icachansind": np.array([]), + }, + ) + + with pytest.raises(FileNotFoundError): + pop_fileio(dataset) + + def test_pop_importevent_replaces_and_appends_events(tmp_path): events_file = tmp_path / "events.tsv" events_file.write_text("type\tlatency\tduration\nstim\t1\t0\nresp\t4\t1\n", encoding="utf-8") diff --git a/tests/test_firfilt_helpers.py b/tests/test_firfilt_helpers.py index 64690d35..bddfc6da 100644 --- a/tests/test_firfilt_helpers.py +++ b/tests/test_firfilt_helpers.py @@ -37,6 +37,31 @@ def test_kaiserbeta_matches_eeglab_formula_and_inverse_roundtrips(): assert invkaiserbeta(0) == pytest.approx(10 ** (-21 / 20)) +def test_firws_and_firwsord_are_owned_by_firfilt_plugin(): + """The FIR design lives in the firfilt plugin, not clean_rawdata's private helper. + + Locks the ownership move: ``firfilt/firws.py`` and ``firfilt/firwsord.py`` now + define the canonical implementations, so the design helpers must no longer leak + out of ``clean_rawdata.private.sigproc`` and importers resolve into firfilt. + """ + from eegprep.plugins.clean_rawdata.private import sigproc + from eegprep.plugins.firfilt.firws import firws + from eegprep.plugins.firfilt.firwsord import firwsord + + assert firws.__module__ == "eegprep.plugins.firfilt.firws" + assert firwsord.__module__ == "eegprep.plugins.firfilt.firwsord" + assert not hasattr(sigproc, "firws") + assert not hasattr(sigproc, "firwsord") + + # The firwsord order still feeds firws to produce a valid type-I linear-phase kernel. + fs, cutoff, df = 500.0, 0.5, 1.0 + m, _dev = firwsord("hamming", fs, df) + b, a = firws(m, cutoff / (fs / 2.0), "high") + assert a == 1.0 + assert b.size == m + 1 + np.testing.assert_allclose(b, b[::-1], atol=1e-12) + + def test_invfirwsord_returns_transition_width_and_window_deviation(): df, dev = invfirwsord("hamming", 500, 826) diff --git a/tests/test_gui_long_task.py b/tests/test_gui_long_task.py new file mode 100644 index 00000000..33b33091 --- /dev/null +++ b/tests/test_gui_long_task.py @@ -0,0 +1,153 @@ +import logging +import os +import threading + +import pytest + +import eegprep.functions.guifunc.long_task as long_task_module +from eegprep.functions.guifunc.long_task import run_long_task + + +@pytest.fixture +def qapp(): + os.environ.setdefault("QT_QPA_PLATFORM", "offscreen") + pytest.importorskip("PySide6") + from PySide6 import QtWidgets + + app = QtWidgets.QApplication.instance() or QtWidgets.QApplication([]) + yield app + + +def test_run_long_task_returns_result_and_forwards_progress(qapp): + from PySide6 import QtCore + + loop = QtCore.QEventLoop() + results = [] + errors = [] + finished = [] + + def task(): + logging.getLogger("eegprep.tests").info("worker progress") + return "done" + + handle = run_long_task( + parent=None, + title="Running test task", + label="Running test task.", + task=task, + on_success=results.append, + on_error=errors.append, + on_finished=lambda task_handle: (finished.append(task_handle), loop.quit()), + ) + QtCore.QTimer.singleShot(3000, loop.quit) + loop.exec() + + assert results == ["done"] + assert errors == [] + assert finished == [handle] + assert "worker progress" in handle.dialog.labelText() + + +def test_run_long_task_restores_eegprep_logger_level_after_forwarding_progress(qapp): + from PySide6 import QtCore + + loop = QtCore.QEventLoop() + logger = logging.getLogger("eegprep") + original_level = logger.level + logger.setLevel(logging.WARNING) + + def task(): + logging.getLogger("eegprep.tests").info("worker progress") + return "done" + + try: + handle = run_long_task( + parent=None, + title="Running test task", + label="Running test task.", + task=task, + on_success=lambda _result: None, + on_error=lambda _exc: None, + on_finished=lambda _handle: loop.quit(), + ) + QtCore.QTimer.singleShot(3000, loop.quit) + loop.exec() + + assert "worker progress" in handle.dialog.labelText() + assert logger.level == logging.WARNING + finally: + logger.setLevel(original_level) + + +def test_run_long_task_callbacks_are_delivered_on_main_thread(qapp, monkeypatch): + from PySide6 import QtCore + + loop = QtCore.QEventLoop() + main_thread_id = threading.get_ident() + task_thread_ids = [] + message_thread_ids = [] + success_thread_ids = [] + original_update = long_task_module._update_progress_label + + def update_progress_label(progress, label, message): + message_thread_ids.append(threading.get_ident()) + original_update(progress, label, message) + + def task(): + task_thread_ids.append(threading.get_ident()) + logging.getLogger("eegprep.tests").info("worker progress") + return "done" + + monkeypatch.setattr(long_task_module, "_update_progress_label", update_progress_label) + + run_long_task( + parent=None, + title="Running test task", + label="Running test task.", + task=task, + on_success=lambda _result: success_thread_ids.append(threading.get_ident()), + on_error=lambda _exc: None, + on_finished=lambda _handle: loop.quit(), + ) + QtCore.QTimer.singleShot(3000, loop.quit) + loop.exec() + + assert task_thread_ids + assert task_thread_ids[0] != main_thread_id + assert message_thread_ids + assert all(thread_id == main_thread_id for thread_id in message_thread_ids) + assert success_thread_ids == [main_thread_id] + + +def test_run_long_task_reports_errors(qapp): + from PySide6 import QtCore + + loop = QtCore.QEventLoop() + results = [] + errors = [] + error_thread_ids = [] + main_thread_id = threading.get_ident() + + def task(): + raise ValueError("task failed") + + def on_error(exc): + error_thread_ids.append(threading.get_ident()) + errors.append(exc) + + run_long_task( + parent=None, + title="Running test task", + label="Running test task.", + task=task, + on_success=results.append, + on_error=on_error, + on_finished=lambda _handle: loop.quit(), + ) + QtCore.QTimer.singleShot(3000, loop.quit) + loop.exec() + + assert results == [] + assert len(errors) == 1 + assert str(errors[0]) == "task failed" + assert error_thread_ids == [main_thread_id] diff --git a/tests/test_gui_main_window.py b/tests/test_gui_main_window.py index a9d58b37..fab5d119 100644 --- a/tests/test_gui_main_window.py +++ b/tests/test_gui_main_window.py @@ -1,6 +1,9 @@ +import ast +import inspect import os import logging import sys +import textwrap import unittest from unittest import mock @@ -9,9 +12,11 @@ from eegprep.functions.guifunc.eeglab_menu import eeglab_menus, menu_actions from eegprep.functions.guifunc.menu_actions import ( + IMPLEMENTED_ACTIONS, MenuActionDispatcher, action_kind, ) +from eegprep.functions.guifunc.long_task import LongTaskHandle from eegprep.functions.guifunc.menu_placeholders import is_placeholder_action, placeholder_message from eegprep.functions.guifunc.menu_spec import menu_enabled from eegprep.functions.guifunc.session import EEGPrepSession @@ -329,6 +334,52 @@ def test_file_menu_actions_are_implemented_or_explicit_placeholders(self): ) self.assertEqual(action_kind("pop_fileio_brainvision_mat"), "implemented") + def test_implemented_actions_registry_matches_dispatch_routing(self): + # IMPLEMENTED_ACTIONS gates whether a menu item is shown enabled. An entry + # with no dispatch arm would render enabled yet fall through to the + # show_coming_soon catch-all; a dispatch arm missing from the set would be + # classified "unknown"/placeholder. Keep the two in lockstep. + routed = _dispatch_routed_actions() + self.assertEqual( + sorted(IMPLEMENTED_ACTIONS - routed), + [], + "IMPLEMENTED_ACTIONS entries with no dispatch handler (enabled menu items that no-op)", + ) + self.assertEqual( + sorted(routed - IMPLEMENTED_ACTIONS), + [], + "dispatch handlers missing from IMPLEMENTED_ACTIONS (classified unknown/placeholder)", + ) + + +def _dispatch_routed_actions(): + """Base action names routed to a real handler by ``MenuActionDispatcher.dispatch``.""" + import eegprep.functions.guifunc.menu_actions as menu_actions_module + + source = textwrap.dedent(inspect.getsource(MenuActionDispatcher.dispatch)) + module_sets = {name: value for name, value in vars(menu_actions_module).items() if isinstance(value, set)} + routed: set[str] = set() + for node in ast.walk(ast.parse(source)): + if not isinstance(node, ast.Compare): + continue + left = node.left + if not (isinstance(left, ast.Name) and left.id in ("base", "action")): + continue + operator = node.ops[0] + comparator = node.comparators[0] + if isinstance(operator, ast.Eq) and isinstance(comparator, ast.Constant) and isinstance(comparator.value, str): + routed.add(comparator.value) + elif isinstance(operator, ast.In): + if isinstance(comparator, (ast.Set, ast.List, ast.Tuple)): + routed.update( + element.value + for element in comparator.elts + if isinstance(element, ast.Constant) and isinstance(element.value, str) + ) + elif isinstance(comparator, ast.Name) and comparator.id in module_sets: + routed.update(module_sets[comparator.id]) + return routed + class EEGPrepSessionTests(unittest.TestCase): def test_session_reports_startup_without_data(self): @@ -359,6 +410,41 @@ def test_session_stores_multiple_selected_datasets_back_to_same_indices(self): self.assertEqual(session.CURRENTSET, [1, 2]) self.assertEqual([item["ref"] for item in session.ALLEEG], ["average", "average"]) + def test_apply_workspace_state_rejects_currentset_outside_alleeg_before_mutating(self): + session = EEGPrepSession() + session.store_current(_demo_eeg(), new=True) + original_eeg = session.EEG + original_alleeg = list(session.ALLEEG) + original_currentset = list(session.CURRENTSET) + + with self.assertRaisesRegex(ValueError, "CURRENTSET contains indices outside ALLEEG"): + session.apply_workspace_state(alleeg=[_demo_eeg()], currentset=2) + + self.assertIs(session.EEG, original_eeg) + self.assertEqual(len(session.ALLEEG), len(original_alleeg)) + self.assertIs(session.ALLEEG[0], original_alleeg[0]) + self.assertEqual(session.CURRENTSET, original_currentset) + + def test_apply_workspace_state_rejects_eeg_list_length_mismatch_before_mutating(self): + session = EEGPrepSession() + first = _demo_eeg() + second = _demo_eeg() + second["setname"] = "second" + session.store_current(first, new=True) + session.store_current(second, new=True) + original_eeg = session.EEG + original_alleeg = list(session.ALLEEG) + original_currentset = list(session.CURRENTSET) + + with self.assertRaisesRegex(ValueError, "EEG selection length must match CURRENTSET"): + session.apply_workspace_state(alleeg=[first, second], eeg=[first], currentset=[1, 2]) + + self.assertIs(session.EEG, original_eeg) + self.assertEqual(len(session.ALLEEG), len(original_alleeg)) + self.assertIs(session.ALLEEG[0], original_alleeg[0]) + self.assertIs(session.ALLEEG[1], original_alleeg[1]) + self.assertEqual(session.CURRENTSET, original_currentset) + def test_session_delete_current_selects_remaining_dataset(self): session = EEGPrepSession() first = _demo_eeg() @@ -747,6 +833,73 @@ def test_multiple_dataset_reref_preserves_selection(self): self.assertEqual([item["ref"] for item in session.EEG], ["average", "average"]) self.assertEqual([item["ref"] for item in session.ALLEEG], ["average", "average"]) + def test_cancelled_newset_commit_does_not_pollute_dataset_history(self): + session = EEGPrepSession() + session.store_current(_demo_eeg(), new=True) + original_history = session.EEG.get("history", "") + original_allcom = list(session.ALLCOM) + dispatcher = MenuActionDispatcher(session) + select_command = "EEG = pop_select(EEG, 'point', [1 20]);" + + with ( + mock.patch( + "eegprep.functions.popfunc.pop_select.pop_select", + return_value=(session.EEG, select_command), + ), + # User cancels the pop_newset dialog, so no dataset is committed. + mock.patch( + "eegprep.functions.guifunc.menu_actions.pop_newset", + return_value=(session.ALLEEG, session.EEG, [1], ""), + ), + ): + dispatcher.dispatch("pop_select", parent=object()) + + self.assertEqual(session.EEG.get("history", ""), original_history) + self.assertNotIn(select_command, str(session.EEG.get("history", ""))) + self.assertEqual(session.ALLCOM, original_allcom) + + def test_committed_newset_records_processing_history_on_stored_dataset(self): + session = EEGPrepSession() + session.store_current(_demo_eeg(), new=True) + dispatcher = MenuActionDispatcher(session) + processed = dict(session.EEG, setname="selected") + select_command = "EEG = pop_select(EEG, 'point', [1 20]);" + + with mock.patch( + "eegprep.functions.popfunc.pop_select.pop_select", + return_value=(processed, select_command), + ): + dispatcher.dispatch("pop_select") + + self.assertEqual(session.EEG["setname"], "selected") + self.assertEqual(session.ALLEEG[0]["setname"], "selected") + self.assertIn(select_command, str(session.EEG["history"])) + self.assertIn(select_command, str(session.ALLEEG[0]["history"])) + + def test_headplot_menu_action_commits_spline_file_through_session(self): + session = EEGPrepSession() + session.store_current(_demo_eeg(), new=True) + events = [] + session.add_change_listener(lambda _session: events.append("changed")) + dispatcher = MenuActionDispatcher(session) + headplot_command = "pop_headplot(EEG, 1, [0], '', [1 1]);" + + def fake_pop_headplot(eeg, *, typeplot, return_com): + eeg["splinefile"] = "/tmp/demo.spl" + return (["figure"], headplot_command) + + with mock.patch("eegprep.functions.popfunc.pop_headplot.pop_headplot", side_effect=fake_pop_headplot): + dispatcher.dispatch("pop_headplot") + + self.assertEqual(session.EEG["splinefile"], "/tmp/demo.spl") + self.assertEqual(session.ALLEEG[0]["splinefile"], "/tmp/demo.spl") + self.assertEqual(session.ALLCOM[-1], headplot_command) + self.assertTrue(events, "headplot commit must notify session listeners") + # Committing through store_current records the edit in the dataset + # history; a history-only path would leave the dataset .history untouched. + self.assertIn(headplot_command, str(session.EEG["history"])) + self.assertIn(headplot_command, str(session.ALLEEG[0]["history"])) + def test_resave_updates_single_dataset_metadata_and_saved_state(self): session = EEGPrepSession() session.store_current(_demo_eeg(), new=True) @@ -855,6 +1008,93 @@ def test_new_main_window_pop_actions_dispatch_to_real_wrappers(self): else: self.assertEqual(session.ALLCOM[-1], f"EEG = {action}(EEG);") + def test_gui_pop_runica_runs_ica_in_long_task_before_committing_result(self): + session = EEGPrepSession() + session.store_current(_demo_eeg(), new=True) + refresh = mock.Mock() + dispatcher = MenuActionDispatcher(session, refresh=refresh) + output = dict(session.EEG, setname="ica") + options = { + "icatype": "runica", + "options": {"extended": 1, "interrupt": "on"}, + "reorder": "on", + "chanind": None, + "dataset": None, + "concatenate": "off", + "concatcond": "off", + } + captured = {} + handle = LongTaskHandle(thread=object(), worker=object(), dialog=object()) + events = [] + original_selection = session.EEG + session.add_gui_action_listener(lambda event, action: events.append((event, action))) + + def fake_run_long_task(**kwargs): + captured.update(kwargs) + return handle + + with ( + mock.patch("eegprep.functions.popfunc.pop_runica.pop_runica_gui_options", return_value=options), + mock.patch( + "eegprep.functions.popfunc.pop_runica.pop_runica", + return_value=(output, "EEG = pop_runica(EEG, 'icatype', 'runica', 'extended', 1, 'interrupt', 'on');"), + ) as pop_func, + mock.patch("eegprep.functions.guifunc.menu_actions.run_long_task", side_effect=fake_run_long_task), + ): + dispatcher.dispatch("pop_runica", parent=object()) + self.assertEqual(session.EEG["setname"], "demo") + + result = captured["task"]() + captured["on_success"](result) + captured["on_finished"](handle) + + pop_func.assert_called_once_with(original_selection, gui=False, return_com=True, **options) + self.assertEqual(session.EEG["setname"], "ica") + self.assertEqual(session.ALLEEG[0]["setname"], "ica") + self.assertEqual( + session.ALLCOM[-1], + "EEG = pop_runica(EEG, 'icatype', 'runica', 'extended', 1, 'interrupt', 'on');", + ) + refresh.assert_called_once() + self.assertEqual(events, [("begin", "pop_runica"), ("end", "pop_runica")]) + self.assertEqual(dispatcher._long_tasks, []) + + def test_gui_pop_runica_long_task_error_does_not_mutate_session(self): + session = EEGPrepSession() + session.store_current(_demo_eeg(), new=True) + dispatcher = MenuActionDispatcher(session) + options = { + "icatype": "runica", + "options": {"extended": 1}, + "reorder": "on", + "chanind": None, + "dataset": None, + "concatenate": "off", + "concatcond": "off", + } + captured = {} + handle = LongTaskHandle(thread=object(), worker=object(), dialog=object()) + warnings = [] + + def fake_run_long_task(**kwargs): + captured.update(kwargs) + return handle + + with ( + mock.patch("eegprep.functions.popfunc.pop_runica.pop_runica_gui_options", return_value=options), + mock.patch("eegprep.functions.guifunc.menu_actions.run_long_task", side_effect=fake_run_long_task), + mock.patch.object(dispatcher, "_warn", side_effect=lambda _parent, message: warnings.append(message)), + ): + dispatcher.dispatch("pop_runica", parent=object()) + captured["on_error"](ValueError("runica failed")) + captured["on_finished"](handle) + + self.assertEqual(session.EEG["setname"], "demo") + self.assertEqual(session.ALLEEG[0]["setname"], "demo") + self.assertEqual(session.ALLCOM, []) + self.assertEqual(warnings, ["runica failed"]) + self.assertEqual(dispatcher._long_tasks, []) + def test_gui_transform_action_can_commit_processed_dataset_as_new_set(self): session = EEGPrepSession() session.store_current(_demo_eeg(), new=True) @@ -982,6 +1222,26 @@ def test_pop_topoplot_menu_actions_record_history_without_replacing_dataset(self self.assertIs(session.ALLEEG[0], original_eeg) self.assertEqual(session.ALLCOM[-1], "pop_topoplot(EEG, typeplot=1, items=[0])") + def test_dipfit_mutating_menu_action_updates_session_and_history(self): + session = EEGPrepSession() + session.store_current(_demo_eeg(), new=True) + dispatcher = MenuActionDispatcher(session) + original = session.EEG + fitted = dict(original, setname="dipfit updated") + fitted["dipfit"] = {"model": [{"rv": 0.01}]} + + with mock.patch( + "eegprep.plugins.dipfit.pop_dipfit_gridsearch.pop_dipfit_gridsearch", + return_value=(fitted, "EEG = pop_dipfit_gridsearch(EEG, select=[1]);"), + ) as gridsearch: + dispatcher.dispatch("pop_dipfit_gridsearch") + + gridsearch.assert_called_once_with(original, return_com=True) + self.assertEqual(session.EEG["setname"], "dipfit updated") + self.assertEqual(session.ALLEEG[0]["dipfit"]["model"][0]["rv"], 0.01) + self.assertEqual(session.LASTCOM, "EEG = pop_dipfit_gridsearch(EEG, select=[1]);") + self.assertEqual(session.ALLCOM[-1], "EEG = pop_dipfit_gridsearch(EEG, select=[1]);") + def test_copyset_menu_updates_alleeg_eeg_currentset_and_history(self): session = EEGPrepSession() session.store_current(_demo_eeg(), new=True) @@ -1329,7 +1589,14 @@ def test_file_menu_runscript_updates_currentset_from_namespace(self): qt_widgets = _fake_qt_widgets(open_file="/tmp/script.py") def fake_runscript(_filename, namespace): + self.assertIn("ALLCOM", namespace) + self.assertIn("LASTCOM", namespace) + self.assertIn("CURRENTSTUDY", namespace) namespace["CURRENTSET"] = 2 + namespace["ALLCOM"].append("EEG = script_command(EEG);") + namespace["LASTCOM"] = "EEG = script_command(EEG);" + namespace["STUDY"] = {"name": "script study"} + namespace["CURRENTSTUDY"] = 1 return "LASTCOM = pop_runscript('/tmp/script.py');" with ( @@ -1339,7 +1606,34 @@ def fake_runscript(_filename, namespace): dispatcher.dispatch("pop_runscript") self.assertEqual(session.CURRENTSET, [2]) - self.assertEqual(session.ALLCOM[-1], "LASTCOM = pop_runscript('/tmp/script.py');") + self.assertEqual(session.STUDY["name"], "script study") + self.assertEqual(session.CURRENTSTUDY, 1) + self.assertEqual( + session.ALLCOM, + ["EEG = script_command(EEG);", "LASTCOM = pop_runscript('/tmp/script.py');"], + ) + self.assertEqual(session.LASTCOM, "LASTCOM = pop_runscript('/tmp/script.py');") + + def test_file_menu_runscript_clear_currentset_resets_eeg(self): + session = EEGPrepSession() + session.store_current(_demo_eeg(), new=True) + dispatcher = MenuActionDispatcher(session) + qt_widgets = _fake_qt_widgets(open_file="/tmp/script.py") + + def fake_runscript(_filename, namespace): + namespace["CURRENTSET"] = 0 + return "LASTCOM = pop_runscript('/tmp/script.py');" + + with ( + mock.patch("eegprep.functions.guifunc.menu_actions._require_qt_widgets", return_value=qt_widgets), + mock.patch("eegprep.functions.popfunc.pop_runscript.pop_runscript", side_effect=fake_runscript), + ): + dispatcher.dispatch("pop_runscript") + + self.assertEqual(session.CURRENTSET, []) + self.assertIsInstance(session.EEG, dict) + self.assertEqual(session.EEG.get("setname"), "") + self.assertEqual(session.EEG["data"].size, 0) class QtMainWindowTests(unittest.TestCase): diff --git a/tests/test_gui_pop_runica.py b/tests/test_gui_pop_runica.py index 0b05c8af..21ad625e 100644 --- a/tests/test_gui_pop_runica.py +++ b/tests/test_gui_pop_runica.py @@ -10,7 +10,7 @@ from eegprep.functions.guifunc.spec import controls_by_tag from eegprep.functions.guifunc.qt import QtDialogRenderer from eegprep.functions.popfunc.pop_loadset import pop_loadset -from eegprep.functions.popfunc.pop_runica import pop_runica, pop_runica_dialog_spec +from eegprep.functions.popfunc.pop_runica import pop_runica, pop_runica_dialog_spec, pop_runica_gui_options def _eeg(): @@ -96,6 +96,19 @@ def run(self, spec, initial_values=None): "EEG = pop_runica(EEG, 'icatype', 'runica', 'extended', 1, 'maxsteps', 2, 'interrupt', 'on');", ) + def test_gui_options_do_not_inject_interrupt_for_non_runica_algorithms(self): + class Renderer: + def run(self, spec, initial_values=None): + return {"icatype": 4, "params": "'maxiter', 7", "reorder": True, "chantype": ""} + + options = pop_runica_gui_options(_eeg(), renderer=Renderer()) + + self.assertIsNotNone(options) + assert options is not None + self.assertEqual(options["icatype"], "picard") + self.assertEqual(options["options"], {"maxiter": 7}) + self.assertNotIn("interrupt", options["options"]) + def test_gui_result_runs_runica_and_returns_history(self): class Renderer: def run(self, spec, initial_values=None): @@ -295,13 +308,22 @@ def test_picard_algorithm_routes_to_eeg_picard(self): updated = dict(eeg, icaweights=np.eye(4), icasphere=np.eye(4), icawinv=np.eye(4), icaact=np.zeros((4, 20, 1))) with mock.patch("eegprep.functions.popfunc.pop_runica.eeg_picard", return_value=updated) as picard: - out, com = pop_runica(eeg, icatype="picard", options={"maxiter": 7, "mode": "standard"}, return_com=True) + out, com = pop_runica( + eeg, + icatype="picard", + options={"maxiter": 7, "mode": "standard", "seed": 3}, + return_com=True, + ) picard.assert_called_once() self.assertEqual(picard.call_args.kwargs["max_iter"], 7) + self.assertEqual(picard.call_args.kwargs["random_state"], 3) self.assertFalse(picard.call_args.kwargs["ortho"]) self.assertEqual(out["icaweights"].shape, (4, 4)) - self.assertEqual(com, "EEG = pop_runica(EEG, 'icatype', 'picard', 'maxiter', 7, 'mode', 'standard');") + self.assertEqual( + com, + "EEG = pop_runica(EEG, 'icatype', 'picard', 'maxiter', 7, 'mode', 'standard', 'seed', 3);", + ) def test_unported_ica_algorithm_fails_clearly(self): with self.assertRaisesRegex(NotImplementedError, "not ported"): diff --git a/tests/test_pac_time_frequency_helpers.py b/tests/test_pac_time_frequency_helpers.py index 1727076b..d16acf26 100644 --- a/tests/test_pac_time_frequency_helpers.py +++ b/tests/test_pac_time_frequency_helpers.py @@ -79,6 +79,21 @@ def test_pac_reports_explicit_unsupported_statistics_and_latphase(): assert "not silently emulated" in PAC_UNSUPPORTED_MESSAGE +def test_pac_rejects_unimplemented_plotting_options(): + amp, phase, srate = _coupled_trials() + + with pytest.raises(NotImplementedError, match="not implemented"): + pac(amp, phase, srate, title="my coupling") + with pytest.raises(NotImplementedError, match="not implemented"): + pac(amp, phase, srate, vert=[100.0]) + with pytest.raises(NotImplementedError, match="not implemented"): + pac(amp, phase, srate, newfig="off") + + # Default plotting values still compute PAC without plotting. + result = pac(amp, phase, srate, title="", vert=None, newfig="on", ntimesout=4) + assert isinstance(result, PacResult) + + def test_std_pac_computes_study_cache_and_std_pacplot_reads_it(): study, alleeg = _study_pair() diff --git a/tests/test_phase4_plot_wrappers.py b/tests/test_phase4_plot_wrappers.py index 4bf1fa1a..118f3140 100644 --- a/tests/test_phase4_plot_wrappers.py +++ b/tests/test_phase4_plot_wrappers.py @@ -21,7 +21,12 @@ from eegprep.functions.popfunc.pop_comperp import pop_comperp, pop_comperp_dialog_spec from eegprep.functions.popfunc.pop_envtopo import pop_envtopo from eegprep.functions.popfunc.pop_erpimage import pop_erpimage, pop_erpimage_dialog_spec -from eegprep.functions.popfunc.pop_headplot import pop_headplot, pop_headplot_dialog_spec, _shared_maplimits +from eegprep.functions.popfunc.pop_headplot import ( + pop_headplot, + pop_headplot_dialog_spec, + _current_spline_file, + _shared_maplimits, +) from eegprep.functions.popfunc.pop_loadset import pop_loadset from eegprep.functions.popfunc.pop_epoch import pop_epoch from eegprep.functions.popfunc._plot_utils import component_activations @@ -84,6 +89,31 @@ def test_pop_spectopo_plots_sample_data_headlessly(sample_eeg): plt.close(result["figure"]) +def test_pop_spectopo_component_default_controls_succeed(ica_epoch): + result, command = pop_spectopo( + ica_epoch, dataflag=0, freqs=[10], plotchan=0, icamode=True, icacomps=[1, 2], nicamaps=2, return_com=True + ) + + assert result["spectra"].shape[0] == 2 + assert "pop_spectopo(EEG" in command + plt.close(result["figure"]) + + +def test_pop_spectopo_rejects_nondefault_plotchan(ica_epoch): + with pytest.raises(ValueError, match="whole-scalp component spectra"): + pop_spectopo(ica_epoch, dataflag=0, freqs=[10], plotchan=3, icacomps=[1, 2]) + + +def test_pop_spectopo_rejects_max_power_plotchan(ica_epoch): + with pytest.raises(ValueError, match="whole-scalp component spectra"): + pop_spectopo(ica_epoch, dataflag=0, freqs=[10], plotchan=[], icacomps=[1, 2]) + + +def test_pop_spectopo_rejects_datacomp_icamode(ica_epoch): + with pytest.raises(ValueError, match="component spectra"): + pop_spectopo(ica_epoch, dataflag=0, freqs=[10], icamode=False, icacomps=[1, 2]) + + def test_pop_prop_plots_sample_channel_properties(sample_eeg): figure, command = pop_prop(sample_eeg, typecomp=1, chanorcomp=1, return_com=True) @@ -103,12 +133,30 @@ def test_pop_headplot_plots_sample_latency_map_with_spline_setup(sample_eeg, tmp assert len(figures) == 1 assert figures[0].axes[0].name == "3d" assert splinefile.exists() - assert eeg["splinefile"] == str(splinefile) + assert _current_spline_file(eeg, 1) != str(splinefile) # plot must not mutate the caller's EEG assert "setup={" in command _assert_python_command(command) plt.close(figures[0]) +def test_pop_headplot_does_not_mutate_caller_eeg(sample_eeg, tmp_path): + eeg = deepcopy(sample_eeg) + before = deepcopy(eeg) + setup = { + "splinefile": str(tmp_path / "nomutate.spl"), + "transform": [0, -10, 0, -0.1, 0, -1.6, 1100, 1100, 1100], + } + + figures = pop_headplot(eeg, typeplot=1, items=[0], setup=setup) + + assert set(eeg.keys()) == set(before.keys()) # no new spline/mesh keys added + assert np.array_equal(np.asarray(eeg["splinefile"]), np.asarray(before["splinefile"])) + assert "headplotmeshfile" not in eeg + assert np.array_equal(np.asarray(eeg["data"]), np.asarray(before["data"])) + for fig in figures: + plt.close(fig) + + def test_pop_headplot_single_map_has_eeglab_like_title_and_surface(sample_eeg, tmp_path): eeg = deepcopy(sample_eeg) title = "ERP scalp maps of dataset: eeglab_data" @@ -303,7 +351,7 @@ def test_pop_headplot_component_path_creates_ica_spline(ica_epoch, tmp_path): ) assert len(figures) == 1 - assert eeg["icasplinefile"] == str(splinefile) + assert _current_spline_file(eeg, 0) != str(splinefile) # plot must not mutate the caller's EEG assert "IC 2" in figures[0].axes[1].get_title() _assert_python_command(command) plt.close(figures[0]) @@ -383,7 +431,8 @@ def run(self, spec, initial_values=None): figures, command = pop_headplot(eeg, typeplot=1, gui=True, renderer=renderer, return_com=True) assert renderer.spec.title == "ERP head plot(s) -- pop_headplot()" - assert Path(eeg["splinefile"]).exists() + assert (tmp_path / "gui.spl").exists() + assert _current_spline_file(eeg, 1) == "" # plot must not mutate the caller's EEG assert "setup={" in command assert "electrodes='off'" in command _assert_python_command(command) @@ -1247,6 +1296,40 @@ def _assert_python_command(command: str) -> None: ast.parse(command) +def test_component_activations_dedup_contract(): + """Lock the K4 dedup: rejection delegates recompute to the canonical helper. + + The rejection ``component_activations`` (``_rejection``) and the canonical + plotting helper (``_plot_utils``) must agree when recomputing from weights, + and rejection must ignore a stored ``icaact`` while plotting trusts it. + """ + from eegprep.functions.popfunc._rejection import component_activations as rejection_activations + + rng = np.random.default_rng(7) + nbchan, pnts, trials = 5, 16, 4 + data = rng.standard_normal((nbchan, pnts, trials)) + weights = rng.standard_normal((nbchan, nbchan)) + sphere = rng.standard_normal((nbchan, nbchan)) + recompute = (weights @ sphere) @ data.reshape(nbchan, -1, order="F") + stored = -recompute.reshape(nbchan, pnts, trials, order="F") + eeg = { + "data": data, + "icaweights": weights, + "icasphere": sphere, + "icachansind": np.arange(nbchan), + "nbchan": nbchan, + "pnts": pnts, + "trials": trials, + "icaact": stored, + } + + plot_recompute = component_activations(eeg, use_stored=False) + assert np.allclose(rejection_activations(eeg), plot_recompute) + # Rejection ignores the stored icaact; the default plotting path trusts it. + assert not np.allclose(rejection_activations(eeg), stored) + assert np.allclose(component_activations(eeg), stored) + + def _matlab_string(path: Any) -> str: return str(path).replace("'", "''") diff --git a/tests/test_phase4_timefreq_statistics.py b/tests/test_phase4_timefreq_statistics.py index a188d02f..eafacd20 100644 --- a/tests/test_phase4_timefreq_statistics.py +++ b/tests/test_phase4_timefreq_statistics.py @@ -39,8 +39,15 @@ from eegprep.functions.timefreqfunc.dftfilt import dftfilt from eegprep.functions.timefreqfunc.dftfilt2 import dftfilt2 from eegprep.functions.timefreqfunc.dftfilt3 import dftfilt3 +from eegprep.functions.statistics.fdr import fdr +from eegprep.functions.timefreqfunc.newcrossf import _is_on as newcrossf_is_on +from eegprep.functions.timefreqfunc.newcrossf import _threshold_vector as newcrossf_threshold_vector +from eegprep.functions.timefreqfunc.newcrossf import _upper_thresholds_by_frequency from eegprep.functions.timefreqfunc.newcrossf import newcrossf -from eegprep.functions.timefreqfunc.newtimef import newtimef +from eegprep.functions.timefreqfunc.newtimef import _is_on as newtimef_is_on +from eegprep.functions.timefreqfunc.newtimef import _significance_mask, _thresholds_by_frequency +from eegprep.functions.timefreqfunc.newtimef import _threshold_vector as newtimef_threshold_vector +from eegprep.functions.timefreqfunc.newtimef import compute_time_frequency, newtimef from eegprep.functions.timefreqfunc.newtimefbaseln import newtimefbaseln from eegprep.functions.timefreqfunc.newtimefpowerunit import newtimefpowerunit from eegprep.functions.timefreqfunc.rsadjust import rsadjust @@ -91,6 +98,31 @@ def test_newtimef_rejects_unknown_options(): newtimef(signal, 128, [0, 1000], 128, 0, unsupported_option=1) +def test_newtimef_is_on_uses_whitelist_semantics(): + assert newtimef_is_on("on") is True + assert newtimef_is_on("yes") is True + assert newtimef_is_on(1) is True + # Unrecognized values are treated as OFF, matching the canonical is_on. + assert newtimef_is_on("yes-please") is False + assert newtimef_is_on("display") is False + assert newtimef_is_on("off") is False + + +def test_newtimef_fails_loudly_on_unimplemented_overlap_and_plotphase(): + signal = np.sin(2 * np.pi * 10 * np.arange(128) / 128) + + with pytest.raises(NotImplementedError, match="overlap"): + newtimef(signal, 128, [0, 1000], 128, 0, plot="off", overlap=2) + with pytest.raises(NotImplementedError, match="plotphase"): + newtimef(signal, 128, [0, 1000], 128, 0, plot="off", plotphase="on") + with pytest.raises(NotImplementedError, match="overlap"): + compute_time_frequency(signal, 128, [0, 1000], 128, 0, overlap=2) + + # Default values still compute without raising. + result = newtimef(signal, 128, [0, 1000], 128, 0, plot="off", overlap=None, plotphase="off") + assert result.ersp.shape == result.itc.shape + + def test_newtimef_nonzero_cycles_use_wavelet_time_grid(sample_epoch): result = pop_newtimef(sample_epoch, 1, 1, [-100, 200], [3, 0.8], plot="off") @@ -524,6 +556,46 @@ def statistic(left, right): np.testing.assert_allclose(randomized.surrogates, np.broadcast_to(np.abs(complex_values), (5, 2, 3))) +def test_timefreq_threshold_helpers_pool_through_canonical_bootstrap_threshold(): + # newtimef/newcrossf no longer re-sort surrogates; they pool (naccu x baseline) + # per frequency and delegate the percentile/tail math to bootstrap_threshold. + rng = np.random.default_rng(7) + surrogates = rng.normal(size=(24, 3, 5)) + + pooled = surrogates.transpose(0, 2, 1).reshape(-1, surrogates.shape[1]) + expected_both = bootstrap_threshold(pooled, alpha=0.1, bootside="both") + expected_upper = bootstrap_threshold(pooled, alpha=0.1, bootside="upper") + + np.testing.assert_allclose(_thresholds_by_frequency(surrogates, alpha=0.1, both=True), expected_both) + np.testing.assert_allclose(_thresholds_by_frequency(surrogates, alpha=0.1, both=False), expected_upper) + np.testing.assert_allclose(_upper_thresholds_by_frequency(surrogates, alpha=0.1), expected_upper) + + # Single-frequency case keeps the original (nfreq,) / (nfreq, 2) shapes. + single = rng.normal(size=(24, 1, 5)) + assert _thresholds_by_frequency(single, alpha=0.1, both=True).shape == (1, 2) + assert _thresholds_by_frequency(single, alpha=0.1, both=False).shape == (1,) + assert _upper_thresholds_by_frequency(single, alpha=0.1).shape == (1,) + + +def test_newtimef_fdr_branch_matches_canonical_fdr_threshold(): + rng = np.random.default_rng(11) + pvalues = rng.random(size=(4, 6)) + + threshold = float(fdr(pvalues, 0.1).threshold) + expected = np.zeros_like(pvalues, dtype=bool) if threshold == 0 else pvalues <= threshold + + np.testing.assert_array_equal(_significance_mask(pvalues, 0.1, "fdr"), expected) + np.testing.assert_array_equal(_significance_mask(pvalues, 0.1, "none"), pvalues <= 0.1) + + +def test_timefreq_is_on_and_threshold_vector_are_the_canonical_shared_helpers(): + # newcrossf reuses newtimef's threshold helper and the canonical is_on whitelist. + assert newcrossf_threshold_vector is newtimef_threshold_vector + assert newcrossf_is_on is newtimef_is_on + assert newcrossf_is_on("on") is True + assert newcrossf_is_on("display") is False + + def test_bootstat_basevect_uses_eeglab_one_based_indices(): data = np.arange(12, dtype=float).reshape(2, 3, 2) seen = [] diff --git a/tests/test_phase7_long_tail_helpers.py b/tests/test_phase7_long_tail_helpers.py index 96e8ae3f..d931c265 100644 --- a/tests/test_phase7_long_tail_helpers.py +++ b/tests/test_phase7_long_tail_helpers.py @@ -9,6 +9,7 @@ from eegprep.functions.popfunc.pop_fusechanrej import pop_fusechanrej from eegprep.functions.popfunc.pop_icathresh import pop_icathresh from eegprep.functions.popfunc.pop_rejchanspec import pop_rejchanspec +from eegprep.functions.popfunc.pop_chansel import pop_chansel_resolve from eegprep.functions.popfunc.pop_topochansel import pop_topochansel from eegprep.functions.sigprocfunc.eegthresh import eegthresh from eegprep.functions.sigprocfunc.ica_helpers import compvar, eeg_getica, eeg_pvaf, icaact, icaproj, icavar @@ -159,6 +160,15 @@ def test_pop_topochansel_resolves_indices_and_labels_without_gui(): assert command.startswith("pop_topochansel(") +def test_pop_topochansel_uses_canonical_chansel_resolver(): + # The non-GUI selection resolution must match pop_chansel's resolver so the + # two former parsers cannot drift (e.g. on comma-separated labels). + chanlocs = [{"labels": "Fz"}, {"labels": "Cz"}, {"labels": "Pz"}] + chanlist, _names, _text = pop_topochansel(chanlocs, "Pz, Fz", gui=False) + _values, expected = pop_chansel_resolve(chanlocs, "Pz, Fz") + assert chanlist == expected + + def test_ica_helpers_match_simple_projection_identities(): data = np.array([[1.0, 2.0, 3.0], [0.0, 0.0, 0.0]]) weights = np.eye(2) diff --git a/tests/test_plugin_menu.py b/tests/test_plugin_menu.py index 8b1b5e60..466abe03 100644 --- a/tests/test_plugin_menu.py +++ b/tests/test_plugin_menu.py @@ -50,6 +50,12 @@ def test_bundled_plugins_match_extension_registry_records() -> None: assert [plugin["funcname"] for plugin in plugins] == [ record.spec.pop_functions[0].name for record in records if record.spec is not None ] + # Registry-owned metadata must come from the registry so it cannot drift. + for plugin, record in zip(plugins, records): + assert record.spec is not None + assert plugin["description"] == record.spec.description + assert plugin["version"] == record.spec.version + assert plugin["tags"] == record.spec.capabilities def test_bundled_plugins_returns_copies() -> None: diff --git a/tests/test_pop_editeventvals.py b/tests/test_pop_editeventvals.py new file mode 100644 index 00000000..43657cf4 --- /dev/null +++ b/tests/test_pop_editeventvals.py @@ -0,0 +1,73 @@ +"""Tests for pop_editeventvals latency display/edit symmetry.""" + +from copy import deepcopy + +import numpy as np +import pytest + +from eegprep.functions.popfunc.pop_editeventvals import ( + pop_editeventvals, + pop_editeventvals_dialog_spec, + _display_event_value, + _display_field_label, +) + + +def _epoched_eeg(): + """Epoched dataset: 2 trials, 50 points/epoch, srate 100, xmin -0.1 s.""" + return { + "data": np.zeros((2, 50, 2), dtype=float), + "nbchan": 2, + "pnts": 50, + "trials": 2, + "srate": 100.0, + "xmin": -0.1, + "xmax": 0.39, + "times": np.arange(50, dtype=float), + "chanlocs": [{"labels": "Cz"}, {"labels": "Pz"}], + "chaninfo": {}, + # latencies stored in points across concatenated epochs (1-based) + "event": [ + {"type": "stim", "latency": 15.0, "duration": 0.0, "epoch": 1}, + {"type": "resp", "latency": 70.0, "duration": 0.0, "epoch": 2}, + ], + "urevent": [], + "epoch": [{}, {}], + "reject": {}, + "stats": {}, + "icaweights": np.array([]), + "icasphere": np.array([]), + "icawinv": np.array([]), + "icaact": np.array([]), + "icachansind": np.array([], dtype=int), + "history": "", + "saved": "no", + } + + +def test_epoched_latency_display_matches_ms_label(): + """The epoched latency dialog label says ms, so the displayed value must be ms.""" + eeg = _epoched_eeg() + assert _display_field_label(eeg, "latency") == "Latency (ms)" + # point 15 in epoch 1 with xmin -0.1 s -> (15-1)/100 - 0.1 = 0.04 s = 40 ms + assert _display_event_value(eeg, eeg["event"][0], "latency") == 40.0 + # the dialog spec exposes the same ms value, not the raw stored point count + spec = pop_editeventvals_dialog_spec(eeg) + latency_edit = next(control for control in spec.controls if control.tag == "field_latency") + assert latency_edit.value == 40.0 + + +def test_epoched_latency_edit_is_symmetric_with_display(): + """Editing an epoched latency to its displayed ms value leaves the stored point unchanged.""" + eeg = _epoched_eeg() + displayed = _display_event_value(eeg, eeg["event"][0], "latency") + out = pop_editeventvals(deepcopy(eeg), "changefield", [1, "latency", displayed]) + assert out["event"][0]["latency"] == pytest.approx(eeg["event"][0]["latency"]) + + +def test_epoched_latency_edit_round_trips_through_display(): + """Writing a new ms latency then reading it back yields the same ms value.""" + eeg = _epoched_eeg() + new_ms = 60.0 + out = pop_editeventvals(deepcopy(eeg), "changefield", [1, "latency", new_ms]) + assert _display_event_value(out, out["event"][0], "latency") == new_ms diff --git a/tests/test_pop_epoch.py b/tests/test_pop_epoch.py index 886bedfe..7e241076 100644 --- a/tests/test_pop_epoch.py +++ b/tests/test_pop_epoch.py @@ -6,6 +6,8 @@ MATLAB EEGLAB's pop_epoch function across all tested scenarios. """ +import contextlib +import io import os import numpy as np import unittest @@ -390,6 +392,27 @@ class TestPopEpochEdgeCases(unittest.TestCase): def setUp(self): np.random.seed(42) + def test_pop_epoch_does_not_print_to_stdout(self): + EEG = { + 'data': np.random.randn(2, 500).astype(np.float32), + 'srate': 100.0, + 'nbchan': 2, + 'pnts': 500, + 'trials': 1, + 'xmin': 0.0, + 'xmax': 4.99, + 'setname': 'stdout_test', + 'event': [{'type': 'stim', 'latency': 250}], + 'epoch': [], + 'saved': 'no', + } + stream = io.StringIO() + + with contextlib.redirect_stdout(stream): + pop_epoch(EEG, 'stim', [-0.1, 0.1]) + + self.assertEqual(stream.getvalue(), "") + def test_boundary_events_near_edges(self): """Test epoching when events are near data boundaries""" # Create EEG with events near boundaries diff --git a/tests/test_pop_loadset.py b/tests/test_pop_loadset.py index e535fae6..ff1bd737 100644 --- a/tests/test_pop_loadset.py +++ b/tests/test_pop_loadset.py @@ -1,6 +1,9 @@ +import h5py import numpy as np +import pytest from eegprep.functions.popfunc.pop_loadset import pop_loadset +from eegprep.functions.popfunc.pop_loadset_h5 import pop_loadset_h5 def test_pop_loadset_marks_loaded_dataset_justloaded(): @@ -35,3 +38,57 @@ def test_pop_loadset_hdf5_fallback_does_not_subtract_icachansind_twice(): assert np.issubdtype(eeg["icachansind"].dtype, np.integer) assert eeg["icachansind"][0] == 0 + + +def test_pop_loadset_corrupt_v7_set_surfaces_real_error_not_h5py(tmp_path): + # A corrupt non-HDF5 .set must surface the real scipy parse error, not be silently + # rerouted to the HDF5 loader where h5py raises a cryptic "file signature not found". + corrupt = tmp_path / "corrupt.set" + corrupt.write_bytes(b"MATLAB 5.0 MAT-file, corrupt" + b"\x00" * 100 + b"\xff" * 50) + + with pytest.raises(Exception) as excinfo: + pop_loadset(str(corrupt)) + + message = str(excinfo.value).lower() + assert "file signature not found" not in message + assert "mat file" in message or "unknown mat file" in message + + +def test_pop_loadset_routes_hdf5_file_to_h5_loader(tmp_path): + # A genuine HDF5 .set is detected by its signature and loaded by the HDF5 path. + filepath = tmp_path / "hdf5.set" + with h5py.File(filepath, "w") as f: + eeg_group = f.create_group("EEG") + eeg_group.create_dataset("srate", data=np.array([[500.0]])) + eeg_group.create_dataset("nbchan", data=np.array([[4]])) + eeg_group.create_dataset("pnts", data=np.array([[100]])) + eeg_group.create_dataset("trials", data=np.array([[1]])) + eeg_group.create_dataset("xmin", data=np.array([[-1.0]])) + eeg_group.create_dataset("xmax", data=np.array([[1.0]])) + eeg_group.create_dataset("data", data=np.zeros((4, 100), dtype=np.float32)) + + eeg = pop_loadset(str(filepath)) + + assert eeg["nbchan"] == 4 + assert eeg["data"].shape == (4, 100) + + +def test_pop_loadset_h5_unicode_decodes_via_general_path(tmp_path): + # The general uint16 -> UTF-8 decode (no hard-coded emoji branch) returns the + # correct character. The bytes below are UTF-8 for U+1F496 (sparkling heart). + filepath = tmp_path / "unicode.set" + with h5py.File(filepath, "w") as f: + eeg_group = f.create_group("EEG") + unicode_bytes = np.array([104, 101, 108, 108, 111, 32, 240, 159, 146, 150], dtype=np.uint16) + eeg_group.create_dataset("unicode_string", data=unicode_bytes) + eeg_group.create_dataset("srate", data=np.array([[500.0]])) + eeg_group.create_dataset("nbchan", data=np.array([[4]])) + eeg_group.create_dataset("pnts", data=np.array([[100]])) + eeg_group.create_dataset("trials", data=np.array([[1]])) + eeg_group.create_dataset("xmin", data=np.array([[-1.0]])) + eeg_group.create_dataset("xmax", data=np.array([[1.0]])) + eeg_group.create_dataset("data", data=np.zeros((4, 100), dtype=np.float32)) + + eeg = pop_loadset_h5(str(filepath)) + + assert eeg["unicode_string"] == "hello \U0001f496" diff --git a/tests/test_pop_loadset_h5.py b/tests/test_pop_loadset_h5.py index 888c993e..a8461e05 100644 --- a/tests/test_pop_loadset_h5.py +++ b/tests/test_pop_loadset_h5.py @@ -357,8 +357,8 @@ def test_unicode_strings(self): with h5py.File(filepath, 'w') as f: eeg_group = f.create_group('EEG') - # Create Unicode string data (special case handled in pop_loadset_h5) - unicode_ascii = np.array([104, 101, 108, 108, 111, 32, 240, 159, 146, 150], dtype=np.uint16) # "hello 👖" + # Create Unicode string data: UTF-8 bytes for "hello 💖" stored one byte per uint16 + unicode_ascii = np.array([104, 101, 108, 108, 111, 32, 240, 159, 146, 150], dtype=np.uint16) # "hello 💖" eeg_group.create_dataset('unicode_string', data=unicode_ascii) # Add minimal required data to prevent eeg_checkset from failing @@ -372,8 +372,8 @@ def test_unicode_strings(self): EEG = pop_loadset_h5(filepath) - # Should handle Unicode strings (special case in the code) - self.assertEqual(EEG['unicode_string'], 'hello 👖') + # The general uint16 -> UTF-8 decode path returns the correct emoji. + self.assertEqual(EEG['unicode_string'], 'hello \U0001f496') @unittest.skipIf(os.getenv('EEGPREP_SKIP_MATLAB') == '1', "MATLAB not available") diff --git a/tests/test_pop_newset.py b/tests/test_pop_newset.py index ec3172ce..d135041f 100644 --- a/tests/test_pop_newset.py +++ b/tests/test_pop_newset.py @@ -4,6 +4,8 @@ import numpy as np import pytest +from eegprep.functions.guifunc.qt import QtDialogRenderer +from eegprep.functions.guifunc.spec import controls_by_tag from eegprep.functions.popfunc.eeg_emptyset import eeg_emptyset from eegprep.functions.popfunc.pop_newset import pop_newset, pop_newset_dialog_spec @@ -100,6 +102,95 @@ def test_pop_newset_dialog_old_dataset_prompt_hides_currentset_index(): assert "What do you want to do with the old dataset 1 (not modified since last saved)?" not in labels +def test_pop_newset_dialog_edit_description_opens_multiline_editor(): + eeg = _eeg(name="processed") + eeg["comments"] = ["first line", "second line"] + + control = controls_by_tag(pop_newset_dialog_spec(eeg, 1))["editdescription"] + + assert control.callback is not None + assert control.callback.name == "edit_text" + assert control.callback.params == { + "button": "editdescription", + "target": "editdescription", + "title": "Edit description", + "label": "Dataset description:", + "value": "first line\nsecond line", + } + + +def test_qt_edit_text_callback_stores_accepted_text(monkeypatch): + class QInputDialog: + @staticmethod + def getMultiLineText(_parent, title, label, text): + calls.append((title, label, text)) + return "", True + + class Widget: + def __init__(self): + self.properties = {} + + def property(self, name): + return self.properties.get(name) + + def setProperty(self, name, value): + self.properties[name] = value + + calls = [] + target = Widget() + QtWidgets = type("QtWidgets", (), {"QInputDialog": QInputDialog}) + monkeypatch.setattr("eegprep.functions.guifunc.qt._require_qt", lambda: (None, QtWidgets)) + + QtDialogRenderer._edit_text( + object(), + target, + {"title": "Edit description", "label": "Dataset description:", "value": "old notes"}, + ) + + assert calls == [("Edit description", "Dataset description:", "old notes")] + assert QtDialogRenderer._read_widget(target) == "" + + +def test_pop_newset_gui_description_button_value_updates_comments(): + class Renderer: + def run(self, _spec, initial_values=None): + return {"setname": "processed", "editdescription": "edited notes", "overwrite": 1} + + alleeg, current, current_set, _command = pop_newset([], _eeg(name="original"), 0) + + alleeg, current, current_set, command = pop_newset( + alleeg, _eeg(name="processed"), current_set, "gui", "on", renderer=Renderer() + ) + + assert len(alleeg) == 2 + assert current_set == 2 + assert current["comments"] == "edited notes" + assert command == ( + "[ALLEEG EEG CURRENTSET] = pop_newset(ALLEEG, EEG, CURRENTSET, " + "'setname', 'processed', 'comments', 'edited notes', 'overwrite', 'off');" + ) + + +def test_pop_newset_gui_untouched_description_button_preserves_comments(): + class Renderer: + def run(self, _spec, initial_values=None): + return {"setname": "processed", "editdescription": False, "overwrite": 1} + + alleeg, current, current_set, _command = pop_newset([], _eeg(name="original"), 0) + current["comments"] = "old notes" + processed = _eeg(name="processed") + processed["comments"] = "old notes" + + alleeg, current, current_set, command = pop_newset(alleeg, processed, current_set, "gui", "on", renderer=Renderer()) + + assert len(alleeg) == 2 + assert current_set == 2 + assert current["comments"] == "old notes" + assert command == ( + "[ALLEEG EEG CURRENTSET] = pop_newset(ALLEEG, EEG, CURRENTSET, 'setname', 'processed', 'overwrite', 'off');" + ) + + def test_pop_newset_gui_choice_can_overwrite_current_dataset(): class Renderer: def run(self, _spec, initial_values=None): diff --git a/tests/test_pop_saveset.py b/tests/test_pop_saveset.py index c64352d7..7aca2dfd 100644 --- a/tests/test_pop_saveset.py +++ b/tests/test_pop_saveset.py @@ -1,5 +1,10 @@ import os +import tempfile import unittest + +import numpy as np +import scipy.io + from eegprep import pop_loadset, pop_saveset # Explicitly import pop_resample @@ -38,6 +43,31 @@ def test_basic(self): pop_saveset( self.EEG, os.path.join(local_url, 'eeglab_data_tmp.set') ) # see MATLAB code to compare the results at the end of the file + + def test_saveset_does_not_mutate_caller_indices(self): + EEG = pop_loadset(os.path.join(local_url, 'eeglab_data_with_ica_tmp.set')) + + chanlocs = EEG['chanlocs'] + events = EEG['event'] + urchan_before = [int(c['urchan']) for c in chanlocs if 'urchan' in c] + urevent_before = [int(ev['urevent']) for ev in events if 'urevent' in ev] + latency_before = [float(ev['latency']) for ev in events if 'latency' in ev] + self.assertTrue(urchan_before, "Dataset must have urchan values to exercise the regression") + self.assertTrue(urevent_before, "Dataset must have urevent values to exercise the regression") + + with tempfile.TemporaryDirectory() as tmp: + out = os.path.join(tmp, 'roundtrip.set') + pop_saveset(EEG, out) + # A second save must not double-increment the caller's in-memory indices. + pop_saveset(EEG, out) + + urchan_after = [int(c['urchan']) for c in EEG['chanlocs'] if 'urchan' in c] + urevent_after = [int(ev['urevent']) for ev in EEG['event'] if 'urevent' in ev] + latency_after = [float(ev['latency']) for ev in EEG['event'] if 'latency' in ev] + + self.assertEqual(urchan_after, urchan_before) # still 0-based in memory + self.assertEqual(urevent_after, urevent_before) + np.testing.assert_array_equal(latency_after, latency_before) # """Test basic resampling functionality with different engines""" # # Apply resampling with different engines # EEG_python = pop_resample(self.EEG.copy(), self.new_freq, engine='scipy') @@ -82,6 +112,62 @@ def test_basic(self): # rtol=1e-5, atol=1e-8, # err_msg='ICA activations differ between Python and Octave') + def test_chanlocs_serialized_through_single_converter(self): + # The primary EEG.chanlocs struct must be serialized through the same + # canonical converter as chaninfo.removedchans, so a chanloc field such + # as ``unit`` is preserved instead of being dropped by a second copy. + chanlocs = [ + { + 'labels': 'Fz', + 'theta': 0.0, + 'radius': 0.5, + 'X': 1.0, + 'Y': 2.0, + 'Z': 3.0, + 'sph_theta': 0.0, + 'sph_phi': 0.0, + 'sph_radius': 1.0, + 'type': 'EEG', + 'urchan': 0, + 'ref': np.array([]), + 'unit': 'uV', + }, + { + 'labels': 'Cz', + 'theta': 10.0, + 'radius': 0.6, + 'X': 1.5, + 'Y': 2.5, + 'Z': 3.5, + 'sph_theta': 1.0, + 'sph_phi': 1.0, + 'sph_radius': 1.0, + 'type': 'EEG', + 'urchan': 1, + 'ref': np.array([]), + 'unit': 'uV', + }, + ] + EEG = { + 'setname': 't', + 'nbchan': 2, + 'trials': 1, + 'pnts': 4, + 'srate': 100.0, + 'xmin': 0.0, + 'xmax': 0.03, + 'times': np.arange(4) / 100.0, + 'data': np.zeros((2, 4)), + 'chanlocs': chanlocs, + 'event': [], + 'icachansind': np.array([]), + } + with tempfile.TemporaryDirectory() as tmp: + out = os.path.join(tmp, 'unit.set') + pop_saveset(EEG, out) + loaded = scipy.io.loadmat(out, struct_as_record=True) + self.assertIn('unit', loaded['chanlocs'].dtype.names) + if __name__ == '__main__': # EEG = pop_loadset(ensure_file('FlankerTest.set')) diff --git a/tests/test_pop_select.py b/tests/test_pop_select.py index 8111b9c5..a8b6fe34 100644 --- a/tests/test_pop_select.py +++ b/tests/test_pop_select.py @@ -202,6 +202,30 @@ def test_event_epoch_field_removal_on_single_trial(self): self.assertEqual(np.asarray(EEG_out.get('epoch')).size, 0) + def test_caller_eeg_is_not_mutated(self): + EEG = copy.deepcopy(self.EEG) + original_data = EEG['data'].copy() + original_nbchan = int(EEG['nbchan']) + events = EEG.get('event') + original_event_count = 0 if events is None else len(events) + original_first_latency = ( + float(events[0]['latency']) if original_event_count and 'latency' in events[0] else None + ) + + labels = _chan_labels(EEG) + self.assertGreaterEqual(len(labels), 2, "Dataset must have at least 2 channels") + EEG_out = pop_select(EEG, channel=labels[:2]) + + # The returned dataset reflects the selection ... + self.assertEqual(EEG_out['nbchan'], 2) + # ... while the caller's EEG is left completely untouched. + self.assertTrue(np.array_equal(original_data, EEG['data'])) + self.assertEqual(int(EEG['nbchan']), original_nbchan) + events_after = EEG.get('event') + self.assertEqual(0 if events_after is None else len(events_after), original_event_count) + if original_first_latency is not None: + self.assertEqual(float(events_after[0]['latency']), original_first_latency) + class TestPopSelectEdgeCases(unittest.TestCase): """Test edge cases and error conditions in pop_select.""" @@ -358,6 +382,30 @@ def test_mixed_channel_indices_labels(self): self.assertEqual(EEG_out['nbchan'], 2) self.assertEqual(EEG_out['data'].shape[0], 2) + def test_channel_selection_by_type(self): + """Selecting by chantype keeps only channels whose type matches (no unpack crash).""" + EEG = copy.deepcopy(self.EEG) + types = ['EEG', 'EEG', 'EOG', 'EOG'] + for chan, ctype in zip(EEG['chanlocs'], types): + chan['type'] = ctype + + EEG_out = pop_select(EEG, chantype=['EEG']) + + self.assertEqual(EEG_out['nbchan'], 2) + self.assertEqual([chan['labels'] for chan in EEG_out['chanlocs']], ['Fz', 'Cz']) + + def test_channel_removal_by_type(self): + """Removing by rmchantype drops channels whose type matches (no unpack crash).""" + EEG = copy.deepcopy(self.EEG) + types = ['EEG', 'EEG', 'EOG', 'EOG'] + for chan, ctype in zip(EEG['chanlocs'], types): + chan['type'] = ctype + + EEG_out = pop_select(EEG, rmchantype=['EOG']) + + self.assertEqual(EEG_out['nbchan'], 2) + self.assertEqual([chan['labels'] for chan in EEG_out['chanlocs']], ['Fz', 'Cz']) + def test_negative_indices_error(self): """Test that negative channel indices raise appropriate errors.""" EEG = copy.deepcopy(self.EEG) diff --git a/tests/test_processing_logging_contract.py b/tests/test_processing_logging_contract.py new file mode 100644 index 00000000..2619b065 --- /dev/null +++ b/tests/test_processing_logging_contract.py @@ -0,0 +1,70 @@ +import logging + +import numpy as np +import pytest + +from eegprep.functions.popfunc.eeg_eegrej import _combine_regions +from eegprep.functions.popfunc.eeg_lat2point import eeg_lat2point +from eegprep.functions.popfunc.pop_select import pop_select + + +def _minimal_eeg(): + return { + "data": np.zeros((2, 20), dtype=np.float32), + "nbchan": 2, + "pnts": 20, + "trials": 1, + "srate": 100, + "xmin": 0, + "xmax": 0.19, + "times": np.arange(20), + "chanlocs": [{"labels": "Cz"}, {"labels": "Pz"}], + "event": [], + "urevent": [], + "epoch": [], + "history": "", + "icaact": np.array([]), + "icaweights": np.array([]), + "icasphere": np.array([]), + "icawinv": np.array([]), + "icachansind": np.array([], dtype=int), + "chaninfo": {}, + "reject": {}, + } + + +def test_pop_select_warnings_use_logging_not_stdout(capsys, caplog): + caplog.set_level(logging.WARNING) + + with pytest.raises(ValueError, match="Channels not found"): + pop_select(_minimal_eeg(), channel=["Missing"]) + + captured = capsys.readouterr() + assert captured.out == "" + assert captured.err == "" + assert "channels not found" in caplog.text + + +def test_latency_range_warning_uses_logging_not_stdout(capsys, caplog): + caplog.set_level(logging.WARNING) + + newlat, flag = eeg_lat2point([2], [1], 1, [0, 0], outrange=1) + + captured = capsys.readouterr() + assert captured.out == "" + assert captured.err == "" + assert flag == 1 + np.testing.assert_array_equal(newlat, np.array([1.0])) + assert "Points out of range detected" in caplog.text + + +def test_eegrej_overlap_warning_uses_logging_not_stdout(capsys, caplog): + caplog.set_level(logging.WARNING) + + combined = _combine_regions(np.array([[1, 3], [3, 5], [10, 12]])) + + captured = capsys.readouterr() + assert captured.out == "" + assert captured.err == "" + np.testing.assert_array_equal(combined, np.array([[1, 5], [10, 12]])) + assert "Overlapping regions detected" in caplog.text diff --git a/tests/test_public_api_examples.py b/tests/test_public_api_examples.py index af4b524e..525b5ad9 100644 --- a/tests/test_public_api_examples.py +++ b/tests/test_public_api_examples.py @@ -72,6 +72,15 @@ def test_project_entry_points_cover_gui_and_console() -> None: assert pyproject["project"]["scripts"]["eegprep"] == "eegprep.cli.main:main" +def test_development_dependencies_cover_console_runtime() -> None: + pyproject = _read_pyproject() + dev_dependencies = set(pyproject["dependency-groups"]["dev"]) + + assert "ipython>=8.0" in dev_dependencies + assert "pyqtgraph>=0.13.7" in dev_dependencies + assert "PySide6>=6.6" in dev_dependencies + + def test_setuptools_package_data_covers_runtime_resources() -> None: pyproject = _read_pyproject() package_root = REPO_ROOT / "src/eegprep" diff --git a/tests/test_runamica.py b/tests/test_runamica.py index fc34d1f9..a334265e 100644 --- a/tests/test_runamica.py +++ b/tests/test_runamica.py @@ -271,6 +271,32 @@ def test_find_amica_binary_bad_env_var(self): _find_amica_binary() +class TestRunamicaTempCleanup(unittest.TestCase): + """A failed run must not leak the auto-created temp directory.""" + + def _amica_temp_dirs(self): + root = tempfile.gettempdir() + return {name for name in os.listdir(root) if name.startswith('amica_')} + + def test_failed_run_removes_temp_dir(self): + data = np.random.RandomState(0).randn(4, 500) + before = self._amica_temp_dirs() + + def _boom(binary, param_file): + raise RuntimeError("amica binary failed") + + with mock.patch( + 'eegprep.functions.sigprocfunc.runamica._find_amica_binary', + return_value='/dummy/amica', + ): + with mock.patch('eegprep.functions.sigprocfunc.runamica._run_amica', side_effect=_boom): + with self.assertRaises(RuntimeError): + runamica(data, num_models=1, max_iter=10, max_threads=1) + + # No new amica_* temp directory should survive the failure. + self.assertEqual(self._amica_temp_dirs() - before, set()) + + @unittest.skipUnless(is_amica_available(), "AMICA binary not functional on this platform") class TestRunamicaIntegration(unittest.TestCase): """Integration test: run AMICA binary on small synthetic data.""" diff --git a/tests/test_runica.py b/tests/test_runica.py index 7e8dcd31..d4b0ee60 100644 --- a/tests/test_runica.py +++ b/tests/test_runica.py @@ -48,6 +48,17 @@ def test_basic_ica(self): self.assertTrue(np.all(np.isfinite(bias))) self.assertTrue(np.all(np.isfinite(signs))) + def test_runica_does_not_mutate_input_array(self): + """runica must not modify the caller's data array (mean subtraction).""" + np.random.seed(42) + data = np.random.randn(8, 500).astype(np.float64) + original = data.copy() + + runica(data, maxsteps=5, verbose=False, rndreset='off') + + # The float64 input array passed by the caller must be untouched. + self.assertTrue(np.array_equal(data, original)) + def test_extended_ica(self): """Test extended-ICA mode.""" np.random.seed(42) diff --git a/tests/test_sigproc_eegrej.py b/tests/test_sigproc_eegrej.py new file mode 100644 index 00000000..5ce925cc --- /dev/null +++ b/tests/test_sigproc_eegrej.py @@ -0,0 +1,56 @@ +"""Tests for the low-level signal-processing eegrej (sigprocfunc.eegrej).""" + +import unittest + +import numpy as np + +from eegprep.functions.sigprocfunc.eegrej import eegrej + + +class TestSigprocEegrej(unittest.TestCase): + def setUp(self): + # 1 channel, 30 samples with values equal to their 1-based sample index + self.data = np.arange(1, 31, dtype=float).reshape(1, 30) + self.timelength = 30.0 + + def test_boundary_shift_uses_base_span_not_nested_duration(self): + # The first removed region already contains a boundary event with a large + # duration. EEGLAB shifts later boundary latencies by the prior regions' + # base spans only; the nested duration must NOT pull later boundaries left. + events = [ + {"type": "boundary", "latency": 7.0, "duration": 100.0}, # inside [5, 10] + {"type": "stim", "latency": 25.0}, + ] + _, _, newevents, boundevents = eegrej(self.data, [[5, 10], [20, 22]], self.timelength, events) + + # Region1 boundary at start-1 = 4 -> 4.5. + # Region2 boundary at start-1 = 19, shifted by region1 base span (6) -> 13 -> 13.5. + # The augmented duration (106) must not be used for the shift. + np.testing.assert_array_equal(boundevents, [4.5, 13.5]) + + bnd = {ev["latency"]: ev["duration"] for ev in newevents if ev.get("type") == "boundary"} + # The first boundary's .duration carries the nested duration (base 6 + 100). + self.assertEqual(bnd[4.5], 106.0) + # The second boundary's .duration is region2's base span (22 - 20 + 1 = 3). + self.assertEqual(bnd[13.5], 3.0) + + def test_multiple_regions_without_nested_boundaries(self): + # With no nested boundary durations, base span == duration, so the shift is + # unchanged: region2 boundary at 11 shifted by region1 base span 4 -> 7.5. + _, _, _, boundevents = eegrej(self.data, [[5, 8], [12, 14]], self.timelength) + np.testing.assert_array_equal(boundevents, [4.5, 7.5]) + + def test_adjacent_regions_merge_to_single_boundary(self): + # Adjacent regions excise a contiguous block; the two boundaries collapse + # to one latency after the base-span shift. + _, _, _, boundevents = eegrej(self.data, [[5, 8], [9, 12]], self.timelength) + np.testing.assert_array_equal(boundevents, [4.5]) + + def test_overlapping_regions_merge_to_single_boundary(self): + # Overlapping regions are de-overlapped then excised as one contiguous block. + _, _, _, boundevents = eegrej(self.data, [[5, 10], [8, 12]], self.timelength) + np.testing.assert_array_equal(boundevents, [4.5]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_statistics_package.py b/tests/test_statistics_package.py index de1b0f18..26916b47 100644 --- a/tests/test_statistics_package.py +++ b/tests/test_statistics_package.py @@ -46,6 +46,22 @@ def test_fdr_matches_bh_and_by_thresholds(): npt.assert_array_equal(by.mask, [True, True, False, False]) +def test_fdr_uses_finite_pvalues_for_threshold_denominator(): + pvals = np.array([0.01, 0.03, np.nan, np.inf]) + + result = fdr(pvals, 0.05) + + assert result.threshold == pytest.approx(0.03) + npt.assert_array_equal(result.mask, [True, True, False, False]) + + +def test_fdr_no_finite_pvalues_returns_empty_mask(): + result = fdr(np.array([np.nan, np.inf]), 0.05) + + assert result.threshold == 0 + npt.assert_array_equal(result.mask, [False, False]) + + def test_surrogate_pvals_and_ci_use_last_axis(): distribution = np.array([[1, 2, 3, 4], [4, 5, 6, 7]], dtype=float) observed = np.array([3, 5], dtype=float) diff --git a/tests/test_study_clustering.py b/tests/test_study_clustering.py index ae398835..e172df7f 100644 --- a/tests/test_study_clustering.py +++ b/tests/test_study_clustering.py @@ -150,11 +150,21 @@ def test_deterministic_clustering_helpers_return_stable_shapes(): centroids = std_centroid(data[:4], optimal_labels) outlier_rows = std_findoutlierclust(data, [1, 1, 1, 1, 1], threshold=1.0) + # The two tight pairs must each share a label and the pairs must differ, + # so a regression that swaps or misindexes cluster assignments is caught. assert set(optimal_labels.tolist()) == {1, 2} + assert optimal_labels[0] == optimal_labels[1] + assert optimal_labels[2] == optimal_labels[3] + assert optimal_labels[0] != optimal_labels[2] assert optimal_centers.shape == (2, 2) assert optimal_sumd.shape == (2,) assert optimal_distances.shape == (4, 2) + + # The far row [20, 20] (row index 4) must land alone in its own cluster, + # distinct from the four tight rows. assert robust_labels.shape == (5,) + assert robust_labels[4] not in set(robust_labels[:4].tolist()) + assert int(np.sum(robust_labels == robust_labels[4])) == 1 assert robust_centers.shape[1] == 2 assert robust_sumd.ndim == 1 assert robust_distances.shape[0] == 5 @@ -163,9 +173,31 @@ def test_deterministic_clustering_helpers_return_stable_shapes(): assert ap_centers.shape[1] == 2 assert ap_sumd.ndim == 1 assert centroids.shape == (2, 2) + # std_findoutlierclust must flag the far row (1-based index 5) as the outlier. assert outlier_rows.tolist() == [5] +def test_kmeans_kernel_is_shared_by_clustering_callers(): + # optimal_kmeans and robust_kmeans must reuse the one canonical k-means + # kernel so the numerics cannot drift between copies. Single-cluster runs + # let us compare labels/centers directly against the kernel output. + from eegprep.functions.studyfunc._cluster_kmeans import kmeans_labels, squared_distances + + data = np.asarray([[0.0, 0.0], [0.1, 0.0], [5.0, 5.0], [5.1, 5.0]]) + kernel_labels, kernel_centers = kmeans_labels(data, 2, random_state=7) + + optimal_labels, optimal_centers, _sumd, optimal_distances = optimal_kmeans(data, 2, random_state=7) + np.testing.assert_array_equal(optimal_labels, kernel_labels) + np.testing.assert_allclose(optimal_centers, kernel_centers) + np.testing.assert_allclose(optimal_distances, np.sqrt(squared_distances(data, kernel_centers))) + + robust_labels, robust_centers, _rsumd, _rdist, _outliers = robust_kmeans( + data, 2, STD=float("inf"), MAXiter=1, random_state=7 + ) + np.testing.assert_array_equal(robust_labels, kernel_labels) + np.testing.assert_allclose(robust_centers, kernel_centers) + + def test_pop_clust_rejects_invalid_cluster_counts_and_outlier_thresholds(): study, alleeg = _preclustered_study() diff --git a/tests/test_study_measures.py b/tests/test_study_measures.py index 87fcc33b..99caf9d8 100644 --- a/tests/test_study_measures.py +++ b/tests/test_study_measures.py @@ -97,6 +97,24 @@ def test_std_precomp_channel_measures_store_eeglab_named_fields(): std_readerp(collapsed, alleeg, channels=[1], subject="S01") +def test_std_precomp_recompute_off_preserves_cached_channel_measures(): + study, alleeg = _study_pair() + + study, alleeg = std_precomp(study, alleeg, [1], erp="on", spec="on") + original_erp = deepcopy(study["changrp"][0]["erpdata"]) + # Overwrite the cached value with a sentinel the recompute path would never produce. + sentinel = (np.asarray(original_erp) + 1000.0).tolist() + study["changrp"][0]["erpdata"] = sentinel + + # recompute='off' must keep the cached measure rather than recomputing it. + study, alleeg = std_precomp(study, alleeg, [1], erp="on", spec="on", recompute="off") + np.testing.assert_allclose(np.asarray(study["changrp"][0]["erpdata"]), np.asarray(sentinel)) + + # recompute='on' must overwrite the sentinel with a freshly computed measure. + study, alleeg = std_precomp(study, alleeg, [1], erp="on", spec="on", recompute="on") + np.testing.assert_allclose(np.asarray(study["changrp"][0]["erpdata"]), np.asarray(original_erp)) + + def test_std_precomp_baseline_and_design_contract(caplog): study, alleeg = _study_pair() alleeg[0]["times"] = np.asarray([-100.0, 0.0, 100.0, 200.0]) diff --git a/tests/test_study_metadata.py b/tests/test_study_metadata.py index f4c86a6f..bb72412a 100644 --- a/tests/test_study_metadata.py +++ b/tests/test_study_metadata.py @@ -111,6 +111,31 @@ def test_std_makedesign_selects_1_based_design_and_validates_variables(): std_selectdesign(study, alleeg, 3) +def test_std_makedesign_delfiles_off_preserves_cached_measures(): + study, alleeg = pop_study( + None, + [ + _eeg("one", subject="S01", condition="target", group="control"), + _eeg("two", subject="S02", condition="standard", group="patient"), + ], + ) + erpdata = [[1.0, 2.0, 3.0]] + study["changrp"] = [{"name": "Ch1", "channels": ["Ch1"], "erpdata": erpdata, "erptimes": [0.0, 1.0, 2.0]}] + + kept, _command = std_makedesign( + study, alleeg, 2, variable1="group", values1=["control"], delfiles="off", return_com=True + ) + assert kept["changrp"][0]["erpdata"] == erpdata + assert kept["changrp"][0]["erptimes"] == [0.0, 1.0, 2.0] + + study["changrp"] = [{"name": "Ch1", "channels": ["Ch1"], "erpdata": erpdata, "erptimes": [0.0, 1.0, 2.0]}] + cleared, _command = std_makedesign( + study, alleeg, 2, variable1="group", values1=["control"], delfiles="on", return_com=True + ) + assert "erpdata" not in cleared["changrp"][0] + assert "erptimes" not in cleared["changrp"][0] + + def test_pop_studydesign_selects_and_updates_design(): study, alleeg = pop_study( None, diff --git a/tests/test_utils_asr.py b/tests/test_utils_asr.py index 94cd48a9..6e4f22a9 100644 --- a/tests/test_utils_asr.py +++ b/tests/test_utils_asr.py @@ -4,6 +4,7 @@ from eegprep.plugins.clean_rawdata.asr_calibrate import asr_calibrate from eegprep.plugins.clean_rawdata.asr_process import asr_process +from eegprep.plugins.clean_rawdata.clean_asr import clean_asr class TestAsrCalibrate(unittest.TestCase): @@ -69,20 +70,31 @@ def test_different_sampling_rates(self): self.assertTrue(len(state['B']) > 1) self.assertTrue(len(state['A']) > 0) - def test_unsupported_sampling_rate(self): - """Test calibration with unsupported sampling rate (triggers warning).""" - unsupported_srate = 999.0 + def test_unsupported_sampling_rate_raises(self): + """Unsupported sampling rates must fail loudly, not silently degrade. + + Common rates like 999/1000/1024 Hz have no pre-computed spectral filter. + Substituting a trivial difference filter would silently miscalibrate ASR + thresholds, so asr_calibrate must raise rather than warn-and-continue. + """ data = np.random.randn(4, 1000) * 0.3 - with self.assertLogs('eegprep.plugins.clean_rawdata.asr_calibrate', level='WARNING') as log: - state = asr_calibrate(data, unsupported_srate) + for unsupported_srate in (999.0, 1000.0, 1024.0): + with self.subTest(srate=unsupported_srate): + with self.assertRaises(ValueError) as cm: + asr_calibrate(data, unsupported_srate) + self.assertIn('No pre-computed ASR spectral filter', str(cm.exception)) - # Check that warning was logged - self.assertTrue(any('No pre-computed spectral filter' in msg for msg in log.output)) + def test_unsupported_sampling_rate_allows_explicit_filter(self): + """An explicit B/A bypasses the precomputed-filter lookup for any srate.""" + data = np.random.randn(4, 1000) * 0.3 + B = np.array([1.0, -0.5]) + A = np.array([1.0]) - # Check that fallback filter was used - self.assertEqual(len(state['B']), 2) # Simple fallback filter - self.assertEqual(len(state['A']), 1) + state = asr_calibrate(data, 999.0, B=B, A=A) + self.assertIn('M', state) + np.testing.assert_array_equal(state['B'], B) + np.testing.assert_array_equal(state['A'], A) def test_parameter_validation(self): """Test parameter validation and edge cases.""" @@ -336,30 +348,51 @@ def test_memory_error_handling(self): self.assertIn('Not enough memory', str(cm.exception)) - def test_eigendecomposition_failure_handling(self): - """Test handling of eigendecomposition failures.""" - # Create problematic covariance that might cause eigendecomposition to fail - with patch('numpy.linalg.eigh') as mock_eigh: - mock_eigh.side_effect = np.linalg.LinAlgError("Eigendecomposition failed") + def test_rank_deficient_covariance_produces_sane_output(self): + """Process genuinely rank-deficient data (singular covariance). - with self.assertLogs('eegprep.plugins.clean_rawdata.asr_process', level='WARNING') as log: - cleaned_data, new_state = asr_process(self.test_data, self.srate, self.state) + Duplicate and zeroed channels make the per-window covariance singular, + exercising the eigendecomposition and pseudo-inverse paths with a real + degenerate input rather than monkeypatching numpy to raise. The cleaned + output must stay finite and keep its shape. + """ + degenerate = self.test_data.copy() + degenerate[3, :] = degenerate[0, :] # duplicate channel -> singular covariance + degenerate[5, :] = 0.0 # flat channel -> singular covariance - # Should log warning and use fallback - self.assertTrue(any('Eigendecomposition failed' in msg for msg in log.output)) - self.assertEqual(cleaned_data.shape, self.test_data.shape) + cleaned_data, new_state = asr_process(degenerate, self.srate, self.state) + + self.assertEqual(cleaned_data.shape, degenerate.shape) + self.assertTrue(np.all(np.isfinite(cleaned_data))) - def test_reconstruction_matrix_failure(self): - """Test handling of reconstruction matrix calculation failures.""" - with patch('numpy.linalg.pinv') as mock_pinv: - mock_pinv.side_effect = np.linalg.LinAlgError("Singular matrix") + def test_extreme_artifact_amplitudes_produce_sane_output(self): + """Process data with extreme-amplitude artifacts. - with self.assertLogs('eegprep.plugins.clean_rawdata.asr_process', level='WARNING') as log: - cleaned_data, new_state = asr_process(self.test_data, self.srate, self.state) + Huge transient amplitudes drive the reconstruction matrix toward + ill-conditioning, exercising the same numeric path. The cleaned output + must remain finite, keep its shape, and attenuate the injected spike. + """ + extreme = self.test_data.copy() + spike_peak = float(np.max(np.abs(extreme))) * 1e4 + extreme[2, 100:150] += spike_peak - # Should log warning and use identity matrix fallback - self.assertTrue(any('Failed to calculate inverse' in msg for msg in log.output)) - self.assertEqual(cleaned_data.shape, self.test_data.shape) + cleaned_data, new_state = asr_process(extreme, self.srate, self.state) + + self.assertEqual(cleaned_data.shape, extreme.shape) + self.assertTrue(np.all(np.isfinite(cleaned_data))) + self.assertLess(float(np.max(np.abs(cleaned_data))), spike_peak) + + def test_component_selection_shape_error_propagates(self): + """A shape/contract bug during component selection must surface, not be + silently swallowed into a no-op (keep-all) for the affected window. + """ + bad_state = dict(self.state) + # T is C x C in a valid state; an incompatible threshold matrix makes + # finite_matmul(T, V) fail. This must raise rather than disable cleaning. + bad_state['T'] = np.ones((self.n_channels + 1, self.n_channels + 1)) + + with self.assertRaises(ValueError): + asr_process(self.test_data, self.srate, bad_state) def test_state_persistence_across_calls(self): """Test that state is properly maintained across multiple processing calls.""" @@ -396,15 +429,16 @@ def test_window_length_adjustment(self): self.assertTrue(np.all(np.isfinite(cleaned_data))) def test_component_selection_error_handling(self): - """Test error handling in component selection logic.""" - # Mock numpy.sum to raise an error during threshold calculation - with patch('numpy.sum', side_effect=Exception("Threshold error")): - with self.assertLogs('eegprep.plugins.clean_rawdata.asr_process', level='ERROR') as log: - cleaned_data, new_state = asr_process(self.test_data, self.srate, self.state) + """An unexpected error during threshold computation must propagate. - # Should log error and use fallback (keep all components) - self.assertTrue(any('Error in component selection' in msg for msg in log.output)) - self.assertEqual(cleaned_data.shape, self.test_data.shape) + Previously such errors were swallowed and the window silently kept all + components (no artifact removal). They must now surface so genuine bugs are + visible rather than producing quietly under-cleaned data. + """ + with patch('numpy.sum', side_effect=Exception("Threshold error")): + with self.assertRaises(Exception) as cm: + asr_process(self.test_data, self.srate, self.state) + self.assertIn('Threshold error', str(cm.exception)) class TestAsrIntegration(unittest.TestCase): @@ -603,25 +637,27 @@ def test_single_channel_data(self): self.assertEqual(cleaned_data.shape, test_data.shape) def test_very_high_sampling_rate(self): - """Test ASR with very high sampling rate.""" + """Very high (unsupported) sampling rate must fail loudly in calibration. + + 2000 Hz has no pre-computed spectral filter; calibration must raise rather + than silently substitute a degenerate difference filter. + """ n_channels = 4 - srate = 2000.0 # High sampling rate + srate = 2000.0 # High, unsupported sampling rate - # Create appropriate amount of data n_samples = int(srate * 2) # 2 seconds calib_data = np.random.randn(n_channels, n_samples) * 0.3 - # Should use fallback filter for unsupported sampling rate - with self.assertLogs('eegprep.plugins.clean_rawdata.asr_calibrate', level='WARNING'): - state = asr_calibrate(calib_data, srate) - - # Should still work - self.assertIsInstance(state, dict) + with self.assertRaises(ValueError) as cm: + asr_calibrate(calib_data, srate) + self.assertIn('No pre-computed ASR spectral filter', str(cm.exception)) - # Test processing + # With explicit filter coefficients the high rate is processable end-to-end. + B = np.array([1.0, -0.5]) + A = np.array([1.0]) + state = asr_calibrate(calib_data, srate, B=B, A=A) test_data = np.random.randn(n_channels, 200) * 0.4 cleaned_data, _ = asr_process(test_data, srate, state) - self.assertEqual(cleaned_data.shape, test_data.shape) def test_zero_variance_data(self): @@ -660,5 +696,43 @@ def test_memory_usage_calculation_accuracy(self): self.assertEqual(cleaned_data.shape, (n_channels, 1000)) +class TestCleanAsrNoMutation(unittest.TestCase): + """Regression tests that clean_asr never mutates the caller's EEG.""" + + def setUp(self): + np.random.seed(7) + n_channels = 8 + n_samples = 2500 + srate = 250.0 + data = np.random.randn(n_channels, n_samples) * 0.5 + for i in range(n_channels): + for j in range(1, n_samples): + data[i, j] += 0.8 * data[i, j - 1] + # Inject a non-finite sample to exercise the in-place NaN-zeroing path + # that asr_calibrate applies to whatever array it receives. + data[0, 100] = np.nan + self.EEG = { + 'data': data, + 'srate': srate, + 'nbchan': n_channels, + 'pnts': n_samples, + 'etc': {}, + } + + def test_does_not_mutate_input_data(self): + """clean_asr must leave the caller's EEG['data'] (incl. NaNs) unchanged.""" + EEG_in = self.EEG + original_data = EEG_in['data'].copy() + + EEG_out = clean_asr(EEG_in, ref_maxbadchannels='off') + + # The caller's data is byte-for-byte unchanged, including the NaN that + # asr_calibrate would otherwise have zeroed in place. + self.assertTrue(np.array_equal(original_data, EEG_in['data'], equal_nan=True)) + # Output is a distinct object with distinct data. + self.assertIsNot(EEG_out, EEG_in) + self.assertIsNot(EEG_out['data'], EEG_in['data']) + + if __name__ == '__main__': unittest.main() diff --git a/tests/test_utils_ransac.py b/tests/test_utils_ransac.py index a90774e4..3e386ece 100644 --- a/tests/test_utils_ransac.py +++ b/tests/test_utils_ransac.py @@ -3,6 +3,7 @@ from unittest.mock import patch, MagicMock from eegprep.plugins.clean_rawdata.private.ransac import rand_sample, calc_projector +from eegprep.plugins.clean_rawdata.private.sphericalSplineInterpolate import sphericalSplineInterpolate class TestRandSample(unittest.TestCase): @@ -328,6 +329,44 @@ def mock_interp_func(src_locs, dest_locs): self.assertEqual(result.shape, (self.n_channels, self.n_channels * 3)) self.assertEqual(mock_interp.call_count, 3) + def test_real_interpolation_channel_mapping_and_assembly(self): + """Validate calc_projector against the real spherical-spline kernel. + + The mock-based tests above only check output shape and call counts, so a + bug that permutes channel indices or mishandles the per-subset transpose + would pass. This test runs the real interpolation on a small montage and + independently reproduces each subset's reconstruction matrix, catching + such assembly bugs. + """ + num_samples = 4 + subset_size = self.n_channels - 2 + + # Reproduce the exact subsets calc_projector samples (k from num_samples-1..0). + subset_stream = np.random.RandomState(7) + subsets = {k: rand_sample(self.n_channels, subset_size, subset_stream) for k in range(num_samples - 1, -1, -1)} + + projector = calc_projector(self.locs, num_samples, subset_size, stream=np.random.RandomState(7)) + + # Output must be a finite, real-valued bag of reconstruction matrices. + self.assertEqual(projector.shape, (self.n_channels, self.n_channels * num_samples)) + self.assertTrue(np.isrealobj(projector)) + self.assertTrue(np.all(np.isfinite(projector))) + + blocks = projector.reshape(self.n_channels, num_samples, self.n_channels) + for k, sample in subsets.items(): + block = blocks[:, k, :] + + # Only the rows of the sampled source channels carry weight; the two + # unsampled channels must stay all-zero. A channel-index permutation + # would shift the zero rows away from the unsampled channels. + nonzero_rows = np.flatnonzero(np.any(block != 0, axis=1)) + np.testing.assert_array_equal(np.sort(nonzero_rows), np.sort(sample)) + + # The non-zero rows must equal the real spherical-spline weights for + # this subset, transposed exactly as calc_projector assembles them. + expected_w = sphericalSplineInterpolate(self.locs[sample, :].T, self.locs.T)[0] + np.testing.assert_allclose(block[sample, :], np.real(expected_w).T, rtol=1e-10, atol=1e-12) + class TestRansacIntegration(unittest.TestCase): """Integration tests for RANSAC functionality.""" diff --git a/tests/test_visual_parity.py b/tests/test_visual_parity.py index c477ff9a..f732f9af 100644 --- a/tests/test_visual_parity.py +++ b/tests/test_visual_parity.py @@ -11,7 +11,10 @@ from tools.visual_parity.config import load_manifest from tools.visual_parity.export_eegprep_menu_inventory import export_inventory from tools.visual_parity.menu_inventory import compare_menu_trees -from eegprep.functions.guifunc.visual_capture import _main_window_menu_state as _eegprep_main_window_menu_state +from tools.visual_parity.visual_capture import ( + _capture_case_handlers, + _main_window_menu_state as _eegprep_main_window_menu_state, +) ONE_PIXEL_PNG = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+/p9sAAAAASUVORK5CYII=" @@ -24,12 +27,10 @@ def test_load_manifest_parses_cases(self): self.assertIn("main_window", cases) self.assertEqual(cases["main_window"].window_size, (520, 380)) self.assertIn("eeglab", cases["main_window"].targets) - self.assertIn("eegprep.functions.guifunc.visual_capture", cases["main_window"].targets["eegprep"].command) + self.assertIn("tools.visual_parity.visual_capture", cases["main_window"].targets["eegprep"].command) self.assertIn("adjust_events_dialog", cases) self.assertEqual(cases["adjust_events_dialog"].targets["eeglab"].type, "matlab_dialog") - self.assertIn( - "eegprep.functions.guifunc.visual_capture", cases["adjust_events_dialog"].targets["eegprep"].command - ) + self.assertIn("tools.visual_parity.visual_capture", cases["adjust_events_dialog"].targets["eegprep"].command) self.assertIn("reref_dialog", cases) self.assertEqual(cases["reref_dialog"].targets["eeglab"].action, "pop_reref") self.assertEqual(cases["reref_dialog_channel_ref"].targets["eeglab"].action, "pop_reref:channels") @@ -50,7 +51,7 @@ def test_load_manifest_parses_cases(self): ): with self.subTest(case_id=case_id): self.assertEqual(cases[case_id].targets["eeglab"].type, "matlab_figure") - self.assertIn("eegprep.functions.guifunc.visual_capture", cases[case_id].targets["eegprep"].command) + self.assertIn("tools.visual_parity.visual_capture", cases[case_id].targets["eegprep"].command) self.assertIn("pop_interp_dialog", cases) self.assertEqual(cases["pop_interp_dialog"].targets["eeglab"].action, "pop_interp:continuous") self.assertEqual(cases["pop_interp_epoched_dialog"].targets["eeglab"].action, "pop_interp:epoched") @@ -131,6 +132,18 @@ def test_load_manifest_parses_cases(self): self.assertEqual(cases["pop_interp_dataset_index_dialog"].targets["eeglab"].action, "inputdlg2:dataset_index") self.assertEqual(cases["pop_reref_help_dialog"].targets["eeglab"].action, "pophelp:pop_reref") + def test_eegprep_visual_capture_registry_covers_manifest_cases(self): + handlers = _capture_case_handlers() + cases = load_manifest() + + for case_id, case in cases.items(): + eegprep_target = case.targets.get("eegprep") + if eegprep_target is None: + continue + if "tools.visual_parity.visual_capture" in eegprep_target.command: + with self.subTest(case_id=case_id): + self.assertIn(case_id, handlers) + def test_eegbrowser_epoched_cases_compare_raw_matrix_captures(self): cases = load_manifest() diff --git a/tools/eeg_bids/__init__.py b/tools/eeg_bids/__init__.py new file mode 100644 index 00000000..ef0a2e2d --- /dev/null +++ b/tools/eeg_bids/__init__.py @@ -0,0 +1 @@ +"""EEG-BIDS parity and comparison tools for EEGPrep.""" diff --git a/src/eegprep/plugins/EEG_BIDS/stage_comparison.py b/tools/eeg_bids/stage_comparison.py similarity index 100% rename from src/eegprep/plugins/EEG_BIDS/stage_comparison.py rename to tools/eeg_bids/stage_comparison.py diff --git a/tools/eeglab_final_parity_matrix.py b/tools/eeglab_final_parity_matrix.py index 6c6a9d6d..d8be63a5 100644 --- a/tools/eeglab_final_parity_matrix.py +++ b/tools/eeglab_final_parity_matrix.py @@ -429,7 +429,9 @@ def _source_path_exists(repo_root: Path, source_path: str) -> bool: source = repo_root / "src/eegprep/eeglab" / source_path if source.exists(): return True - return source_path in _load_reference_path_snapshot(repo_root) and _source_root_exists(repo_root, source_path) + if _uses_reference_path_snapshot(repo_root) or source_path in _load_reference_path_snapshot(repo_root): + return True + return _source_root_exists(repo_root, source_path) if source_path.startswith("docs/"): return (repo_root / source_path).exists() return False @@ -456,6 +458,13 @@ def _source_root_exists(repo_root: Path, source_path: str) -> bool: return False +def _uses_reference_path_snapshot(repo_root: Path) -> bool: + snapshot_paths = _load_reference_path_snapshot(repo_root) + if not snapshot_paths: + return False + return not snapshot_paths <= _discover_live_final_eeglab_paths(repo_root / "src/eegprep/eeglab") + + def _eeglab_reference_root_exists(repo_root: Path, relative_root: str) -> bool: eeglab_root = repo_root / "src/eegprep/eeglab" if (eeglab_root / relative_root).exists(): diff --git a/tools/iclabel/__init__.py b/tools/iclabel/__init__.py new file mode 100644 index 00000000..5627fbf1 --- /dev/null +++ b/tools/iclabel/__init__.py @@ -0,0 +1 @@ +"""ICLabel parity and comparison tools for EEGPrep.""" diff --git a/tools/iclabel/iclabel_net_load_py_measures.py b/tools/iclabel/iclabel_net_load_py_measures.py new file mode 100644 index 00000000..9c678fae --- /dev/null +++ b/tools/iclabel/iclabel_net_load_py_measures.py @@ -0,0 +1,39 @@ +"""MATLAB-parity harness: run the packaged ICLabelNet on reformatted features. + +This is a development/parity tool, not part of the installed ``eegprep`` +package. It is invoked by the MATLAB parity scripts under ``tests/matlab/`` +via ``system('... iclabel_net_load_py_measures.py')``. It reuses the single +canonical :class:`~eegprep.plugins.ICLabel.iclabel_net.ICLabelNet` definition +so the parity harness exercises the same network as production ICLabel. +""" + +import logging +from importlib.resources import files + +import scipy.io +import torch + +from eegprep.plugins.ICLabel.iclabel_net import ICLabelNet + +logger = logging.getLogger(__name__) + + +if __name__ == "__main__": + net_path = files("eegprep").joinpath("plugins").joinpath("ICLabel").joinpath("netICL.mat") + model = ICLabelNet(str(net_path)) + data = scipy.io.loadmat('python_temp_reformated.mat') + image_mat = data['grid'][0][0] + psdmed_mat = data['grid'][0][1] + autocorr_mat = data['grid'][0][2] + # assuming third dimension is trivial and last dimension is channel. First two dimensions (32 x 32) are size of topoplot + image = torch.tensor(image_mat).permute(-1, 2, 0, 1) + logger.debug("image shape: %s", image.shape) + psdmed = torch.tensor(psdmed_mat).permute(-1, 2, 0, 1) + logger.debug("psd shape: %s", psdmed.shape) + autocorr = torch.tensor(autocorr_mat).permute(-1, 2, 0, 1) + logger.debug("autocorr shape: %s", autocorr.shape) + output = model(image, psdmed, autocorr) + logger.debug("output shape: %s", output.shape) + + # save the output to a mat file + scipy.io.savemat('output4_py.mat', {'output': output.detach().numpy()}) diff --git a/tools/visual_parity/cases.json b/tools/visual_parity/cases.json index 3e185a41..43364461 100644 --- a/tools/visual_parity/cases.json +++ b/tools/visual_parity/cases.json @@ -20,7 +20,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -45,7 +45,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -70,7 +70,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -95,7 +95,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -120,7 +120,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -146,7 +146,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -172,7 +172,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -198,7 +198,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -224,7 +224,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -250,7 +250,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -276,7 +276,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -302,7 +302,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -328,7 +328,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -354,7 +354,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -380,7 +380,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -406,7 +406,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -432,7 +432,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -457,7 +457,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -482,7 +482,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -507,7 +507,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -532,7 +532,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -557,7 +557,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -582,7 +582,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -607,7 +607,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -632,7 +632,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -657,7 +657,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -682,7 +682,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -707,7 +707,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -732,7 +732,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -757,7 +757,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -782,7 +782,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -807,7 +807,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -832,7 +832,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -857,7 +857,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -882,7 +882,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -907,7 +907,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -932,7 +932,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -957,7 +957,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -982,7 +982,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1007,7 +1007,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1032,7 +1032,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1057,7 +1057,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1082,7 +1082,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1107,7 +1107,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1132,7 +1132,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1157,7 +1157,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1182,7 +1182,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1207,7 +1207,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1232,7 +1232,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1257,7 +1257,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1282,7 +1282,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1307,7 +1307,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1332,7 +1332,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1357,7 +1357,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1382,7 +1382,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1407,7 +1407,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1432,7 +1432,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1457,7 +1457,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1482,7 +1482,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1507,7 +1507,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1532,7 +1532,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1557,7 +1557,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1582,7 +1582,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1607,7 +1607,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1632,7 +1632,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1657,7 +1657,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1682,7 +1682,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1707,7 +1707,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1732,7 +1732,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1757,7 +1757,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1782,7 +1782,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1807,7 +1807,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1832,7 +1832,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1857,7 +1857,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1882,7 +1882,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1907,7 +1907,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1932,7 +1932,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1957,7 +1957,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -1982,7 +1982,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -2007,7 +2007,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -2032,7 +2032,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -2057,7 +2057,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -2082,7 +2082,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -2107,7 +2107,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -2126,7 +2126,7 @@ "eegprep": { "type": "command", "action": "pop_eegthresh_dialog", - "command": ["{python}", "-m", "eegprep.functions.guifunc.visual_capture", "--case", "{case_id}", "--output", "{output}"] + "command": ["{python}", "-m", "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", "{output}"] } } }, @@ -2140,7 +2140,7 @@ "eegprep": { "type": "command", "action": "pop_jointprob_dialog", - "command": ["{python}", "-m", "eegprep.functions.guifunc.visual_capture", "--case", "{case_id}", "--output", "{output}"] + "command": ["{python}", "-m", "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", "{output}"] } } }, @@ -2154,7 +2154,7 @@ "eegprep": { "type": "command", "action": "pop_rejkurt_dialog", - "command": ["{python}", "-m", "eegprep.functions.guifunc.visual_capture", "--case", "{case_id}", "--output", "{output}"] + "command": ["{python}", "-m", "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", "{output}"] } } }, @@ -2168,7 +2168,7 @@ "eegprep": { "type": "command", "action": "pop_rejtrend_dialog", - "command": ["{python}", "-m", "eegprep.functions.guifunc.visual_capture", "--case", "{case_id}", "--output", "{output}"] + "command": ["{python}", "-m", "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", "{output}"] } } }, @@ -2182,7 +2182,7 @@ "eegprep": { "type": "command", "action": "pop_rejspec_dialog", - "command": ["{python}", "-m", "eegprep.functions.guifunc.visual_capture", "--case", "{case_id}", "--output", "{output}"] + "command": ["{python}", "-m", "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", "{output}"] } } }, @@ -2196,7 +2196,7 @@ "eegprep": { "type": "command", "action": "pop_rejmenu_dialog", - "command": ["{python}", "-m", "eegprep.functions.guifunc.visual_capture", "--case", "{case_id}", "--output", "{output}"] + "command": ["{python}", "-m", "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", "{output}"] } } }, @@ -2210,7 +2210,7 @@ "eegprep": { "type": "command", "action": "pop_autorej_dialog", - "command": ["{python}", "-m", "eegprep.functions.guifunc.visual_capture", "--case", "{case_id}", "--output", "{output}"] + "command": ["{python}", "-m", "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", "{output}"] } } }, @@ -2224,7 +2224,7 @@ "eegprep": { "type": "command", "action": "pop_rejchan_dialog", - "command": ["{python}", "-m", "eegprep.functions.guifunc.visual_capture", "--case", "{case_id}", "--output", "{output}"] + "command": ["{python}", "-m", "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", "{output}"] } } }, @@ -2238,7 +2238,7 @@ "eegprep": { "type": "command", "action": "pop_rejcont_dialog", - "command": ["{python}", "-m", "eegprep.functions.guifunc.visual_capture", "--case", "{case_id}", "--output", "{output}"] + "command": ["{python}", "-m", "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", "{output}"] } } }, @@ -2252,7 +2252,7 @@ "eegprep": { "type": "command", "action": "pop_selectcomps_dialog", - "command": ["{python}", "-m", "eegprep.functions.guifunc.visual_capture", "--case", "{case_id}", "--output", "{output}"] + "command": ["{python}", "-m", "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", "{output}"] } } }, @@ -2266,7 +2266,7 @@ "eegprep": { "type": "command", "action": "pop_viewprops_dialog", - "command": ["{python}", "-m", "eegprep.functions.guifunc.visual_capture", "--case", "{case_id}", "--output", "{output}"] + "command": ["{python}", "-m", "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", "{output}"] } } }, @@ -2285,7 +2285,7 @@ "type": "command", "action": "iclabel_pop_prop_extended_dashboard", "env": {"MPLBACKEND": "Agg"}, - "command": ["{python}", "-m", "eegprep.functions.guifunc.visual_capture", "--case", "{case_id}", "--output", "{output}"] + "command": ["{python}", "-m", "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", "{output}"] } } }, @@ -2299,7 +2299,7 @@ "eegprep": { "type": "command", "action": "pop_dipfit_settings_dialog", - "command": ["{python}", "-m", "eegprep.functions.guifunc.visual_capture", "--case", "{case_id}", "--output", "{output}"] + "command": ["{python}", "-m", "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", "{output}"] } } }, @@ -2313,7 +2313,7 @@ "eegprep": { "type": "command", "action": "pop_dipfit_gridsearch_dialog", - "command": ["{python}", "-m", "eegprep.functions.guifunc.visual_capture", "--case", "{case_id}", "--output", "{output}"] + "command": ["{python}", "-m", "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", "{output}"] } } }, @@ -2327,7 +2327,7 @@ "eegprep": { "type": "command", "action": "pop_dipfit_nonlinear_dialog", - "command": ["{python}", "-m", "eegprep.functions.guifunc.visual_capture", "--case", "{case_id}", "--output", "{output}"] + "command": ["{python}", "-m", "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", "{output}"] } } }, @@ -2341,7 +2341,7 @@ "eegprep": { "type": "command", "action": "pop_dipplot_dialog", - "command": ["{python}", "-m", "eegprep.functions.guifunc.visual_capture", "--case", "{case_id}", "--output", "{output}"] + "command": ["{python}", "-m", "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", "{output}"] } } }, @@ -2355,7 +2355,7 @@ "eegprep": { "type": "command", "action": "pop_multifit_dialog", - "command": ["{python}", "-m", "eegprep.functions.guifunc.visual_capture", "--case", "{case_id}", "--output", "{output}"] + "command": ["{python}", "-m", "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", "{output}"] } } }, @@ -2369,7 +2369,7 @@ "eegprep": { "type": "command", "action": "pop_leadfield_dialog", - "command": ["{python}", "-m", "eegprep.functions.guifunc.visual_capture", "--case", "{case_id}", "--output", "{output}"] + "command": ["{python}", "-m", "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", "{output}"] } } }, @@ -2383,7 +2383,7 @@ "eegprep": { "type": "command", "action": "pop_dipfit_loreta_dialog", - "command": ["{python}", "-m", "eegprep.functions.guifunc.visual_capture", "--case", "{case_id}", "--output", "{output}"] + "command": ["{python}", "-m", "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", "{output}"] } } }, @@ -2396,7 +2396,7 @@ "eegprep": { "type": "command", "action": "pop_dipfit_headmodel_dialog", - "command": ["{python}", "-m", "eegprep.functions.guifunc.visual_capture", "--case", "{case_id}", "--output", "{output}"] + "command": ["{python}", "-m", "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", "{output}"] } } }, @@ -2416,7 +2416,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -2441,7 +2441,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -2466,7 +2466,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -2491,7 +2491,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -2516,7 +2516,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -2541,7 +2541,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -2565,7 +2565,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -2589,7 +2589,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -2613,7 +2613,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -2637,7 +2637,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -2661,7 +2661,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -2685,7 +2685,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", @@ -2709,7 +2709,7 @@ "command": [ "{python}", "-m", - "eegprep.functions.guifunc.visual_capture", + "tools.visual_parity.visual_capture", "--case", "{case_id}", "--output", diff --git a/src/eegprep/functions/guifunc/visual_capture.py b/tools/visual_parity/visual_capture.py similarity index 82% rename from src/eegprep/functions/guifunc/visual_capture.py rename to tools/visual_parity/visual_capture.py index 35f04ddf..3635f334 100644 --- a/src/eegprep/functions/guifunc/visual_capture.py +++ b/tools/visual_parity/visual_capture.py @@ -7,7 +7,7 @@ import os import pathlib import sys -from typing import Any +from typing import Any, Callable import matplotlib.pyplot as plt import numpy as np @@ -1340,237 +1340,171 @@ def capture_pophelp_dialog(output: pathlib.Path, function_name: str) -> None: _grab_dialog(dialog, output, app) +CaptureHandler = Callable[[pathlib.Path], None] + + +def _capture_with(function: Callable[..., None], **kwargs: Any) -> CaptureHandler: + def handler(output: pathlib.Path) -> None: + function(output, **kwargs) + + return handler + + +def _capture_case_handlers() -> dict[str, CaptureHandler]: + handlers: dict[str, CaptureHandler] = { + "adjust_events_dialog": capture_adjust_events_dialog, + "main_window": capture_main_window, + "main_window_continuous": _capture_with(capture_main_window, state="continuous"), + "main_window_epoched": _capture_with(capture_main_window, state="epoched"), + "main_window_multiple": _capture_with(capture_main_window, state="multiple"), + "main_window_study": _capture_with(capture_main_window, state="study"), + "file_menu": _capture_with(capture_main_window, menu_label="File"), + "edit_menu": _capture_with(capture_main_window, menu_label="Edit"), + "tools_menu": _capture_with(capture_main_window, menu_label="Tools"), + "plot_menu": _capture_with(capture_main_window, menu_label="Plot"), + "study_menu": _capture_with(capture_main_window, menu_label="Study"), + "datasets_menu": _capture_with(capture_main_window, menu_label="Datasets"), + "help_menu": _capture_with(capture_main_window, menu_label="Help"), + "pop_comments_dialog": capture_pop_comments_dialog, + "pop_editset_dialog": capture_pop_editset_dialog, + "pop_editeventfield_dialog": capture_pop_editeventfield_dialog, + "pop_editeventvals_dialog": capture_pop_editeventvals_dialog, + "pop_selectevent_dialog": capture_pop_selectevent_dialog, + "pop_rmdat_dialog": capture_pop_rmdat_dialog, + "pop_chanedit_dialog": capture_pop_chanedit_dialog, + "pop_copyset_dialog": capture_pop_copyset_dialog, + "pop_mergeset_dialog": capture_pop_mergeset_dialog, + "pop_study_dialog": capture_pop_study_dialog, + "pop_studydesign_dialog": capture_pop_studydesign_dialog, + "pop_precomp_dialog": capture_pop_precomp_dialog, + "pop_preclust_dialog": capture_pop_preclust_dialog, + "pop_clust_dialog": capture_pop_clust_dialog, + "pop_chanplot_dialog": capture_pop_chanplot_dialog, + "pop_clustedit_dialog": capture_pop_clustedit_dialog, + "reref_dialog": capture_reref_dialog, + "reref_dialog_channel_ref": _capture_with(capture_reref_dialog, variant="channels"), + "reref_dialog_huber_ref": _capture_with(capture_reref_dialog, variant="huber"), + "reref_dialog_interp_removed": _capture_with(capture_reref_dialog, variant="interp_removed"), + "pop_interp_dialog": capture_pop_interp_dialog, + "pop_interp_removed_dialog": _capture_with(capture_pop_interp_dialog, variant="removed"), + "pop_interp_epoched_dialog": _capture_with(capture_pop_interp_dialog, variant="epoched"), + "pop_select_dialog": capture_pop_select_dialog, + "pop_resample_dialog": capture_pop_resample_dialog, + "pop_newset_dialog": capture_pop_newset_dialog, + "pop_rmbase_dialog": capture_pop_rmbase_dialog, + "pop_eegfilt_dialog": capture_pop_eegfilt_dialog, + "pop_eegfiltnew_dialog": capture_pop_eegfiltnew_dialog, + "pop_firws_dialog": capture_pop_firws_dialog, + "pop_firpm_dialog": capture_pop_firpm_dialog, + "pop_firma_dialog": capture_pop_firma_dialog, + "pop_kaiserbeta_dialog": capture_pop_kaiserbeta_dialog, + "pop_firwsord_dialog": capture_pop_firwsord_dialog, + "pop_firpmord_dialog": capture_pop_firpmord_dialog, + "pop_xfirws_dialog": capture_pop_xfirws_dialog, + "pop_epoch_dialog": capture_pop_epoch_dialog, + "pop_topoplot_erp_dialog": _capture_with(capture_pop_topoplot_dialog, variant="erp"), + "pop_topoplot_components_dialog": _capture_with(capture_pop_topoplot_dialog, variant="components"), + "pop_spectopo_channels_dialog": _capture_with(capture_pop_spectopo_dialog, variant="channels"), + "pop_spectopo_components_dialog": _capture_with(capture_pop_spectopo_dialog, variant="components"), + "pop_prop_channels_dialog": _capture_with(capture_pop_prop_dialog, variant="channels"), + "pop_prop_components_dialog": _capture_with(capture_pop_prop_dialog, variant="components"), + "pop_timtopo_dialog": capture_pop_timtopo_dialog, + "pop_plottopo_dialog": capture_pop_plottopo_dialog, + "pop_headplot_erp_dialog": _capture_with(capture_pop_headplot_dialog, variant="erp"), + "pop_headplot_components_dialog": _capture_with(capture_pop_headplot_dialog, variant="components"), + "coregister_dialog": capture_coregister_dialog, + "pop_plotdata_dialog": capture_pop_plotdata_dialog, + "pop_erpimage_channels_dialog": _capture_with(capture_pop_erpimage_dialog, variant="channels"), + "pop_erpimage_components_dialog": _capture_with(capture_pop_erpimage_dialog, variant="components"), + "pop_envtopo_dialog": capture_pop_envtopo_dialog, + "pop_comperp_channels_dialog": _capture_with(capture_pop_comperp_dialog, variant="channels"), + "pop_comperp_components_dialog": _capture_with(capture_pop_comperp_dialog, variant="components"), + "pop_newtimef_channels_dialog": _capture_with(capture_pop_newtimef_dialog, variant="channels"), + "pop_newtimef_components_dialog": _capture_with(capture_pop_newtimef_dialog, variant="components"), + "pop_newcrossf_channels_dialog": _capture_with(capture_pop_newcrossf_dialog, variant="channels"), + "pop_newcrossf_components_dialog": _capture_with(capture_pop_newcrossf_dialog, variant="components"), + "pop_signalstat_channels_dialog": _capture_with(capture_pop_signalstat_dialog, variant="channels"), + "pop_signalstat_components_dialog": _capture_with(capture_pop_signalstat_dialog, variant="components"), + "pop_eventstat_dialog": capture_pop_eventstat_dialog, + "pop_runica_dialog": capture_pop_runica_dialog, + "pop_runica_multiple_dialog": capture_pop_runica_multiple_dialog, + "pop_iclabel_dialog": capture_pop_iclabel_dialog, + "pop_icflag_dialog": capture_pop_icflag_dialog, + "iclabel_pop_prop_extended_dashboard": capture_pop_prop_extended_dashboard, + "pop_subcomp_dialog": capture_pop_subcomp_dialog, + "pop_clean_rawdata_dialog": capture_pop_clean_rawdata_dialog, + "pop_chansel_dialog": capture_pop_chansel_dialog, + "select_multiple_datasets_dialog": capture_select_multiple_datasets_dialog, + "pop_interp_dataset_index_dialog": capture_dataset_index_dialog, + "pop_reref_help_dialog": _capture_with(capture_pophelp_dialog, function_name="pop_reref"), + "pop_interp_help_dialog": _capture_with(capture_pophelp_dialog, function_name="pop_interp"), + } + handlers.update( + { + f"eegbrowser_{variant}": _capture_with(capture_eegbrowser, variant=variant) + for variant in ( + "continuous", + "continuous_marked", + "epoched", + "epoched_marked", + "events", + "grid_off", + "labels", + "component_activity", + "data2_overlay", + "spectral_overlay", + "pop_eegplot_reject_data", + "rejcont_continuous", + "rejection_epochs", + ) + } + ) + handlers.update( + { + case_id: _capture_with(capture_rejection_dialog, case_id=case_id) + for case_id in ( + "pop_autorej_dialog", + "pop_eegthresh_dialog", + "pop_jointprob_dialog", + "pop_rejchan_dialog", + "pop_rejcont_dialog", + "pop_rejkurt_dialog", + "pop_rejmenu_dialog", + "pop_rejspec_dialog", + "pop_rejtrend_dialog", + "pop_selectcomps_dialog", + "pop_viewprops_dialog", + ) + } + ) + handlers.update( + { + case_id: _capture_with(capture_dipfit_dialog, case_id=case_id) + for case_id in ( + "pop_dipfit_settings_dialog", + "pop_dipfit_headmodel_dialog", + "pop_dipfit_gridsearch_dialog", + "pop_dipfit_nonlinear_dialog", + "pop_dipplot_dialog", + "pop_multifit_dialog", + "pop_leadfield_dialog", + "pop_dipfit_loreta_dialog", + ) + } + ) + return handlers + + def main(argv: list[str] | None = None) -> int: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--case", required=True) parser.add_argument("--output", required=True, type=pathlib.Path) args = parser.parse_args(argv) - if args.case == "adjust_events_dialog": - capture_adjust_events_dialog(args.output) - elif args.case == "main_window": - capture_main_window(args.output) - elif args.case == "main_window_continuous": - capture_main_window(args.output, state="continuous") - elif args.case == "main_window_epoched": - capture_main_window(args.output, state="epoched") - elif args.case == "main_window_multiple": - capture_main_window(args.output, state="multiple") - elif args.case == "main_window_study": - capture_main_window(args.output, state="study") - elif args.case == "eegbrowser_continuous": - capture_eegbrowser(args.output, variant="continuous") - elif args.case == "eegbrowser_continuous_marked": - capture_eegbrowser(args.output, variant="continuous_marked") - elif args.case == "eegbrowser_epoched": - capture_eegbrowser(args.output, variant="epoched") - elif args.case == "eegbrowser_epoched_marked": - capture_eegbrowser(args.output, variant="epoched_marked") - elif args.case == "eegbrowser_events": - capture_eegbrowser(args.output, variant="events") - elif args.case == "eegbrowser_grid_off": - capture_eegbrowser(args.output, variant="grid_off") - elif args.case == "eegbrowser_labels": - capture_eegbrowser(args.output, variant="labels") - elif args.case == "eegbrowser_component_activity": - capture_eegbrowser(args.output, variant="component_activity") - elif args.case == "eegbrowser_data2_overlay": - capture_eegbrowser(args.output, variant="data2_overlay") - elif args.case == "eegbrowser_spectral_overlay": - capture_eegbrowser(args.output, variant="spectral_overlay") - elif args.case == "eegbrowser_pop_eegplot_reject_data": - capture_eegbrowser(args.output, variant="pop_eegplot_reject_data") - elif args.case == "eegbrowser_rejcont_continuous": - capture_eegbrowser(args.output, variant="rejcont_continuous") - elif args.case == "eegbrowser_rejection_epochs": - capture_eegbrowser(args.output, variant="rejection_epochs") - elif args.case == "file_menu": - capture_main_window(args.output, menu_label="File") - elif args.case == "edit_menu": - capture_main_window(args.output, menu_label="Edit") - elif args.case == "tools_menu": - capture_main_window(args.output, menu_label="Tools") - elif args.case == "plot_menu": - capture_main_window(args.output, menu_label="Plot") - elif args.case == "study_menu": - capture_main_window(args.output, menu_label="Study") - elif args.case == "datasets_menu": - capture_main_window(args.output, menu_label="Datasets") - elif args.case == "help_menu": - capture_main_window(args.output, menu_label="Help") - elif args.case == "pop_comments_dialog": - capture_pop_comments_dialog(args.output) - elif args.case == "pop_editset_dialog": - capture_pop_editset_dialog(args.output) - elif args.case == "pop_editeventfield_dialog": - capture_pop_editeventfield_dialog(args.output) - elif args.case == "pop_editeventvals_dialog": - capture_pop_editeventvals_dialog(args.output) - elif args.case == "pop_selectevent_dialog": - capture_pop_selectevent_dialog(args.output) - elif args.case == "pop_rmdat_dialog": - capture_pop_rmdat_dialog(args.output) - elif args.case == "pop_chanedit_dialog": - capture_pop_chanedit_dialog(args.output) - elif args.case == "pop_copyset_dialog": - capture_pop_copyset_dialog(args.output) - elif args.case == "pop_mergeset_dialog": - capture_pop_mergeset_dialog(args.output) - elif args.case == "pop_study_dialog": - capture_pop_study_dialog(args.output) - elif args.case == "pop_studydesign_dialog": - capture_pop_studydesign_dialog(args.output) - elif args.case == "pop_precomp_dialog": - capture_pop_precomp_dialog(args.output) - elif args.case == "pop_preclust_dialog": - capture_pop_preclust_dialog(args.output) - elif args.case == "pop_clust_dialog": - capture_pop_clust_dialog(args.output) - elif args.case == "pop_chanplot_dialog": - capture_pop_chanplot_dialog(args.output) - elif args.case == "pop_clustedit_dialog": - capture_pop_clustedit_dialog(args.output) - elif args.case == "reref_dialog": - capture_reref_dialog(args.output) - elif args.case == "reref_dialog_channel_ref": - capture_reref_dialog(args.output, variant="channels") - elif args.case == "reref_dialog_huber_ref": - capture_reref_dialog(args.output, variant="huber") - elif args.case == "reref_dialog_interp_removed": - capture_reref_dialog(args.output, variant="interp_removed") - elif args.case == "pop_interp_dialog": - capture_pop_interp_dialog(args.output) - elif args.case == "pop_interp_removed_dialog": - capture_pop_interp_dialog(args.output, variant="removed") - elif args.case == "pop_interp_epoched_dialog": - capture_pop_interp_dialog(args.output, variant="epoched") - elif args.case == "pop_select_dialog": - capture_pop_select_dialog(args.output) - elif args.case == "pop_resample_dialog": - capture_pop_resample_dialog(args.output) - elif args.case == "pop_newset_dialog": - capture_pop_newset_dialog(args.output) - elif args.case == "pop_rmbase_dialog": - capture_pop_rmbase_dialog(args.output) - elif args.case == "pop_eegfilt_dialog": - capture_pop_eegfilt_dialog(args.output) - elif args.case == "pop_eegfiltnew_dialog": - capture_pop_eegfiltnew_dialog(args.output) - elif args.case == "pop_firws_dialog": - capture_pop_firws_dialog(args.output) - elif args.case == "pop_firpm_dialog": - capture_pop_firpm_dialog(args.output) - elif args.case == "pop_firma_dialog": - capture_pop_firma_dialog(args.output) - elif args.case == "pop_kaiserbeta_dialog": - capture_pop_kaiserbeta_dialog(args.output) - elif args.case == "pop_firwsord_dialog": - capture_pop_firwsord_dialog(args.output) - elif args.case == "pop_firpmord_dialog": - capture_pop_firpmord_dialog(args.output) - elif args.case == "pop_xfirws_dialog": - capture_pop_xfirws_dialog(args.output) - elif args.case == "pop_epoch_dialog": - capture_pop_epoch_dialog(args.output) - elif args.case == "pop_topoplot_erp_dialog": - capture_pop_topoplot_dialog(args.output, variant="erp") - elif args.case == "pop_topoplot_components_dialog": - capture_pop_topoplot_dialog(args.output, variant="components") - elif args.case == "pop_spectopo_channels_dialog": - capture_pop_spectopo_dialog(args.output, variant="channels") - elif args.case == "pop_spectopo_components_dialog": - capture_pop_spectopo_dialog(args.output, variant="components") - elif args.case == "pop_prop_channels_dialog": - capture_pop_prop_dialog(args.output, variant="channels") - elif args.case == "pop_prop_components_dialog": - capture_pop_prop_dialog(args.output, variant="components") - elif args.case == "pop_timtopo_dialog": - capture_pop_timtopo_dialog(args.output) - elif args.case == "pop_plottopo_dialog": - capture_pop_plottopo_dialog(args.output) - elif args.case == "pop_headplot_erp_dialog": - capture_pop_headplot_dialog(args.output, variant="erp") - elif args.case == "pop_headplot_components_dialog": - capture_pop_headplot_dialog(args.output, variant="components") - elif args.case == "coregister_dialog": - capture_coregister_dialog(args.output) - elif args.case == "pop_plotdata_dialog": - capture_pop_plotdata_dialog(args.output) - elif args.case == "pop_erpimage_channels_dialog": - capture_pop_erpimage_dialog(args.output, variant="channels") - elif args.case == "pop_erpimage_components_dialog": - capture_pop_erpimage_dialog(args.output, variant="components") - elif args.case == "pop_envtopo_dialog": - capture_pop_envtopo_dialog(args.output) - elif args.case == "pop_comperp_channels_dialog": - capture_pop_comperp_dialog(args.output, variant="channels") - elif args.case == "pop_comperp_components_dialog": - capture_pop_comperp_dialog(args.output, variant="components") - elif args.case == "pop_newtimef_channels_dialog": - capture_pop_newtimef_dialog(args.output, variant="channels") - elif args.case == "pop_newtimef_components_dialog": - capture_pop_newtimef_dialog(args.output, variant="components") - elif args.case == "pop_newcrossf_channels_dialog": - capture_pop_newcrossf_dialog(args.output, variant="channels") - elif args.case == "pop_newcrossf_components_dialog": - capture_pop_newcrossf_dialog(args.output, variant="components") - elif args.case == "pop_signalstat_channels_dialog": - capture_pop_signalstat_dialog(args.output, variant="channels") - elif args.case == "pop_signalstat_components_dialog": - capture_pop_signalstat_dialog(args.output, variant="components") - elif args.case == "pop_eventstat_dialog": - capture_pop_eventstat_dialog(args.output) - elif args.case == "pop_runica_dialog": - capture_pop_runica_dialog(args.output) - elif args.case == "pop_runica_multiple_dialog": - capture_pop_runica_multiple_dialog(args.output) - elif args.case == "pop_iclabel_dialog": - capture_pop_iclabel_dialog(args.output) - elif args.case == "pop_icflag_dialog": - capture_pop_icflag_dialog(args.output) - elif args.case == "iclabel_pop_prop_extended_dashboard": - capture_pop_prop_extended_dashboard(args.output) - elif args.case == "pop_subcomp_dialog": - capture_pop_subcomp_dialog(args.output) - elif args.case in { - "pop_autorej_dialog", - "pop_eegthresh_dialog", - "pop_jointprob_dialog", - "pop_rejchan_dialog", - "pop_rejcont_dialog", - "pop_rejkurt_dialog", - "pop_rejmenu_dialog", - "pop_rejspec_dialog", - "pop_rejtrend_dialog", - "pop_selectcomps_dialog", - "pop_viewprops_dialog", - }: - capture_rejection_dialog(args.output, case_id=args.case) - elif args.case in { - "pop_dipfit_settings_dialog", - "pop_dipfit_headmodel_dialog", - "pop_dipfit_gridsearch_dialog", - "pop_dipfit_nonlinear_dialog", - "pop_dipplot_dialog", - "pop_multifit_dialog", - "pop_leadfield_dialog", - "pop_dipfit_loreta_dialog", - }: - capture_dipfit_dialog(args.output, case_id=args.case) - elif args.case == "pop_clean_rawdata_dialog": - capture_pop_clean_rawdata_dialog(args.output) - elif args.case == "pop_chansel_dialog": - capture_pop_chansel_dialog(args.output) - elif args.case == "select_multiple_datasets_dialog": - capture_select_multiple_datasets_dialog(args.output) - elif args.case == "pop_interp_dataset_index_dialog": - capture_dataset_index_dialog(args.output) - elif args.case == "pop_reref_help_dialog": - capture_pophelp_dialog(args.output, "pop_reref") - elif args.case == "pop_interp_help_dialog": - capture_pophelp_dialog(args.output, "pop_interp") - else: + handler = _capture_case_handlers().get(args.case) + if handler is None: parser.error(f"unsupported EEGPrep visual capture case: {args.case}") + handler(args.output) return 0 diff --git a/uv.lock b/uv.lock index 4a78a0a0..93fb5319 100644 --- a/uv.lock +++ b/uv.lock @@ -693,7 +693,6 @@ dependencies = [ [package.optional-dependencies] all = [ - { name = "eeglabio" }, { name = "ipython", version = "8.39.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "ipython", version = "9.13.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "myst-parser", version = "4.0.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, @@ -740,9 +739,6 @@ docs = [ { name = "sphinx-togglebutton" }, { name = "sphinxcontrib-spelling" }, ] -eeglabio = [ - { name = "eeglabio" }, -] gui = [ { name = "pyqtgraph" }, { name = "pyside6" }, @@ -753,6 +749,10 @@ torch = [ [package.dev-dependencies] dev = [ + { name = "ipython", version = "8.39.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "ipython", version = "9.13.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "pyqtgraph" }, + { name = "pyside6" }, { name = "pytest" }, { name = "ruff" }, { name = "tomli", marker = "python_full_version < '3.11'" }, @@ -767,10 +767,8 @@ release = [ requires-dist = [ { name = "colorama", specifier = ">=0.4.6" }, { name = "eeglabio", specifier = ">=0.1.2" }, - { name = "eeglabio", marker = "extra == 'eeglabio'", specifier = ">=0.1.2" }, { name = "eegprep", extras = ["console"], marker = "extra == 'all'" }, { name = "eegprep", extras = ["docs"], marker = "extra == 'all'" }, - { name = "eegprep", extras = ["eeglabio"], marker = "extra == 'all'" }, { name = "eegprep", extras = ["gui"], marker = "extra == 'all'" }, { name = "eegprep", extras = ["gui"], marker = "extra == 'console'" }, { name = "eegprep", extras = ["torch"], marker = "extra == 'all'" }, @@ -804,10 +802,13 @@ requires-dist = [ { name = "threadpoolctl", specifier = ">=3.6.0" }, { name = "torch", marker = "extra == 'torch'", specifier = ">=2.0" }, ] -provides-extras = ["torch", "eeglabio", "gui", "console", "docs", "all"] +provides-extras = ["torch", "gui", "console", "docs", "all"] [package.metadata.requires-dev] dev = [ + { name = "ipython", specifier = ">=8.0" }, + { name = "pyqtgraph", specifier = ">=0.13.7" }, + { name = "pyside6", specifier = ">=6.6" }, { name = "pytest", specifier = ">=8.0" }, { name = "ruff", specifier = ">=0.15.14" }, { name = "tomli", marker = "python_full_version < '3.11'", specifier = ">=2.0" },