Skip to content

[DRAFT] FEAT: ChargE3Net Fine-Tuning Pipeline on LeMatRho#8

Draft
speckhard wants to merge 7 commits intoLeMaterial:mainfrom
speckhard:feat/charge3net
Draft

[DRAFT] FEAT: ChargE3Net Fine-Tuning Pipeline on LeMatRho#8
speckhard wants to merge 7 commits intoLeMaterial:mainfrom
speckhard:feat/charge3net

Conversation

@speckhard
Copy link
Copy Markdown
Collaborator

@speckhard speckhard commented Mar 17, 2026

`feat/charge3net` — ChargE3Net Fine-Tuning Pipeline

Overview

Fine-tunes ChargE3Net (Koker et al., npj Computational Materials 2024) on LeMatRho charge density data. The approach is:

  1. Clone the official AIforGreatGood/charge3net repo as a sibling directory — this provides both the model code (E3DensityModel) and the pre-trained Materials Project checkpoint (models/charge3net_mp.pt, 23 MB, 1.9M params, trained for 245 epochs / 407k steps on MP charge densities).
  2. Load those pre-trained weights into our wrapper and fine-tune on 65,239 LeMatRho materials (10×10×10 compressed charge density grids stored in Parquet).
  3. Train on Jean Zay (1× A100, 20h SLURM jobs), resuming across jobs via latest.pt.

The training pipeline lives in charge3net_ft/ — a self-contained package within this repo.

Setup

# 1. Clone charge3net as a sibling of LeMat-Rho (required — provides model + checkpoint)
git clone https://github.com/AIforGreatGood/charge3net ../charge3net

# 2. Install dependencies
uv sync

# 3. Set data path (or pass --parquet-dir directly)
export LEMATRHO_DATA_DIR=/path/to/lematrho_full_10x10x10

# 4. Run a smoke test to verify everything works
python -m charge3net_ft.train --smoke-test

On Jean Zay, submit with sbatch submit_charge3net.sh — it auto-resumes from $SCRATCH/charge3net_checkpoints/latest.pt if present.

New Files

File Lines Purpose
charge3net_ft/data.py 270 Lazy-loading PyTorch Dataset. Reads LeMatRho Parquet chunks, converts to ase.Atoms + density grids, builds charge3net-compatible graph dicts via KdTreeGraphConstructor. Per-worker table cache so each chunk file is read from disk once per worker, not once per sample.
charge3net_ft/model.py 115 Thin wrapper around charge3net's E3DensityModel. Handles instantiation with MP checkpoint hyperparameters and loading from 3 checkpoint formats (legacy PyTorch Lightning, new charge3net, raw state_dict).
charge3net_ft/train.py 310 Full training script with CLI. Supports normal training, smoke test, single-batch overfit test, W&B logging (online/offline), checkpoint resumption across SLURM jobs. Train/val/test three-way split with held-out test evaluation at the end.
submit_charge3net.sh 47 SLURM GPU submission script for Jean Zay (A100, 20h wall time). Auto-resumes from latest.pt if present.
tests/test_metrics.py Parametrized tests for compute_nmape, compute_rmse, compute_nrmse including padding-mask correctness.
tests/test_data.py Unit tests for _parse_grid_json, _row_to_atoms_and_density, and _build_parquet_index using synthetic Parquet files — no real data needed.
tests/test_model.py Tests all 3 checkpoint-loading branches with mock state dicts.

Modified Files

File Change
pyproject.toml Added torch>=2.0, numpy>=1.24, charge3net (git dep); loosened aggressive version pins
.gitignore Added *.pt to prevent accidental checkpoint commits

Training Results (30 epochs across 3 SLURM jobs)

Epoch train L1 (e/ų) val L1 (e/ų) val NMAPE (%) val RMSE (e/ų) val NRMSE (%)
1 5.27 3.94 14.15 8.99 31.89
5 2.85 3.03 10.66 6.48 22.45
10 2.48 2.66 9.40 5.62 19.55
15 2.31 2.56 9.00 5.45 18.83
20 2.23 2.44 8.55 5.17 17.91
25 2.13 2.40 8.43 5.08 17.55
30 2.05 2.32 8.10 4.84 16.76

Best checkpoint: best.pt at epoch 30 (8.10% val NMAPE). Model is still improving — no plateau observed. Note: these are val-set numbers (used for checkpoint selection). Test-set numbers will be reported after the next training run using the new three-way split.

Git Log

dae2b45 Add PYTHONUNBUFFERED=1 to SLURM script for real-time log output
f3ffde4 Add --resume-from for checkpoint resumption across SLURM jobs
252acac Fix wandb on air-gapped clusters: add offline mode and init timeout
99d391a Add W&B logging, RMSE/NRMSE metrics, and Jean Zay SLURM script
f72fa52 Training pipeline to train Charge3net on LeMat-Rho

- Add wandb integration with --wandb-project/--wandb-entity/--no-wandb flags
- Add compute_rmse() and compute_nrmse() validation metrics
- Log per-step train loss and per-epoch train/val metrics to W&B
- Load WANDB_API_KEY from .env via python-dotenv
- Add submit_charge3net.sh for Jean Zay A100 GPU jobs
- Add .gitignore (excludes .env, checkpoints, wandb, etc.)
- save_checkpoint now includes global_step
- load_checkpoint restores model, optimizer, scheduler, epoch, best_nmape, global_step
- SLURM script auto-detects latest.pt and passes --resume-from
@speckhard speckhard requested review from Ramlaoui and mfranckel April 7, 2026 14:31
@speckhard speckhard marked this pull request as draft April 7, 2026 14:34
@speckhard speckhard changed the title FEAT: ChargE3Net Fine-Tuning Pipeline on LeMatRho [DRAFT] FEAT: ChargE3Net Fine-Tuning Pipeline on LeMatRho Apr 7, 2026
@speckhard
Copy link
Copy Markdown
Collaborator Author

@mfranckel let me try to create a job script for adastra so you can run this yourself.

@speckhard
Copy link
Copy Markdown
Collaborator Author

Sorry still needs to be cleaned up before you review it, just wanted to make you aware of it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant