Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 47 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@

# EEGClassify

This repository contains `eegclassify`, a reproducible EEG classification package built from the original C247 project notebooks. It includes CNN, LSTM, CNN-LSTM, and GAN-augmented classifier experiments across TensorFlow, PyTorch, and JAX/Flax.
`eegclassify` is a standalone Python package for reproducible EEG task classification. It provides a consistent data workflow, reusable preprocessing utilities, and CNN, LSTM, CNN-LSTM, and GAN-augmented training paths across TensorFlow, PyTorch, and JAX/Flax.

The original project write-up is available as [Report.pdf](Report.pdf). The original notebook code is preserved under [`notebooks/legacy/`](notebooks/legacy/) for reference, while the reproducible package-oriented notebooks live directly under [`notebooks/`](notebooks/).
Use it when you want a clean package interface for running EEG classification experiments, comparing frameworks, or regenerating the processed BCI Competition IV-2a arrays from the public raw data.

The public data source is BCI Competition IV, data set 2a:
## What Is Included

- Dataset page: https://www.bbci.de/competition/iv/
- Raw GDF archive: https://www.bbci.de/competition/download/competition_iv/BCICIV_2a_gdf.zip
- True-label archive: https://www.bbci.de/competition/iv/results/ds2a/true_labels.zip
- Reproducible BCI Competition IV data set 2a download and conversion tools.
- Processed `.npy` data files tracked with Git LFS for quick local examples.
- Shared experiment settings for labels, splits, augmentation, seeds, and batching.
- TensorFlow/Keras, PyTorch, and JAX/Flax model implementations.
- Reusable GAN augmentation for classifier training data.
- Jupyter notebooks and CLI commands for smoke tests and full reproductions.

## Quick Start

