From 64a810f4e480ca5df5092bb167ef0eee291b9681 Mon Sep 17 00:00:00 2001 From: Mohammad Askari Date: Sat, 23 May 2026 12:29:30 -0700 Subject: [PATCH] Polish standalone docs and add notebook examples --- README.md | 72 +++--- docs/data.md | 24 +- docs/examples/00_data_postprocessing.ipynb | 84 +++++++ .../examples/01_tensorflow_reproduction.ipynb | 89 ++++++++ docs/examples/02_pytorch_reproduction.ipynb | 212 ++++++++++++++++++ docs/examples/03_jax_reproduction.ipynb | 89 ++++++++ docs/examples/index.md | 18 ++ docs/index.md | 27 ++- docs/reproduce.md | 25 ++- docs/results.md | 13 +- mkdocs.yml | 9 +- notebooks/00_data_postprocessing.ipynb | 95 +++++++- notebooks/01_tensorflow_reproduction.ipynb | 8 +- notebooks/02_pytorch_reproduction.ipynb | 24 +- notebooks/03_jax_reproduction.ipynb | 8 +- src/eegclassify/augmentation.py | 8 +- src/eegclassify/bci2a.py | 8 +- src/eegclassify/cli.py | 2 +- src/eegclassify/config.py | 2 +- src/eegclassify/data.py | 6 +- src/eegclassify/models/tensorflow.py | 4 +- src/eegclassify/preprocessing.py | 4 +- 22 files changed, 721 insertions(+), 110 deletions(-) create mode 100644 docs/examples/00_data_postprocessing.ipynb create mode 100644 docs/examples/01_tensorflow_reproduction.ipynb create mode 100644 docs/examples/02_pytorch_reproduction.ipynb create mode 100644 docs/examples/03_jax_reproduction.ipynb create mode 100644 docs/examples/index.md diff --git a/README.md b/README.md index 08913fa..f06fcdb 100644 --- a/README.md +++ b/README.md @@ -6,15 +6,18 @@ # EEGClassify -This repository contains `eegclassify`, a reproducible EEG classification package built from the original C247 project notebooks. It includes CNN, LSTM, CNN-LSTM, and GAN-augmented classifier experiments across TensorFlow, PyTorch, and JAX/Flax. +`eegclassify` is a standalone Python package for reproducible EEG task classification. It provides a consistent data workflow, reusable preprocessing utilities, and CNN, LSTM, CNN-LSTM, and GAN-augmented training paths across TensorFlow, PyTorch, and JAX/Flax. -The original project write-up is available as [Report.pdf](Report.pdf). The original notebook code is preserved under [`notebooks/legacy/`](notebooks/legacy/) for reference, while the reproducible package-oriented notebooks live directly under [`notebooks/`](notebooks/). +Use it when you want a clean package interface for running EEG classification experiments, comparing frameworks, or regenerating the processed BCI Competition IV-2a arrays from the public raw data. -The public data source is BCI Competition IV, data set 2a: +## What Is Included -- Dataset page: https://www.bbci.de/competition/iv/ -- Raw GDF archive: https://www.bbci.de/competition/download/competition_iv/BCICIV_2a_gdf.zip -- True-label archive: https://www.bbci.de/competition/iv/results/ds2a/true_labels.zip +- Reproducible BCI Competition IV data set 2a download and conversion tools. +- Processed `.npy` data files tracked with Git LFS for quick local examples. +- Shared experiment settings for labels, splits, augmentation, seeds, and batching. +- TensorFlow/Keras, PyTorch, and JAX/Flax model implementations. +- Reusable GAN augmentation for classifier training data. +- Jupyter notebooks and CLI commands for smoke tests and full reproductions. ## Quick Start @@ -25,33 +28,38 @@ python -m pip install -U pip python -m pip install -e ".[data,notebooks,docs,dev]" ``` -Download and prepare data: +## Data Preparation -```bash -python -m eegclassify.cli download-bci2a --raw-dir data/raw -python -m eegclassify.cli prepare-bci2a --raw-dir data/raw --output-dir data/processed -python -m eegclassify.cli compare-data --generated-dir data/processed --reference-dir data_temp -``` +The public source of truth is BCI Competition IV, data set 2a. + +- Dataset page: https://www.bbci.de/competition/iv/ +- Raw GDF archive: https://www.bbci.de/competition/download/competition_iv/BCICIV_2a_gdf.zip +- True-label archive: https://www.bbci.de/competition/iv/results/ds2a/true_labels.zip -Run tests and build documentation: +Download and regenerate the processed arrays: ```bash -pytest -mkdocs build --strict +python -m eegclassify.cli download-bci2a --raw-dir data/raw +python -m eegclassify.cli prepare-bci2a \ + --raw-dir data/raw \ + --output-dir data/processed_regenerated +python -m eegclassify.cli compare-data \ + --generated-dir data/processed_regenerated \ + --reference-dir data/processed ``` ## Notebook Examples -The main reproducible Jupyter notebook examples are: +The runnable notebooks live under [`notebooks/`](notebooks/) and are also rendered in the documentation site. - [Data postprocessing](notebooks/00_data_postprocessing.ipynb) - [TensorFlow reproduction](notebooks/01_tensorflow_reproduction.ipynb) - [PyTorch reproduction](notebooks/02_pytorch_reproduction.ipynb) - [JAX/Flax reproduction](notebooks/03_jax_reproduction.ipynb) -The original project notebooks are preserved in [notebooks/legacy](notebooks/legacy/). +Each notebook defaults to `FAST_DEV_RUN=True` so it can be opened and executed quickly before launching longer training runs. -Run a quick local training smoke test: +## CLI Training ```bash python -m eegclassify.cli train --framework tensorflow --model cnn --fast-dev-run @@ -59,13 +67,18 @@ python -m eegclassify.cli train --framework pytorch --model cnn --fast-dev-run python -m eegclassify.cli train --framework jax --model cnn --fast-dev-run ``` -Use the original GAN-CNN parity path with: +Use GAN augmentation with the canonical CNN path: ```bash -python -m eegclassify.cli train --framework tensorflow --model gan_cnn --use-gan-augmentation +python -m eegclassify.cli train \ + --framework tensorflow \ + --model gan_cnn \ + --use-gan-augmentation ``` -For NVIDIA GPU support on Linux, install: +## GPU Extras + +For NVIDIA GPU support on Linux, install the framework extras you need: ```bash python -m pip install -e ".[tensorflow,pytorch,jax-cuda13]" @@ -73,10 +86,19 @@ python -m pip install -e ".[tensorflow,pytorch,jax-cuda13]" The TensorFlow CLI configures the pip-installed CUDA library paths before importing TensorFlow. -The processed files in `data/processed/*.npy` are intended to be committed through Git LFS. They are excluded from PyPI builds. +## Documentation And Package Notes + +- Documentation: https://mhmdaskari.github.io/EEG_DL_Classification/ +- Processed files under `data/processed/*.npy` are tracked with Git LFS. +- Processed arrays are excluded from PyPI builds; package users can download or regenerate them with the CLI. -Documentation is configured for GitHub Pages at: +Run checks locally: + +```bash +pytest +mkdocs build --strict +``` -https://mhmdaskari.github.io/EEG_DL_Classification/ +This package was first initiated as part of a UCLA C247 course project; the original report is available as [Report.pdf](Report.pdf). -This project began as part of UCLA C247, Deep Learning and Neural Networks, Winter 2023. See [Report.pdf](Report.pdf) for the report associated with that work. +Archived notebook-only prototypes are kept under [`notebooks/legacy/`](notebooks/legacy/) for historical reference. diff --git a/docs/data.md b/docs/data.md index eb7a329..b6cf3c5 100644 --- a/docs/data.md +++ b/docs/data.md @@ -6,7 +6,7 @@ The public source of truth is BCI Competition IV, data set 2a. - Raw GDF archive: - True-label archive: -The original notebooks used six processed files: +The packaged processed dataset uses six files: - `X_train_valid.npy` - `y_train_valid.npy` @@ -16,7 +16,7 @@ The original notebooks used six processed files: - `person_test.npy` Post-processed `.npy` files under `data/processed/` are versioned with Git LFS so the -notebooks can run from a fresh clone. Raw archives, extracted GDF files, local caches, +examples can run from a fresh clone. Raw archives, extracted GDF files, local caches, and intermediate conversion folders stay out of Git. Before adding or updating processed arrays, install Git LFS and initialize it once: @@ -44,26 +44,30 @@ should download or regenerate the data with the commands below. python -m eegclassify.cli download-bci2a --raw-dir data/raw ``` -## Regenerate The Project Subset +## Regenerate The Processed Dataset ```bash python -m eegclassify.cli prepare-bci2a \ --raw-dir data/raw \ - --output-dir data/processed \ - --manifest data/processed/manifest.json + --output-dir data/processed_regenerated \ + --manifest data/processed_regenerated/manifest.json ``` -The default converter uses the training GDF files `A01T.gdf` through `A09T.gdf`, extracts the first 22 EEG channels, uses 1000 samples starting one sample after cue onset, and applies the built-in course split map inferred from the original `.npy` cache. A trial-block split is still available with `--split-map none`. +The default converter uses the training GDF files `A01T.gdf` through `A09T.gdf`, +extracts the first 22 EEG channels, uses 1000 samples starting one sample after cue +onset, and applies the package split map. A trial-block split is still available +with `--split-map none`. -## Compare With Local Cache +## Compare Processed Outputs ```bash python -m eegclassify.cli compare-data \ - --generated-dir data/processed \ - --reference-dir data_temp \ + --generated-dir data/processed_regenerated \ + --reference-dir data/processed \ --output artifacts/data_comparison.json ``` The comparison report includes shapes, dtypes, SHA256 hashes, label counts, subject counts, exact-match flags, and numeric differences. -The notebook `notebooks/00_data_postprocessing.ipynb` runs the same workflow interactively. +The [data postprocessing notebook](examples/00_data_postprocessing.ipynb) runs the +same workflow interactively. diff --git a/docs/examples/00_data_postprocessing.ipynb b/docs/examples/00_data_postprocessing.ipynb new file mode 100644 index 0000000..737ac8c --- /dev/null +++ b/docs/examples/00_data_postprocessing.ipynb @@ -0,0 +1,84 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": "# BCI IV 2a Data Postprocessing\n\nThis notebook makes the dataset lineage explicit. It downloads the official BCI Competition IV data set 2a archives, recreates the six package processed `.npy` files, writes a manifest, and optionally compares regenerated files with the packaged processed data.\n" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "## What The Files Mean\n\nThe raw source is the public BCI IV-2a GDF release. The project subset stores train/validation candidate trials, held-out test trials, labels, and subject IDs in six NumPy files. The labels are still the original cue IDs (`769-772`) at this stage; conversion to class IDs happens during preprocessing.\n" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "from pathlib import Path\nimport json\n\nfrom eegclassify.bci2a import BCI2AConversionConfig, convert_bci2a_training_subset, download_bci2a\nfrom eegclassify.data import compare_processed_dirs, summarize_processed_dir, write_json_report\n" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "RAW_DIR = Path(\"../data/raw\")\nPROCESSED_DIR = Path(\"../data/processed\")\nREFERENCE_DIR = Path(\"../data_temp\")\nREPORT_DIR = Path(\"../artifacts\")\n\nDOWNLOAD = True\nREGENERATE_FROM_RAW = True\nCOMPARE_WITH_REFERENCE = True\n" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "## Download Official Archives\n\nSet `DOWNLOAD = False` when the archives are already available under `data/raw/`. Raw files are reproducible from the official URLs and should not be committed to Git.\n" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "if DOWNLOAD:\n paths = download_bci2a(RAW_DIR)\n print(json.dumps({key: str(value) for key, value in paths.items()}, indent=2))\nelse:\n print(\"Skipping download\")\n" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "## Convert Raw GDF Files To The Processed Dataset\n\nThe converter uses the training GDF files `A01T.gdf` through `A09T.gdf`, the first 22 EEG channels, 1000 samples per trial, and the built-in package split map.\n" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "config = BCI2AConversionConfig(\n window_samples=1000,\n window_offset_samples=1,\n test_trial_start=200,\n test_trial_stop=250,\n)\n\nif REGENERATE_FROM_RAW:\n bundle = convert_bci2a_training_subset(RAW_DIR, PROCESSED_DIR, config, PROCESSED_DIR / \"manifest.json\")\n print(\"X_train_valid:\", bundle.X_train_valid.shape)\n print(\"X_test:\", bundle.X_test.shape)\nelse:\n print(\"Skipping raw-to-subset conversion\")\n" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "## Manifest And Comparison\n\nThe manifest records source URLs, shapes, dtypes, label counts, subject counts, file sizes, and SHA256 hashes. If `data_temp/` exists, the comparison report checks whether the regenerated files match the local cache.\n" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "manifest = summarize_processed_dir(PROCESSED_DIR)\nwrite_json_report(manifest, PROCESSED_DIR / \"manifest.json\")\nprint(json.dumps({name: info.get(\"shape\") for name, info in manifest[\"files\"].items()}, indent=2))\n" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "if COMPARE_WITH_REFERENCE and REFERENCE_DIR.exists():\n report = compare_processed_dirs(PROCESSED_DIR, REFERENCE_DIR)\n write_json_report(report, REPORT_DIR / \"data_comparison.json\")\n print(\"all_files_match:\", report[\"all_files_match\"])\nelse:\n print(\"Skipping reference-cache comparison\")\n" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/examples/01_tensorflow_reproduction.ipynb b/docs/examples/01_tensorflow_reproduction.ipynb new file mode 100644 index 0000000..ea63a82 --- /dev/null +++ b/docs/examples/01_tensorflow_reproduction.ipynb @@ -0,0 +1,89 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": "# TensorFlow EEG Classification Reproduction\n\nThis notebook runs the TensorFlow/Keras implementation. It defaults to a small smoke run so the workflow is easy to verify. For the full reproduction, set `FAST_DEV_RUN = False` and leave `DEV_SAMPLE_LIMIT = None`.\n" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "from pathlib import Path\n\nfrom eegclassify.config import DataSplitConfig, ModelConfig, PreprocessingConfig, TrainingConfig\nfrom eegclassify.data import load_processed_arrays\nfrom eegclassify.models.tensorflow import TensorFlowClassifierFactory, TensorFlowGANAugmenter, TensorFlowTrainer\nfrom eegclassify.preprocessing import prepare_splits\n" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "## Configuration\n\n`MODEL_NAME = \"gan_cnn\"` is a convenience alias for `MODEL_NAME = \"cnn\"` plus `USE_GAN_AUGMENTATION = True`. GAN augmentation can also be used with `lstm` or `cnn_lstm`, and GAN+CNN is the canonical GAN-augmented reference target.\n" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "DATA_DIR = Path(\"../data/processed\")\nif not DATA_DIR.exists():\n DATA_DIR = Path(\"../data_temp\")\n\nFAST_DEV_RUN = True\nDEV_SAMPLE_LIMIT = 128 if FAST_DEV_RUN else None\nMODEL_NAME = \"cnn\" # choices: cnn, lstm, cnn_lstm, gan_cnn\nUSE_GAN_AUGMENTATION = False\n\n# Full reproduction settings:\n# FAST_DEV_RUN = False\n# DEV_SAMPLE_LIMIT = None\n# MODEL_NAME = \"cnn\"\n# USE_GAN_AUGMENTATION = False\n\nif MODEL_NAME == \"gan_cnn\":\n USE_GAN_AUGMENTATION = True\n CLASSIFIER_NAME = \"cnn\"\nelse:\n CLASSIFIER_NAME = MODEL_NAME\n\npreprocess = PreprocessingConfig(seed=1)\nsplit = DataSplitConfig(validation_ratio=0.17, seed=1)\ntraining = TrainingConfig(\n fast_dev_run=FAST_DEV_RUN,\n use_gan_augmentation=USE_GAN_AUGMENTATION,\n gan_samples_per_class=10,\n seed=1,\n)\nmodel_config = ModelConfig(max_time_step=preprocess.max_time_step)\nREFERENCE_TARGETS = {\n \"cnn\": 0.7049,\n \"cnn_lstm\": 0.6095,\n \"lstm\": 0.3962,\n \"gan_cnn\": 0.6823,\n}\n" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "## Load And Prepare Data\n\nPreprocessing converts labels `769-772` to `0-3`, creates the validation split with seed `1`, applies the original max-pooling / averaging / subsampling augmentation, and reshapes arrays to `(trials, time, 1, channels)`.\n" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "bundle = load_processed_arrays(DATA_DIR)\nprepared = prepare_splits(bundle, preprocess, split)\n\nx_train, y_train = prepared.x_train, prepared.y_train\nx_valid, y_valid = prepared.x_valid, prepared.y_valid\nx_test, y_test = prepared.x_test, prepared.y_test\n\nif DEV_SAMPLE_LIMIT is not None:\n x_train, y_train = x_train[:DEV_SAMPLE_LIMIT], y_train[:DEV_SAMPLE_LIMIT]\n x_valid, y_valid = x_valid[:DEV_SAMPLE_LIMIT], y_valid[:DEV_SAMPLE_LIMIT]\n x_test, y_test = x_test[:DEV_SAMPLE_LIMIT], y_test[:DEV_SAMPLE_LIMIT]\n\nprint(\"train:\", x_train.shape, y_train.shape)\nprint(\"valid:\", x_valid.shape, y_valid.shape)\nprint(\"test:\", x_test.shape, y_test.shape)\nprint(\"labels are one-hot with columns 0-3 for original cues 769-772\")\n" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "## Optional GAN Augmentation\n\nWhen enabled, a conditional GAN is trained on the training split only. Synthetic samples are appended to `x_train` and `y_train`; validation and test arrays stay untouched so evaluation remains comparable.\n" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "gan_result = None\nif USE_GAN_AUGMENTATION:\n gan_result = TensorFlowGANAugmenter(training, model_config).augment_training_data(x_train, y_train)\n x_train, y_train = gan_result.x_train, gan_result.y_train\n print(\"synthetic samples:\", gan_result.x_synthetic.shape[0])\n print(\"augmented train:\", x_train.shape, y_train.shape)\nelse:\n print(\"GAN augmentation disabled\")\n" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "## Build, Train, And Evaluate\n\nThe classifier architecture is selected after the optional augmentation step. TensorFlow uses categorical cross-entropy, so both one-hot labels and GAN interpolation soft labels are supported.\n" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "factory = TensorFlowClassifierFactory(model_config)\nmodel = factory.build(CLASSIFIER_NAME, learning_rate=training.learning_rate)\nmodel.summary()\n\ntrainer = TensorFlowTrainer(training)\nhistory = trainer.fit(model, x_train, y_train, x_valid, y_valid)\nmetrics = trainer.evaluate(model, x_test, y_test)\nmetrics\n" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "## Compare With Reference Targets\n\nFAST_DEV_RUN metrics are smoke-test numbers. Use the full settings before interpreting the target delta.\n" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "target_key = \"gan_cnn\" if USE_GAN_AUGMENTATION and CLASSIFIER_NAME == \"cnn\" else CLASSIFIER_NAME\ntarget = REFERENCE_TARGETS.get(target_key)\nprint(f\"experiment: {target_key}\")\nprint(f\"test accuracy: {metrics['accuracy']:.4f}\")\nif target is not None:\n delta = metrics[\"accuracy\"] - target\n print(f\"reference target: {target:.4f}\")\n print(f\"delta from target: {delta:+.4f}\")\nelse:\n print(\"No reference target exists for this GAN/classifier combination.\")\n" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/examples/02_pytorch_reproduction.ipynb b/docs/examples/02_pytorch_reproduction.ipynb new file mode 100644 index 0000000..386b324 --- /dev/null +++ b/docs/examples/02_pytorch_reproduction.ipynb @@ -0,0 +1,212 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "9dde7b78", + "metadata": {}, + "source": [ + "# PyTorch EEG Classification Reproduction\n", + "\n", + "This notebook runs the PyTorch implementation with the same preprocessing and experiment defaults as the TensorFlow notebook.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "01ff6e80", + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "from eegclassify.config import DataSplitConfig, ModelConfig, PreprocessingConfig, TrainingConfig\n", + "from eegclassify.data import load_processed_arrays\n", + "from eegclassify.models.pytorch import PyTorchGANAugmenter, PyTorchTrainer, build_classifier\n", + "from eegclassify.preprocessing import prepare_splits\n" + ] + }, + { + "cell_type": "markdown", + "id": "49b14bdf", + "metadata": {}, + "source": [ + "## Configuration\n", + "\n", + "Use `USE_GAN_AUGMENTATION = True` to add synthetic GAN samples before any classifier. The original parity comparison only applies to CNN with GAN augmentation.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c98711d8", + "metadata": {}, + "outputs": [], + "source": [ + "DATA_DIR = Path(\"../data/processed\")\n", + "if not DATA_DIR.exists():\n", + " DATA_DIR = Path(\"../data_temp\")\n", + "\n", + "FAST_DEV_RUN = True\n", + "DEV_SAMPLE_LIMIT = 128 if FAST_DEV_RUN else None\n", + "MODEL_NAME = \"cnn\" # choices: cnn, lstm, cnn_lstm, gan_cnn\n", + "USE_GAN_AUGMENTATION = False\n", + "\n", + "# Full reproduction settings:\n", + "# FAST_DEV_RUN = False\n", + "# DEV_SAMPLE_LIMIT = None\n", + "# MODEL_NAME = \"cnn\"\n", + "# USE_GAN_AUGMENTATION = False\n", + "\n", + "if MODEL_NAME == \"gan_cnn\":\n", + " USE_GAN_AUGMENTATION = True\n", + " CLASSIFIER_NAME = \"cnn\"\n", + "else:\n", + " CLASSIFIER_NAME = MODEL_NAME\n", + "\n", + "preprocess = PreprocessingConfig(seed=1)\n", + "split = DataSplitConfig(validation_ratio=0.17, seed=1)\n", + "training = TrainingConfig(\n", + " fast_dev_run=FAST_DEV_RUN,\n", + " use_gan_augmentation=USE_GAN_AUGMENTATION,\n", + " gan_samples_per_class=10,\n", + " seed=1,\n", + ")\n", + "model_config = ModelConfig(max_time_step=preprocess.max_time_step)\n", + "REFERENCE_TARGETS = {\n", + " \"cnn\": 0.7049,\n", + " \"cnn_lstm\": 0.6095,\n", + " \"lstm\": 0.3962,\n", + " \"gan_cnn\": 0.6823,\n", + "}\n" + ] + }, + { + "cell_type": "markdown", + "id": "b51d64b9", + "metadata": {}, + "source": [ + "## Load And Prepare Data\n", + "\n", + "The package preprocessing applies the default label conversion, seeded validation split, temporal augmentation, and channels-last classifier layout.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f5e7cfef", + "metadata": {}, + "outputs": [], + "source": [ + "bundle = load_processed_arrays(DATA_DIR)\n", + "prepared = prepare_splits(bundle, preprocess, split)\n", + "\n", + "x_train, y_train = prepared.x_train, prepared.y_train\n", + "x_valid, y_valid = prepared.x_valid, prepared.y_valid\n", + "x_test, y_test = prepared.x_test, prepared.y_test\n", + "\n", + "if DEV_SAMPLE_LIMIT is not None:\n", + " x_train, y_train = x_train[:DEV_SAMPLE_LIMIT], y_train[:DEV_SAMPLE_LIMIT]\n", + " x_valid, y_valid = x_valid[:DEV_SAMPLE_LIMIT], y_valid[:DEV_SAMPLE_LIMIT]\n", + " x_test, y_test = x_test[:DEV_SAMPLE_LIMIT], y_test[:DEV_SAMPLE_LIMIT]\n", + "\n", + "print(\"train:\", x_train.shape, y_train.shape)\n", + "print(\"valid:\", x_valid.shape, y_valid.shape)\n", + "print(\"test:\", x_test.shape, y_test.shape)\n", + "print(\"labels are one-hot with columns 0-3 for original cues 769-772\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "393de256", + "metadata": {}, + "source": [ + "## Optional GAN Augmentation\n", + "\n", + "PyTorch uses the same conditional GAN semantics as the TensorFlow implementation. Soft labels generated by interpolation are handled by a distribution cross-entropy loss in the trainer.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "79242ecd", + "metadata": {}, + "outputs": [], + "source": [ + "gan_result = None\n", + "if USE_GAN_AUGMENTATION:\n", + " gan_result = PyTorchGANAugmenter(training, model_config).augment_training_data(x_train, y_train)\n", + " x_train, y_train = gan_result.x_train, gan_result.y_train\n", + " print(\"synthetic samples:\", gan_result.x_synthetic.shape[0])\n", + " print(\"augmented train:\", x_train.shape, y_train.shape)\n", + "else:\n", + " print(\"GAN augmentation disabled\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "0a4a014b", + "metadata": {}, + "source": [ + "## Build, Train, And Evaluate\n", + "\n", + "The trainer automatically uses CUDA when PyTorch can see a compatible GPU; otherwise it runs on CPU.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "acefa0de", + "metadata": {}, + "outputs": [], + "source": [ + "model = build_classifier(CLASSIFIER_NAME, model_config)\n", + "trainer = PyTorchTrainer(training)\n", + "history = trainer.fit(model, x_train, y_train, x_valid, y_valid)\n", + "metrics = trainer.evaluate(model, x_test, y_test)\n", + "metrics\n" + ] + }, + { + "cell_type": "markdown", + "id": "20ae3614", + "metadata": {}, + "source": [ + "## Compare With Reference Targets\n", + "\n", + "FAST_DEV_RUN metrics are smoke-test numbers. Use the full settings before interpreting the target delta.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "953fe2f1", + "metadata": {}, + "outputs": [], + "source": [ + "target_key = \"gan_cnn\" if USE_GAN_AUGMENTATION and CLASSIFIER_NAME == \"cnn\" else CLASSIFIER_NAME\n", + "target = REFERENCE_TARGETS.get(target_key)\n", + "print(f\"experiment: {target_key}\")\n", + "print(f\"test accuracy: {metrics['accuracy']:.4f}\")\n", + "if target is not None:\n", + " delta = metrics[\"accuracy\"] - target\n", + " print(f\"reference target: {target:.4f}\")\n", + " print(f\"delta from target: {delta:+.4f}\")\n", + "else:\n", + " print(\"No reference target exists for this GAN/classifier combination.\")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/examples/03_jax_reproduction.ipynb b/docs/examples/03_jax_reproduction.ipynb new file mode 100644 index 0000000..96c3429 --- /dev/null +++ b/docs/examples/03_jax_reproduction.ipynb @@ -0,0 +1,89 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": "# JAX/Flax EEG Classification Reproduction\n\nThis notebook runs the JAX/Flax implementation. It follows the same data preparation and configuration pattern as the TensorFlow and PyTorch examples.\n" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "from pathlib import Path\n\nfrom eegclassify.config import DataSplitConfig, ModelConfig, PreprocessingConfig, TrainingConfig\nfrom eegclassify.data import load_processed_arrays\nfrom eegclassify.models.jax import JAXGANAugmenter, JAXTrainer, build_classifier\nfrom eegclassify.preprocessing import prepare_splits\n" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "## Configuration\n\nJAX compiles model operations lazily, so the first cell that trains or evaluates a model may take longer than later cells. GAN augmentation is still a training-data step and can be reused with any classifier.\n" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "DATA_DIR = Path(\"../data/processed\")\nif not DATA_DIR.exists():\n DATA_DIR = Path(\"../data_temp\")\n\nFAST_DEV_RUN = True\nDEV_SAMPLE_LIMIT = 128 if FAST_DEV_RUN else None\nMODEL_NAME = \"cnn\" # choices: cnn, lstm, cnn_lstm, gan_cnn\nUSE_GAN_AUGMENTATION = False\n\n# Full reproduction settings:\n# FAST_DEV_RUN = False\n# DEV_SAMPLE_LIMIT = None\n# MODEL_NAME = \"cnn\"\n# USE_GAN_AUGMENTATION = False\n\nif MODEL_NAME == \"gan_cnn\":\n USE_GAN_AUGMENTATION = True\n CLASSIFIER_NAME = \"cnn\"\nelse:\n CLASSIFIER_NAME = MODEL_NAME\n\npreprocess = PreprocessingConfig(seed=1)\nsplit = DataSplitConfig(validation_ratio=0.17, seed=1)\ntraining = TrainingConfig(\n fast_dev_run=FAST_DEV_RUN,\n use_gan_augmentation=USE_GAN_AUGMENTATION,\n gan_samples_per_class=10,\n seed=1,\n)\nmodel_config = ModelConfig(max_time_step=preprocess.max_time_step)\nREFERENCE_TARGETS = {\n \"cnn\": 0.7049,\n \"cnn_lstm\": 0.6095,\n \"lstm\": 0.3962,\n \"gan_cnn\": 0.6823,\n}\n" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "## Load And Prepare Data\n\nThe prepared arrays use one-hot labels. GAN interpolation labels are soft distributions and are supported by the JAX trainer's soft-label cross-entropy.\n" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "bundle = load_processed_arrays(DATA_DIR)\nprepared = prepare_splits(bundle, preprocess, split)\n\nx_train, y_train = prepared.x_train, prepared.y_train\nx_valid, y_valid = prepared.x_valid, prepared.y_valid\nx_test, y_test = prepared.x_test, prepared.y_test\n\nif DEV_SAMPLE_LIMIT is not None:\n x_train, y_train = x_train[:DEV_SAMPLE_LIMIT], y_train[:DEV_SAMPLE_LIMIT]\n x_valid, y_valid = x_valid[:DEV_SAMPLE_LIMIT], y_valid[:DEV_SAMPLE_LIMIT]\n x_test, y_test = x_test[:DEV_SAMPLE_LIMIT], y_test[:DEV_SAMPLE_LIMIT]\n\nprint(\"train:\", x_train.shape, y_train.shape)\nprint(\"valid:\", x_valid.shape, y_valid.shape)\nprint(\"test:\", x_test.shape, y_test.shape)\nprint(\"labels are one-hot with columns 0-3 for original cues 769-772\")\n" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "## Optional GAN Augmentation\n\nSynthetic samples are generated from the training split only. Validation and test data remain real data, keeping validation and test metrics tied to real held-out data.\n" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "gan_result = None\nif USE_GAN_AUGMENTATION:\n gan_result = JAXGANAugmenter(training, model_config).augment_training_data(x_train, y_train)\n x_train, y_train = gan_result.x_train, gan_result.y_train\n print(\"synthetic samples:\", gan_result.x_synthetic.shape[0])\n print(\"augmented train:\", x_train.shape, y_train.shape)\nelse:\n print(\"GAN augmentation disabled\")\n" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "## Build, Train, And Evaluate\n\nThe JAX trainer keeps the final train state internally so evaluation uses the trained parameters and batch-normalization statistics.\n" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "model = build_classifier(CLASSIFIER_NAME, model_config)\ntrainer = JAXTrainer(training)\nhistory = trainer.fit(model, x_train, y_train)\nmetrics = trainer.evaluate(model, x_test, y_test)\nmetrics\n" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "## Compare With Reference Targets\n\nFAST_DEV_RUN metrics are smoke-test numbers. Use the full settings before interpreting the target delta.\n" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "target_key = \"gan_cnn\" if USE_GAN_AUGMENTATION and CLASSIFIER_NAME == \"cnn\" else CLASSIFIER_NAME\ntarget = REFERENCE_TARGETS.get(target_key)\nprint(f\"experiment: {target_key}\")\nprint(f\"test accuracy: {metrics['accuracy']:.4f}\")\nif target is not None:\n delta = metrics[\"accuracy\"] - target\n print(f\"reference target: {target:.4f}\")\n print(f\"delta from target: {delta:+.4f}\")\nelse:\n print(\"No reference target exists for this GAN/classifier combination.\")\n" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/examples/index.md b/docs/examples/index.md new file mode 100644 index 0000000..cea4698 --- /dev/null +++ b/docs/examples/index.md @@ -0,0 +1,18 @@ +# Notebook Examples + +These notebooks provide guided, runnable examples for the main `eegclassify` +workflows. They are rendered in the documentation and also kept as source notebooks +under the repository's `notebooks/` directory. + +Each training notebook starts with `FAST_DEV_RUN=True`. That setting keeps examples +short enough for a first pass while preserving the same configuration cells used for +longer runs. + +## Available Examples + +- [Data postprocessing](00_data_postprocessing.ipynb): download, convert, summarize, and compare processed EEG arrays. +- [TensorFlow reproduction](01_tensorflow_reproduction.ipynb): train and evaluate TensorFlow/Keras classifiers. +- [PyTorch reproduction](02_pytorch_reproduction.ipynb): train and evaluate PyTorch classifiers. +- [JAX/Flax reproduction](03_jax_reproduction.ipynb): train and evaluate JAX/Flax classifiers. + +For scripted runs, use the commands in [Reproduce Results](../reproduce.md). diff --git a/docs/index.md b/docs/index.md index 08636a6..c56dcf1 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,12 +1,25 @@ # EEGClassify -`eegclassify` is a reproducible EEG classification package built from the original C247 project notebooks. It keeps the original CNN, LSTM, CNN-LSTM, and GAN-CNN experiments, while moving the reusable code into a Python package with TensorFlow, PyTorch, and JAX/Flax implementations. +`eegclassify` is a standalone Python package for reproducible EEG task classification. It combines BCI Competition IV-2a data preparation, experiment-ready preprocessing, and deep learning classifiers across TensorFlow, PyTorch, and JAX/Flax. -The project focuses on: +The package is designed for users who want to regenerate the processed dataset, run quick framework smoke tests, and scale the same configuration into full training runs. -- explicit BCI Competition IV data set 2a provenance -- repeatable raw-data conversion into the six `.npy` files used by the original notebooks -- object-oriented model, training, and preprocessing code -- notebooks that can run quick smoke checks or full 100-epoch reproductions +## What You Can Do -The original report remains available as [Report.pdf](https://github.com/mhmdaskari/EEG_DL_Classification/blob/main/Report.pdf), and the original notebook code is preserved under `notebooks/legacy/`. +- Download and convert public BCI Competition IV-2a raw data. +- Load the packaged processed `.npy` arrays for local examples. +- Train CNN, LSTM, and CNN-LSTM classifiers in three frameworks. +- Add conditional GAN-generated samples to the training split. +- Compare local runs against the package reference targets. +- Use notebooks for guided workflows or the CLI for repeatable runs. + +## Start Here + +1. Install the package with the extras you need in [Installation](installation.md). +2. Download or regenerate the EEG arrays in [Data](data.md). +3. Walk through the rendered [Notebook Examples](examples/index.md). +4. Run CLI or notebook experiments with [Reproduce Results](reproduce.md). +5. Check target metrics in [Results](results.md). +6. Build on the package through the [API Reference](api/index.md). + +The default examples use `FAST_DEV_RUN=True`, so they are meant to be approachable first and expandable when you are ready for longer runs. diff --git a/docs/reproduce.md b/docs/reproduce.md index 7fd8fc5..cd32b37 100644 --- a/docs/reproduce.md +++ b/docs/reproduce.md @@ -1,11 +1,13 @@ # Reproduce Results -The reproduction notebooks are parameterized: +The package notebooks and CLI share the same experiment settings, so you can move between an interactive workflow and a scripted run without changing the core configuration. -- `notebooks/00_data_postprocessing.ipynb` -- `notebooks/01_tensorflow_reproduction.ipynb` -- `notebooks/02_pytorch_reproduction.ipynb` -- `notebooks/03_jax_reproduction.ipynb` +The rendered notebook examples are: + +- [Data postprocessing](examples/00_data_postprocessing.ipynb) +- [TensorFlow reproduction](examples/01_tensorflow_reproduction.ipynb) +- [PyTorch reproduction](examples/02_pytorch_reproduction.ipynb) +- [JAX/Flax reproduction](examples/03_jax_reproduction.ipynb) Each training notebook defaults to: @@ -19,7 +21,7 @@ classifier reproduction. Set `USE_GAN_AUGMENTATION = True` to train a conditiona GAN on the training split, append generated samples, and then train the selected classifier on real plus synthetic training data. -The original preprocessing defaults are preserved: +The default experiment settings are: - labels `769-772` are converted to `0-3` - validation ratio is `0.17` @@ -29,13 +31,12 @@ The original preprocessing defaults are preserved: - GAN epochs are `20` - augmentation uses trim, max-pooling, averaging with Gaussian noise, and subsampling -GAN augmentation is available for CNN, LSTM, and CNN-LSTM. The original notebooks -only used GAN augmentation with CNN, so parity checks compare only the GAN+CNN path -against the legacy GAN-CNN result. +GAN augmentation is available for CNN, LSTM, and CNN-LSTM. The GAN+CNN row below is +the canonical GAN-augmented benchmark path for the package. -Original notebook baselines: +Reference accuracy targets: -| Experiment | Original test accuracy | +| Experiment | Reference test accuracy | | --- | ---: | | CNN | 70.49% | | CNN-LSTM | 60.95% | @@ -54,7 +55,7 @@ For GPU runs, install the GPU extras from the installation page and do not set `CUDA_VISIBLE_DEVICES=-1`. TensorFlow GPU training uses the CLI's CUDA path setup; PyTorch and JAX use their framework CUDA packages directly. -For the legacy GAN+CNN parity path: +For the GAN-augmented CNN path: ```bash python -m eegclassify.cli train \ diff --git a/docs/results.md b/docs/results.md index ff18476..4a23285 100644 --- a/docs/results.md +++ b/docs/results.md @@ -1,19 +1,18 @@ # Results -The original notebooks are the parity reference for the package implementation. -Cross-framework parity means matching experiment semantics and landing in the same -accuracy band, not bitwise-identical weights or histories. +The table below lists package reference targets for the default EEG classification +experiments. Cross-framework consistency means matching experiment semantics and +landing in the same accuracy band, not bitwise-identical weights or histories. -| Experiment | Original notebook accuracy | Parity target | +| Experiment | Reference accuracy | Acceptance band | | --- | ---: | --- | | CNN | 70.49% | within about ±5 percentage points | | CNN-LSTM | 60.95% | within about ±5 percentage points | | LSTM | 39.62% | within about ±5 percentage points | | GAN+CNN | 68.23% | within about ±5 percentage points | -GAN augmentation can be applied to any classifier, but only GAN+CNN is compared to -the original GAN notebook because the legacy project did not report GAN+LSTM or -GAN+CNN-LSTM baselines. +GAN augmentation can be applied to any classifier. The reference table includes +GAN+CNN as the canonical GAN-augmented benchmark path. Local runs write their detailed metrics to `artifacts/runs/`. Summarize those runs in this page before publishing a release. diff --git a/mkdocs.yml b/mkdocs.yml index ae00697..26ba07d 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -44,6 +44,12 @@ nav: - Home: index.md - Installation: installation.md - Data: data.md + - Notebook Examples: + - Overview: examples/index.md + - Data Postprocessing: examples/00_data_postprocessing.ipynb + - TensorFlow: examples/01_tensorflow_reproduction.ipynb + - PyTorch: examples/02_pytorch_reproduction.ipynb + - JAX/Flax: examples/03_jax_reproduction.ipynb - Reproduce Results: reproduce.md - Results: results.md - API Reference: @@ -71,7 +77,8 @@ plugins: - search - autorefs - offline - - mkdocs-jupyter + - mkdocs-jupyter: + execute: false - git-revision-date-localized: enable_creation_date: true strict: false diff --git a/notebooks/00_data_postprocessing.ipynb b/notebooks/00_data_postprocessing.ipynb index de4d5d0..c6577d2 100644 --- a/notebooks/00_data_postprocessing.ipynb +++ b/notebooks/00_data_postprocessing.ipynb @@ -2,70 +2,143 @@ "cells": [ { "cell_type": "markdown", + "id": "863713a8", "metadata": {}, - "source": "# BCI IV 2a Data Postprocessing\n\nThis notebook makes the dataset lineage explicit. It downloads the official BCI Competition IV data set 2a archives, recreates the six `.npy` files used by the original notebooks, writes a manifest, and optionally compares the generated files with a local reference cache.\n" + "source": [ + "# BCI IV 2a Data Postprocessing\n", + "\n", + "This notebook makes the dataset lineage explicit. It downloads the official BCI Competition IV data set 2a archives, recreates the six package processed `.npy` files, writes a manifest, and optionally compares regenerated files with the packaged processed data.\n" + ] }, { "cell_type": "markdown", + "id": "42df9c59", "metadata": {}, - "source": "## What The Files Mean\n\nThe raw source is the public BCI IV-2a GDF release. The project subset stores train/validation candidate trials, held-out test trials, labels, and subject IDs in six NumPy files. The labels are still the original cue IDs (`769-772`) at this stage; conversion to class IDs happens during preprocessing.\n" + "source": [ + "## What The Files Mean\n", + "\n", + "The raw source is the public BCI IV-2a GDF release. The project subset stores train/validation candidate trials, held-out test trials, labels, and subject IDs in six NumPy files. The labels are still the original cue IDs (`769-772`) at this stage; conversion to class IDs happens during preprocessing.\n" + ] }, { "cell_type": "code", "execution_count": null, + "id": "6e912ab4", "metadata": {}, "outputs": [], - "source": "from pathlib import Path\nimport json\n\nfrom eegclassify.bci2a import BCI2AConversionConfig, convert_bci2a_training_subset, download_bci2a\nfrom eegclassify.data import compare_processed_dirs, summarize_processed_dir, write_json_report\n" + "source": [ + "from pathlib import Path\n", + "import json\n", + "\n", + "from eegclassify.bci2a import BCI2AConversionConfig, convert_bci2a_training_subset, download_bci2a\n", + "from eegclassify.data import compare_processed_dirs, summarize_processed_dir, write_json_report\n" + ] }, { "cell_type": "code", "execution_count": null, + "id": "bbb59e7a", "metadata": {}, "outputs": [], - "source": "RAW_DIR = Path(\"../data/raw\")\nPROCESSED_DIR = Path(\"../data/processed\")\nREFERENCE_DIR = Path(\"../data_temp\")\nREPORT_DIR = Path(\"../artifacts\")\n\nDOWNLOAD = True\nREGENERATE_FROM_RAW = True\nCOMPARE_WITH_REFERENCE = True\n" + "source": [ + "RAW_DIR = Path(\"../data/raw\")\n", + "PROCESSED_DIR = Path(\"../data/processed\")\n", + "REFERENCE_DIR = Path(\"../data_temp\")\n", + "REPORT_DIR = Path(\"../artifacts\")\n", + "\n", + "DOWNLOAD = True\n", + "REGENERATE_FROM_RAW = True\n", + "COMPARE_WITH_REFERENCE = True\n" + ] }, { "cell_type": "markdown", + "id": "dd0c3e5a", "metadata": {}, - "source": "## Download Official Archives\n\nSet `DOWNLOAD = False` when the archives are already available under `data/raw/`. Raw files are reproducible from the official URLs and should not be committed to Git.\n" + "source": [ + "## Download Official Archives\n", + "\n", + "Set `DOWNLOAD = False` when the archives are already available under `data/raw/`. Raw files are reproducible from the official URLs and should not be committed to Git.\n" + ] }, { "cell_type": "code", "execution_count": null, + "id": "fdaac33e", "metadata": {}, "outputs": [], - "source": "if DOWNLOAD:\n paths = download_bci2a(RAW_DIR)\n print(json.dumps({key: str(value) for key, value in paths.items()}, indent=2))\nelse:\n print(\"Skipping download\")\n" + "source": [ + "if DOWNLOAD:\n", + " paths = download_bci2a(RAW_DIR)\n", + " print(json.dumps({key: str(value) for key, value in paths.items()}, indent=2))\n", + "else:\n", + " print(\"Skipping download\")\n" + ] }, { "cell_type": "markdown", + "id": "5d6c1208", "metadata": {}, - "source": "## Convert Raw GDF Files To The Project Subset\n\nThe converter uses the training GDF files `A01T.gdf` through `A09T.gdf`, the first 22 EEG channels, 1000 samples per trial, and the built-in split map inferred from the original project cache.\n" + "source": [ + "## Convert Raw GDF Files To The Processed Dataset\n", + "\n", + "The converter uses the training GDF files `A01T.gdf` through `A09T.gdf`, the first 22 EEG channels, 1000 samples per trial, and the built-in package split map.\n" + ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": "config = BCI2AConversionConfig(\n window_samples=1000,\n window_offset_samples=1,\n test_trial_start=200,\n test_trial_stop=250,\n)\n\nif REGENERATE_FROM_RAW:\n bundle = convert_bci2a_training_subset(RAW_DIR, PROCESSED_DIR, config, PROCESSED_DIR / \"manifest.json\")\n print(\"X_train_valid:\", bundle.X_train_valid.shape)\n print(\"X_test:\", bundle.X_test.shape)\nelse:\n print(\"Skipping raw-to-subset conversion\")\n" + "source": [ + "config = BCI2AConversionConfig(\n", + " window_samples=1000,\n", + " window_offset_samples=1,\n", + " test_trial_start=200,\n", + " test_trial_stop=250,\n", + ")\n", + "\n", + "if REGENERATE_FROM_RAW:\n", + " bundle = convert_bci2a_training_subset(RAW_DIR, PROCESSED_DIR, config, PROCESSED_DIR / \"manifest.json\")\n", + " print(\"X_train_valid:\", bundle.X_train_valid.shape)\n", + " print(\"X_test:\", bundle.X_test.shape)\n", + "else:\n", + " print(\"Skipping raw-to-subset conversion\")\n" + ] }, { "cell_type": "markdown", "metadata": {}, - "source": "## Manifest And Comparison\n\nThe manifest records source URLs, shapes, dtypes, label counts, subject counts, file sizes, and SHA256 hashes. If `data_temp/` exists, the comparison report checks whether the regenerated files match the local cache.\n" + "source": [ + "## Manifest And Comparison\n", + "\n", + "The manifest records source URLs, shapes, dtypes, label counts, subject counts, file sizes, and SHA256 hashes. If `data_temp/` exists, the comparison report checks whether the regenerated files match the local cache.\n" + ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": "manifest = summarize_processed_dir(PROCESSED_DIR)\nwrite_json_report(manifest, PROCESSED_DIR / \"manifest.json\")\nprint(json.dumps({name: info.get(\"shape\") for name, info in manifest[\"files\"].items()}, indent=2))\n" + "source": [ + "manifest = summarize_processed_dir(PROCESSED_DIR)\n", + "write_json_report(manifest, PROCESSED_DIR / \"manifest.json\")\n", + "print(json.dumps({name: info.get(\"shape\") for name, info in manifest[\"files\"].items()}, indent=2))\n" + ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": "if COMPARE_WITH_REFERENCE and REFERENCE_DIR.exists():\n report = compare_processed_dirs(PROCESSED_DIR, REFERENCE_DIR)\n write_json_report(report, REPORT_DIR / \"data_comparison.json\")\n print(\"all_files_match:\", report[\"all_files_match\"])\nelse:\n print(\"Skipping reference-cache comparison\")\n" + "source": [ + "if COMPARE_WITH_REFERENCE and REFERENCE_DIR.exists():\n", + " report = compare_processed_dirs(PROCESSED_DIR, REFERENCE_DIR)\n", + " write_json_report(report, REPORT_DIR / \"data_comparison.json\")\n", + " print(\"all_files_match:\", report[\"all_files_match\"])\n", + "else:\n", + " print(\"Skipping reference-cache comparison\")\n" + ] } ], "metadata": { diff --git a/notebooks/01_tensorflow_reproduction.ipynb b/notebooks/01_tensorflow_reproduction.ipynb index d4d166a..ea63a82 100644 --- a/notebooks/01_tensorflow_reproduction.ipynb +++ b/notebooks/01_tensorflow_reproduction.ipynb @@ -15,14 +15,14 @@ { "cell_type": "markdown", "metadata": {}, - "source": "## Configuration\n\n`MODEL_NAME = \"gan_cnn\"` is a convenience alias for `MODEL_NAME = \"cnn\"` plus `USE_GAN_AUGMENTATION = True`. GAN augmentation can also be used with `lstm` or `cnn_lstm`, but only GAN+CNN has an original notebook baseline.\n" + "source": "## Configuration\n\n`MODEL_NAME = \"gan_cnn\"` is a convenience alias for `MODEL_NAME = \"cnn\"` plus `USE_GAN_AUGMENTATION = True`. GAN augmentation can also be used with `lstm` or `cnn_lstm`, and GAN+CNN is the canonical GAN-augmented reference target.\n" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": "DATA_DIR = Path(\"../data/processed\")\nif not DATA_DIR.exists():\n DATA_DIR = Path(\"../data_temp\")\n\nFAST_DEV_RUN = True\nDEV_SAMPLE_LIMIT = 128 if FAST_DEV_RUN else None\nMODEL_NAME = \"cnn\" # choices: cnn, lstm, cnn_lstm, gan_cnn\nUSE_GAN_AUGMENTATION = False\n\n# Full reproduction settings:\n# FAST_DEV_RUN = False\n# DEV_SAMPLE_LIMIT = None\n# MODEL_NAME = \"cnn\"\n# USE_GAN_AUGMENTATION = False\n\nif MODEL_NAME == \"gan_cnn\":\n USE_GAN_AUGMENTATION = True\n CLASSIFIER_NAME = \"cnn\"\nelse:\n CLASSIFIER_NAME = MODEL_NAME\n\npreprocess = PreprocessingConfig(seed=1)\nsplit = DataSplitConfig(validation_ratio=0.17, seed=1)\ntraining = TrainingConfig(\n fast_dev_run=FAST_DEV_RUN,\n use_gan_augmentation=USE_GAN_AUGMENTATION,\n gan_samples_per_class=10,\n seed=1,\n)\nmodel_config = ModelConfig(max_time_step=preprocess.max_time_step)\nORIGINAL_BASELINES = {\n \"cnn\": 0.7049,\n \"cnn_lstm\": 0.6095,\n \"lstm\": 0.3962,\n \"gan_cnn\": 0.6823,\n}\n" + "source": "DATA_DIR = Path(\"../data/processed\")\nif not DATA_DIR.exists():\n DATA_DIR = Path(\"../data_temp\")\n\nFAST_DEV_RUN = True\nDEV_SAMPLE_LIMIT = 128 if FAST_DEV_RUN else None\nMODEL_NAME = \"cnn\" # choices: cnn, lstm, cnn_lstm, gan_cnn\nUSE_GAN_AUGMENTATION = False\n\n# Full reproduction settings:\n# FAST_DEV_RUN = False\n# DEV_SAMPLE_LIMIT = None\n# MODEL_NAME = \"cnn\"\n# USE_GAN_AUGMENTATION = False\n\nif MODEL_NAME == \"gan_cnn\":\n USE_GAN_AUGMENTATION = True\n CLASSIFIER_NAME = \"cnn\"\nelse:\n CLASSIFIER_NAME = MODEL_NAME\n\npreprocess = PreprocessingConfig(seed=1)\nsplit = DataSplitConfig(validation_ratio=0.17, seed=1)\ntraining = TrainingConfig(\n fast_dev_run=FAST_DEV_RUN,\n use_gan_augmentation=USE_GAN_AUGMENTATION,\n gan_samples_per_class=10,\n seed=1,\n)\nmodel_config = ModelConfig(max_time_step=preprocess.max_time_step)\nREFERENCE_TARGETS = {\n \"cnn\": 0.7049,\n \"cnn_lstm\": 0.6095,\n \"lstm\": 0.3962,\n \"gan_cnn\": 0.6823,\n}\n" }, { "cell_type": "markdown", @@ -63,14 +63,14 @@ { "cell_type": "markdown", "metadata": {}, - "source": "## Compare With The Original Notebook\n\nFAST_DEV_RUN metrics are smoke-test numbers. Use the full settings before interpreting the baseline delta.\n" + "source": "## Compare With Reference Targets\n\nFAST_DEV_RUN metrics are smoke-test numbers. Use the full settings before interpreting the target delta.\n" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": "baseline_key = \"gan_cnn\" if USE_GAN_AUGMENTATION and CLASSIFIER_NAME == \"cnn\" else CLASSIFIER_NAME\nbaseline = ORIGINAL_BASELINES.get(baseline_key)\nprint(f\"experiment: {baseline_key}\")\nprint(f\"test accuracy: {metrics['accuracy']:.4f}\")\nif baseline is not None:\n delta = metrics[\"accuracy\"] - baseline\n print(f\"original notebook baseline: {baseline:.4f}\")\n print(f\"delta from baseline: {delta:+.4f}\")\nelse:\n print(\"No original notebook baseline exists for this GAN/classifier combination.\")\n" + "source": "target_key = \"gan_cnn\" if USE_GAN_AUGMENTATION and CLASSIFIER_NAME == \"cnn\" else CLASSIFIER_NAME\ntarget = REFERENCE_TARGETS.get(target_key)\nprint(f\"experiment: {target_key}\")\nprint(f\"test accuracy: {metrics['accuracy']:.4f}\")\nif target is not None:\n delta = metrics[\"accuracy\"] - target\n print(f\"reference target: {target:.4f}\")\n print(f\"delta from target: {delta:+.4f}\")\nelse:\n print(\"No reference target exists for this GAN/classifier combination.\")\n" } ], "metadata": { diff --git a/notebooks/02_pytorch_reproduction.ipynb b/notebooks/02_pytorch_reproduction.ipynb index 002c932..386b324 100644 --- a/notebooks/02_pytorch_reproduction.ipynb +++ b/notebooks/02_pytorch_reproduction.ipynb @@ -72,7 +72,7 @@ " seed=1,\n", ")\n", "model_config = ModelConfig(max_time_step=preprocess.max_time_step)\n", - "ORIGINAL_BASELINES = {\n", + "REFERENCE_TARGETS = {\n", " \"cnn\": 0.7049,\n", " \"cnn_lstm\": 0.6095,\n", " \"lstm\": 0.3962,\n", @@ -87,7 +87,7 @@ "source": [ "## Load And Prepare Data\n", "\n", - "The package preprocessing mirrors the original notebooks: label conversion, seeded validation split, temporal augmentation, and channels-last classifier layout.\n" + "The package preprocessing applies the default label conversion, seeded validation split, temporal augmentation, and channels-last classifier layout.\n" ] }, { @@ -171,9 +171,9 @@ "id": "20ae3614", "metadata": {}, "source": [ - "## Compare With The Original Notebook\n", + "## Compare With Reference Targets\n", "\n", - "FAST_DEV_RUN metrics are smoke-test numbers. Use the full settings before interpreting the baseline delta.\n" + "FAST_DEV_RUN metrics are smoke-test numbers. Use the full settings before interpreting the target delta.\n" ] }, { @@ -183,16 +183,16 @@ "metadata": {}, "outputs": [], "source": [ - "baseline_key = \"gan_cnn\" if USE_GAN_AUGMENTATION and CLASSIFIER_NAME == \"cnn\" else CLASSIFIER_NAME\n", - "baseline = ORIGINAL_BASELINES.get(baseline_key)\n", - "print(f\"experiment: {baseline_key}\")\n", + "target_key = \"gan_cnn\" if USE_GAN_AUGMENTATION and CLASSIFIER_NAME == \"cnn\" else CLASSIFIER_NAME\n", + "target = REFERENCE_TARGETS.get(target_key)\n", + "print(f\"experiment: {target_key}\")\n", "print(f\"test accuracy: {metrics['accuracy']:.4f}\")\n", - "if baseline is not None:\n", - " delta = metrics[\"accuracy\"] - baseline\n", - " print(f\"original notebook baseline: {baseline:.4f}\")\n", - " print(f\"delta from baseline: {delta:+.4f}\")\n", + "if target is not None:\n", + " delta = metrics[\"accuracy\"] - target\n", + " print(f\"reference target: {target:.4f}\")\n", + " print(f\"delta from target: {delta:+.4f}\")\n", "else:\n", - " print(\"No original notebook baseline exists for this GAN/classifier combination.\")\n" + " print(\"No reference target exists for this GAN/classifier combination.\")\n" ] } ], diff --git a/notebooks/03_jax_reproduction.ipynb b/notebooks/03_jax_reproduction.ipynb index 3e0813b..96c3429 100644 --- a/notebooks/03_jax_reproduction.ipynb +++ b/notebooks/03_jax_reproduction.ipynb @@ -22,7 +22,7 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": "DATA_DIR = Path(\"../data/processed\")\nif not DATA_DIR.exists():\n DATA_DIR = Path(\"../data_temp\")\n\nFAST_DEV_RUN = True\nDEV_SAMPLE_LIMIT = 128 if FAST_DEV_RUN else None\nMODEL_NAME = \"cnn\" # choices: cnn, lstm, cnn_lstm, gan_cnn\nUSE_GAN_AUGMENTATION = False\n\n# Full reproduction settings:\n# FAST_DEV_RUN = False\n# DEV_SAMPLE_LIMIT = None\n# MODEL_NAME = \"cnn\"\n# USE_GAN_AUGMENTATION = False\n\nif MODEL_NAME == \"gan_cnn\":\n USE_GAN_AUGMENTATION = True\n CLASSIFIER_NAME = \"cnn\"\nelse:\n CLASSIFIER_NAME = MODEL_NAME\n\npreprocess = PreprocessingConfig(seed=1)\nsplit = DataSplitConfig(validation_ratio=0.17, seed=1)\ntraining = TrainingConfig(\n fast_dev_run=FAST_DEV_RUN,\n use_gan_augmentation=USE_GAN_AUGMENTATION,\n gan_samples_per_class=10,\n seed=1,\n)\nmodel_config = ModelConfig(max_time_step=preprocess.max_time_step)\nORIGINAL_BASELINES = {\n \"cnn\": 0.7049,\n \"cnn_lstm\": 0.6095,\n \"lstm\": 0.3962,\n \"gan_cnn\": 0.6823,\n}\n" + "source": "DATA_DIR = Path(\"../data/processed\")\nif not DATA_DIR.exists():\n DATA_DIR = Path(\"../data_temp\")\n\nFAST_DEV_RUN = True\nDEV_SAMPLE_LIMIT = 128 if FAST_DEV_RUN else None\nMODEL_NAME = \"cnn\" # choices: cnn, lstm, cnn_lstm, gan_cnn\nUSE_GAN_AUGMENTATION = False\n\n# Full reproduction settings:\n# FAST_DEV_RUN = False\n# DEV_SAMPLE_LIMIT = None\n# MODEL_NAME = \"cnn\"\n# USE_GAN_AUGMENTATION = False\n\nif MODEL_NAME == \"gan_cnn\":\n USE_GAN_AUGMENTATION = True\n CLASSIFIER_NAME = \"cnn\"\nelse:\n CLASSIFIER_NAME = MODEL_NAME\n\npreprocess = PreprocessingConfig(seed=1)\nsplit = DataSplitConfig(validation_ratio=0.17, seed=1)\ntraining = TrainingConfig(\n fast_dev_run=FAST_DEV_RUN,\n use_gan_augmentation=USE_GAN_AUGMENTATION,\n gan_samples_per_class=10,\n seed=1,\n)\nmodel_config = ModelConfig(max_time_step=preprocess.max_time_step)\nREFERENCE_TARGETS = {\n \"cnn\": 0.7049,\n \"cnn_lstm\": 0.6095,\n \"lstm\": 0.3962,\n \"gan_cnn\": 0.6823,\n}\n" }, { "cell_type": "markdown", @@ -39,7 +39,7 @@ { "cell_type": "markdown", "metadata": {}, - "source": "## Optional GAN Augmentation\n\nSynthetic samples are generated from the training split only. Validation and test data remain real data, preserving the same evaluation protocol as the original notebooks.\n" + "source": "## Optional GAN Augmentation\n\nSynthetic samples are generated from the training split only. Validation and test data remain real data, keeping validation and test metrics tied to real held-out data.\n" }, { "cell_type": "code", @@ -63,14 +63,14 @@ { "cell_type": "markdown", "metadata": {}, - "source": "## Compare With The Original Notebook\n\nFAST_DEV_RUN metrics are smoke-test numbers. Use the full settings before interpreting the baseline delta.\n" + "source": "## Compare With Reference Targets\n\nFAST_DEV_RUN metrics are smoke-test numbers. Use the full settings before interpreting the target delta.\n" }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": "baseline_key = \"gan_cnn\" if USE_GAN_AUGMENTATION and CLASSIFIER_NAME == \"cnn\" else CLASSIFIER_NAME\nbaseline = ORIGINAL_BASELINES.get(baseline_key)\nprint(f\"experiment: {baseline_key}\")\nprint(f\"test accuracy: {metrics['accuracy']:.4f}\")\nif baseline is not None:\n delta = metrics[\"accuracy\"] - baseline\n print(f\"original notebook baseline: {baseline:.4f}\")\n print(f\"delta from baseline: {delta:+.4f}\")\nelse:\n print(\"No original notebook baseline exists for this GAN/classifier combination.\")\n" + "source": "target_key = \"gan_cnn\" if USE_GAN_AUGMENTATION and CLASSIFIER_NAME == \"cnn\" else CLASSIFIER_NAME\ntarget = REFERENCE_TARGETS.get(target_key)\nprint(f\"experiment: {target_key}\")\nprint(f\"test accuracy: {metrics['accuracy']:.4f}\")\nif target is not None:\n delta = metrics[\"accuracy\"] - target\n print(f\"reference target: {target:.4f}\")\n print(f\"delta from target: {delta:+.4f}\")\nelse:\n print(\"No reference target exists for this GAN/classifier combination.\")\n" } ], "metadata": { diff --git a/src/eegclassify/augmentation.py b/src/eegclassify/augmentation.py index 8e461ee..8af3562 100644 --- a/src/eegclassify/augmentation.py +++ b/src/eegclassify/augmentation.py @@ -32,11 +32,11 @@ class GANAugmentationResult: def interpolation_labels(n_classes: int = 4, samples_per_class: int = 10) -> np.ndarray: - """Create legacy GAN interpolation labels for all ordered class pairs. + """Create GAN interpolation labels for all ordered class pairs. - The original GAN-CNN notebook generated `samples_per_class` points between every - ordered pair of class labels, including same-class pairs. The generated labels are - soft labels and can be used directly with categorical cross-entropy style losses. + The routine generates `samples_per_class` points between every ordered pair of + class labels, including same-class pairs. The generated labels are soft labels and + can be used directly with categorical cross-entropy style losses. Args: n_classes: Number of class labels. diff --git a/src/eegclassify/bci2a.py b/src/eegclassify/bci2a.py index b0ccd4d..76d7986 100644 --- a/src/eegclassify/bci2a.py +++ b/src/eegclassify/bci2a.py @@ -27,7 +27,7 @@ @dataclass(frozen=True) class BCI2AConversionConfig: - """Parameters for recreating the course project `.npy` subset. + """Parameters for recreating the package processed `.npy` dataset. Attributes: window_samples: Number of time samples extracted per trial. @@ -36,9 +36,9 @@ class BCI2AConversionConfig: test_trial_stop: Stop trial ordinal for block-split test mode. eeg_channels: Number of EEG channels to extract. scale_to_microvolts: Whether to convert MNE's volt values to microvolts. - quantize_microvolts: Whether to round onto the original GDF microvolt grid. + quantize_microvolts: Whether to round onto the GDF microvolt grid. reject_artifacts: Whether to remove trials containing GDF artifact marker `1023`. - split_map: Built-in/course split, custom split-map JSON path, or `None` for block split. + split_map: Built-in package split, custom split-map JSON path, or `None` for block split. """ window_samples: int = DEFAULT_WINDOW_SAMPLES @@ -104,7 +104,7 @@ def copy_local_cache_to_processed( """Copy an existing local `.npy` cache into the processed data directory. Args: - cache_dir: Directory containing the six original project `.npy` files. + cache_dir: Directory containing the six processed `.npy` files. processed_dir: Destination directory. manifest_path: Optional manifest path written after copying. """ diff --git a/src/eegclassify/cli.py b/src/eegclassify/cli.py index c5d797a..2f1fb8e 100644 --- a/src/eegclassify/cli.py +++ b/src/eegclassify/cli.py @@ -111,7 +111,7 @@ def prepare_bci2a_main(argv: list[str] | None = None) -> None: parser.add_argument( "--split-map", default="builtin", - help="Use `builtin` for the course split map, a JSON path, or `none` for trial-block splitting.", + help="Use `builtin` for the package split map, a JSON path, or `none` for trial-block splitting.", ) parser.add_argument("--reject-artifacts", action="store_true") args = parser.parse_args(argv) diff --git a/src/eegclassify/config.py b/src/eegclassify/config.py index cc79eae..bd4f12f 100644 --- a/src/eegclassify/config.py +++ b/src/eegclassify/config.py @@ -76,7 +76,7 @@ class TrainingConfig: gan_learning_rate: Conditional GAN Adam learning rate. use_gan_augmentation: Whether to append synthetic GAN samples to the training split. gan_samples_per_class: Number of interpolation samples generated for each ordered - class pair in the legacy GAN augmentation routine. + class pair in the GAN augmentation routine. seed: Framework and data random seed. fast_dev_run: If `True`, notebooks/trainers run a one-epoch smoke check. """ diff --git a/src/eegclassify/data.py b/src/eegclassify/data.py index 25c2ee6..7f46d97 100644 --- a/src/eegclassify/data.py +++ b/src/eegclassify/data.py @@ -22,7 +22,7 @@ @dataclass class DatasetBundle: - """Container for the six processed arrays used by the original notebooks. + """Container for the six package processed arrays. Attributes: X_train_valid: Training/validation EEG trials with shape `(n, 22, 1000)`. @@ -77,7 +77,7 @@ def find_processed_data_dir(candidates: tuple[Path, ...] = LOCAL_CACHE_CANDIDATE def load_processed_arrays(data_dir: Path | str | None = None, mmap_mode: str | None = None) -> DatasetBundle: - """Load processed project arrays. + """Load processed EEG arrays. Args: data_dir: Directory containing the six `.npy` files. If omitted, known local cache @@ -101,7 +101,7 @@ def load_processed_arrays(data_dir: Path | str | None = None, mmap_mode: str | N def save_processed_arrays(bundle: DatasetBundle, output_dir: Path | str) -> None: - """Save a processed data bundle in the original six-file layout. + """Save a processed data bundle in the canonical six-file layout. Args: bundle: Arrays to save. diff --git a/src/eegclassify/models/tensorflow.py b/src/eegclassify/models/tensorflow.py index 5360520..366e6e4 100644 --- a/src/eegclassify/models/tensorflow.py +++ b/src/eegclassify/models/tensorflow.py @@ -33,7 +33,7 @@ def _require_tensorflow() -> None: class TensorFlowClassifierFactory: - """Build Keras classifiers matching the original notebook architectures. + """Build Keras classifiers for the package model families. Args: model_config: Optional architecture configuration. @@ -224,7 +224,7 @@ def build_discriminator(config: ModelConfig | None = None) -> keras.Model: class TensorFlowConditionalGAN(_TensorFlowGANBase): - """Conditional GAN augmentation model from the original notebook. + """Conditional GAN model for synthetic EEG training samples. Args: discriminator: Optional discriminator model. diff --git a/src/eegclassify/preprocessing.py b/src/eegclassify/preprocessing.py index 57198f6..00293a4 100644 --- a/src/eegclassify/preprocessing.py +++ b/src/eegclassify/preprocessing.py @@ -84,7 +84,7 @@ def augment_trials( config: PreprocessingConfig, rng: np.random.Generator | None = None, ) -> tuple[np.ndarray, np.ndarray]: - """Apply the original notebook augmentation routine. + """Apply the package temporal augmentation routine. The routine trims each trial, creates a max-pooled copy, creates an averaged copy with Gaussian noise, and appends each temporal subsampling offset. @@ -231,7 +231,7 @@ def prepare_splits( def prepare_gan_training_arrays(x_train: np.ndarray, y_train: np.ndarray) -> tuple[np.ndarray, np.ndarray, float]: - """Normalize classifier arrays into the GAN format used by the original notebook. + """Normalize classifier arrays into the GAN sequence format. Args: x_train: Classifier inputs shaped `(n, time, 1, channels)`.