Expand All @@ -25,58 +28,77 @@ python -m pip install -U pip
python -m pip install -e ".[data,notebooks,docs,dev]"
```

Download and prepare data:
## Data Preparation

```bash
python -m eegclassify.cli download-bci2a --raw-dir data/raw
python -m eegclassify.cli prepare-bci2a --raw-dir data/raw --output-dir data/processed
python -m eegclassify.cli compare-data --generated-dir data/processed --reference-dir data_temp
```
The public source of truth is BCI Competition IV, data set 2a.

- Dataset page: https://www.bbci.de/competition/iv/
- Raw GDF archive: https://www.bbci.de/competition/download/competition_iv/BCICIV_2a_gdf.zip
- True-label archive: https://www.bbci.de/competition/iv/results/ds2a/true_labels.zip

Run tests and build documentation:
Download and regenerate the processed arrays:

```bash
pytest
mkdocs build --strict
python -m eegclassify.cli download-bci2a --raw-dir data/raw
python -m eegclassify.cli prepare-bci2a \
--raw-dir data/raw \
--output-dir data/processed_regenerated
python -m eegclassify.cli compare-data \
--generated-dir data/processed_regenerated \
--reference-dir data/processed
```

## Notebook Examples

The main reproducible Jupyter notebook examples are:
The runnable notebooks live under [`notebooks/`](notebooks/) and are also rendered in the documentation site.

- [Data postprocessing](notebooks/00_data_postprocessing.ipynb)
- [TensorFlow reproduction](notebooks/01_tensorflow_reproduction.ipynb)
- [PyTorch reproduction](notebooks/02_pytorch_reproduction.ipynb)
- [JAX/Flax reproduction](notebooks/03_jax_reproduction.ipynb)

The original project notebooks are preserved in [notebooks/legacy](notebooks/legacy/).
Each notebook defaults to `FAST_DEV_RUN=True` so it can be opened and executed quickly before launching longer training runs.

Run a quick local training smoke test:
## CLI Training

```bash
python -m eegclassify.cli train --framework tensorflow --model cnn --fast-dev-run
python -m eegclassify.cli train --framework pytorch --model cnn --fast-dev-run
python -m eegclassify.cli train --framework jax --model cnn --fast-dev-run
```

Use the original GAN-CNN parity path with:
Use GAN augmentation with the canonical CNN path:

```bash
python -m eegclassify.cli train --framework tensorflow --model gan_cnn --use-gan-augmentation
python -m eegclassify.cli train \
--framework tensorflow \
--model gan_cnn \
--use-gan-augmentation
```

For NVIDIA GPU support on Linux, install:
## GPU Extras

For NVIDIA GPU support on Linux, install the framework extras you need:

```bash
python -m pip install -e ".[tensorflow,pytorch,jax-cuda13]"
```

The TensorFlow CLI configures the pip-installed CUDA library paths before importing TensorFlow.

The processed files in `data/processed/*.npy` are intended to be committed through Git LFS. They are excluded from PyPI builds.
## Documentation And Package Notes

- Documentation: https://mhmdaskari.github.io/EEG_DL_Classification/
- Processed files under `data/processed/*.npy` are tracked with Git LFS.
- Processed arrays are excluded from PyPI builds; package users can download or regenerate them with the CLI.

Documentation is configured for GitHub Pages at:
Run checks locally:

```bash
pytest
mkdocs build --strict
```

https://mhmdaskari.github.io/EEG_DL_Classification/
This package was first initiated as part of a UCLA C247 course project; the original report is available as [Report.pdf](Report.pdf).

This project began as part of UCLA C247, Deep Learning and Neural Networks, Winter 2023. See [Report.pdf](Report.pdf) for the report associated with that work.
Archived notebook-only prototypes are kept under [`notebooks/legacy/`](notebooks/legacy/) for historical reference.
24 changes: 14 additions & 10 deletions docs/data.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ The public source of truth is BCI Competition IV, data set 2a.
- Raw GDF archive: <https://www.bbci.de/competition/download/competition_iv/BCICIV_2a_gdf.zip>
- True-label archive: <https://www.bbci.de/competition/iv/results/ds2a/true_labels.zip>

The original notebooks used six processed files:
The packaged processed dataset uses six files:

- `X_train_valid.npy`
- `y_train_valid.npy`
Expand All @@ -16,7 +16,7 @@ The original notebooks used six processed files:
- `person_test.npy`

Post-processed `.npy` files under `data/processed/` are versioned with Git LFS so the
notebooks can run from a fresh clone. Raw archives, extracted GDF files, local caches,
examples can run from a fresh clone. Raw archives, extracted GDF files, local caches,
and intermediate conversion folders stay out of Git.

Before adding or updating processed arrays, install Git LFS and initialize it once:
Expand Down Expand Up @@ -44,26 +44,30 @@ should download or regenerate the data with the commands below.
python -m eegclassify.cli download-bci2a --raw-dir data/raw
```

## Regenerate The Project Subset
## Regenerate The Processed Dataset

```bash
python -m eegclassify.cli prepare-bci2a \
--raw-dir data/raw \
--output-dir data/processed \
--manifest data/processed/manifest.json
--output-dir data/processed_regenerated \
--manifest data/processed_regenerated/manifest.json
```

The default converter uses the training GDF files `A01T.gdf` through `A09T.gdf`, extracts the first 22 EEG channels, uses 1000 samples starting one sample after cue onset, and applies the built-in course split map inferred from the original `.npy` cache. A trial-block split is still available with `--split-map none`.
The default converter uses the training GDF files `A01T.gdf` through `A09T.gdf`,
extracts the first 22 EEG channels, uses 1000 samples starting one sample after cue
onset, and applies the package split map. A trial-block split is still available
with `--split-map none`.

## Compare With Local Cache
## Compare Processed Outputs

```bash
python -m eegclassify.cli compare-data \
--generated-dir data/processed \
--reference-dir data_temp \
--generated-dir data/processed_regenerated \
--reference-dir data/processed \
--output artifacts/data_comparison.json
```

The comparison report includes shapes, dtypes, SHA256 hashes, label counts, subject counts, exact-match flags, and numeric differences.

The notebook `notebooks/00_data_postprocessing.ipynb` runs the same workflow interactively.
The [data postprocessing notebook](examples/00_data_postprocessing.ipynb) runs the
same workflow interactively.
84 changes: 84 additions & 0 deletions docs/examples/00_data_postprocessing.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": "# BCI IV 2a Data Postprocessing\n\nThis notebook makes the dataset lineage explicit. It downloads the official BCI Competition IV data set 2a archives, recreates the six package processed `.npy` files, writes a manifest, and optionally compares regenerated files with the packaged processed data.\n"
},
{
"cell_type": "markdown",
"metadata": {},
"source": "## What The Files Mean\n\nThe raw source is the public BCI IV-2a GDF release. The project subset stores train/validation candidate trials, held-out test trials, labels, and subject IDs in six NumPy files. The labels are still the original cue IDs (`769-772`) at this stage; conversion to class IDs happens during preprocessing.\n"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "from pathlib import Path\nimport json\n\nfrom eegclassify.bci2a import BCI2AConversionConfig, convert_bci2a_training_subset, download_bci2a\nfrom eegclassify.data import compare_processed_dirs, summarize_processed_dir, write_json_report\n"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "RAW_DIR = Path(\"../data/raw\")\nPROCESSED_DIR = Path(\"../data/processed\")\nREFERENCE_DIR = Path(\"../data_temp\")\nREPORT_DIR = Path(\"../artifacts\")\n\nDOWNLOAD = True\nREGENERATE_FROM_RAW = True\nCOMPARE_WITH_REFERENCE = True\n"
},
{
"cell_type": "markdown",
"metadata": {},
"source": "## Download Official Archives\n\nSet `DOWNLOAD = False` when the archives are already available under `data/raw/`. Raw files are reproducible from the official URLs and should not be committed to Git.\n"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "if DOWNLOAD:\n paths = download_bci2a(RAW_DIR)\n print(json.dumps({key: str(value) for key, value in paths.items()}, indent=2))\nelse:\n print(\"Skipping download\")\n"
},
{
"cell_type": "markdown",
"metadata": {},
"source": "## Convert Raw GDF Files To The Processed Dataset\n\nThe converter uses the training GDF files `A01T.gdf` through `A09T.gdf`, the first 22 EEG channels, 1000 samples per trial, and the built-in package split map.\n"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "config = BCI2AConversionConfig(\n window_samples=1000,\n window_offset_samples=1,\n test_trial_start=200,\n test_trial_stop=250,\n)\n\nif REGENERATE_FROM_RAW:\n bundle = convert_bci2a_training_subset(RAW_DIR, PROCESSED_DIR, config, PROCESSED_DIR / \"manifest.json\")\n print(\"X_train_valid:\", bundle.X_train_valid.shape)\n print(\"X_test:\", bundle.X_test.shape)\nelse:\n print(\"Skipping raw-to-subset conversion\")\n"
},
{
"cell_type": "markdown",
"metadata": {},
"source": "## Manifest And Comparison\n\nThe manifest records source URLs, shapes, dtypes, label counts, subject counts, file sizes, and SHA256 hashes. If `data_temp/` exists, the comparison report checks whether the regenerated files match the local cache.\n"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "manifest = summarize_processed_dir(PROCESSED_DIR)\nwrite_json_report(manifest, PROCESSED_DIR / \"manifest.json\")\nprint(json.dumps({name: info.get(\"shape\") for name, info in manifest[\"files\"].items()}, indent=2))\n"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "if COMPARE_WITH_REFERENCE and REFERENCE_DIR.exists():\n report = compare_processed_dirs(PROCESSED_DIR, REFERENCE_DIR)\n write_json_report(report, REPORT_DIR / \"data_comparison.json\")\n print(\"all_files_match:\", report[\"all_files_match\"])\nelse:\n print(\"Skipping reference-cache comparison\")\n"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
89 changes: 89 additions & 0 deletions docs/examples/01_tensorflow_reproduction.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": "# TensorFlow EEG Classification Reproduction\n\nThis notebook runs the TensorFlow/Keras implementation. It defaults to a small smoke run so the workflow is easy to verify. For the full reproduction, set `FAST_DEV_RUN = False` and leave `DEV_SAMPLE_LIMIT = None`.\n"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "from pathlib import Path\n\nfrom eegclassify.config import DataSplitConfig, ModelConfig, PreprocessingConfig, TrainingConfig\nfrom eegclassify.data import load_processed_arrays\nfrom eegclassify.models.tensorflow import TensorFlowClassifierFactory, TensorFlowGANAugmenter, TensorFlowTrainer\nfrom eegclassify.preprocessing import prepare_splits\n"
},
{
"cell_type": "markdown",
"metadata": {},
"source": "## Configuration\n\n`MODEL_NAME = \"gan_cnn\"` is a convenience alias for `MODEL_NAME = \"cnn\"` plus `USE_GAN_AUGMENTATION = True`. GAN augmentation can also be used with `lstm` or `cnn_lstm`, and GAN+CNN is the canonical GAN-augmented reference target.\n"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "DATA_DIR = Path(\"../data/processed\")\nif not DATA_DIR.exists():\n DATA_DIR = Path(\"../data_temp\")\n\nFAST_DEV_RUN = True\nDEV_SAMPLE_LIMIT = 128 if FAST_DEV_RUN else None\nMODEL_NAME = \"cnn\" # choices: cnn, lstm, cnn_lstm, gan_cnn\nUSE_GAN_AUGMENTATION = False\n\n# Full reproduction settings:\n# FAST_DEV_RUN = False\n# DEV_SAMPLE_LIMIT = None\n# MODEL_NAME = \"cnn\"\n# USE_GAN_AUGMENTATION = False\n\nif MODEL_NAME == \"gan_cnn\":\n USE_GAN_AUGMENTATION = True\n CLASSIFIER_NAME = \"cnn\"\nelse:\n CLASSIFIER_NAME = MODEL_NAME\n\npreprocess = PreprocessingConfig(seed=1)\nsplit = DataSplitConfig(validation_ratio=0.17, seed=1)\ntraining = TrainingConfig(\n fast_dev_run=FAST_DEV_RUN,\n use_gan_augmentation=USE_GAN_AUGMENTATION,\n gan_samples_per_class=10,\n seed=1,\n)\nmodel_config = ModelConfig(max_time_step=preprocess.max_time_step)\nREFERENCE_TARGETS = {\n \"cnn\": 0.7049,\n \"cnn_lstm\": 0.6095,\n \"lstm\": 0.3962,\n \"gan_cnn\": 0.6823,\n}\n"
},
{
"cell_type": "markdown",
"metadata": {},
"source": "## Load And Prepare Data\n\nPreprocessing converts labels `769-772` to `0-3`, creates the validation split with seed `1`, applies the original max-pooling / averaging / subsampling augmentation, and reshapes arrays to `(trials, time, 1, channels)`.\n"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "bundle = load_processed_arrays(DATA_DIR)\nprepared = prepare_splits(bundle, preprocess, split)\n\nx_train, y_train = prepared.x_train, prepared.y_train\nx_valid, y_valid = prepared.x_valid, prepared.y_valid\nx_test, y_test = prepared.x_test, prepared.y_test\n\nif DEV_SAMPLE_LIMIT is not None:\n x_train, y_train = x_train[:DEV_SAMPLE_LIMIT], y_train[:DEV_SAMPLE_LIMIT]\n x_valid, y_valid = x_valid[:DEV_SAMPLE_LIMIT], y_valid[:DEV_SAMPLE_LIMIT]\n x_test, y_test = x_test[:DEV_SAMPLE_LIMIT], y_test[:DEV_SAMPLE_LIMIT]\n\nprint(\"train:\", x_train.shape, y_train.shape)\nprint(\"valid:\", x_valid.shape, y_valid.shape)\nprint(\"test:\", x_test.shape, y_test.shape)\nprint(\"labels are one-hot with columns 0-3 for original cues 769-772\")\n"
},
{
"cell_type": "markdown",
"metadata": {},
"source": "## Optional GAN Augmentation\n\nWhen enabled, a conditional GAN is trained on the training split only. Synthetic samples are appended to `x_train` and `y_train`; validation and test arrays stay untouched so evaluation remains comparable.\n"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "gan_result = None\nif USE_GAN_AUGMENTATION:\n gan_result = TensorFlowGANAugmenter(training, model_config).augment_training_data(x_train, y_train)\n x_train, y_train = gan_result.x_train, gan_result.y_train\n print(\"synthetic samples:\", gan_result.x_synthetic.shape[0])\n print(\"augmented train:\", x_train.shape, y_train.shape)\nelse:\n print(\"GAN augmentation disabled\")\n"
},
{
"cell_type": "markdown",
"metadata": {},
"source": "## Build, Train, And Evaluate\n\nThe classifier architecture is selected after the optional augmentation step. TensorFlow uses categorical cross-entropy, so both one-hot labels and GAN interpolation soft labels are supported.\n"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "factory = TensorFlowClassifierFactory(model_config)\nmodel = factory.build(CLASSIFIER_NAME, learning_rate=training.learning_rate)\nmodel.summary()\n\ntrainer = TensorFlowTrainer(training)\nhistory = trainer.fit(model, x_train, y_train, x_valid, y_valid)\nmetrics = trainer.evaluate(model, x_test, y_test)\nmetrics\n"
},
{
"cell_type": "markdown",
"metadata": {},
"source": "## Compare With Reference Targets\n\nFAST_DEV_RUN metrics are smoke-test numbers. Use the full settings before interpreting the target delta.\n"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "target_key = \"gan_cnn\" if USE_GAN_AUGMENTATION and CLASSIFIER_NAME == \"cnn\" else CLASSIFIER_NAME\ntarget = REFERENCE_TARGETS.get(target_key)\nprint(f\"experiment: {target_key}\")\nprint(f\"test accuracy: {metrics['accuracy']:.4f}\")\nif target is not None:\n delta = metrics[\"accuracy\"] - target\n print(f\"reference target: {target:.4f}\")\n print(f\"delta from target: {delta:+.4f}\")\nelse:\n print(\"No reference target exists for this GAN/classifier combination.\")\n"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading