From 52d4cd029710ebb8d8d2260f6d1490c965453012 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Wed, 10 Jun 2026 22:20:17 +0000 Subject: [PATCH 1/8] sae recipe skill: streaming extract + launch step + hard-won gotchas Update the bionemo-sae-recipe skill: (1) replace the predict->.pt->parquet-shim flow with the STREAMING extractor (wrap predict_, monkeypatch its writer to ActivationStore, no .pt) per evo2/scripts/extract.py; (2) add a launch step with the known-good training config; (3) bake in the issues we hit -- LR cosine must span the full epoch (don't truncate steps), --dead-count-global is a DDP-only no-op (and must be passed, not just named), shard-0 pre-bias is biased (use --presample-shards N>1), and the wandb env (unset WANDB_API_KEY + WANDB_ENTITY); (4) extend scope to eval via sae.eval/probing. Drops the obsolete .pt disk-pressure and shim-throughput caveats. Co-Authored-By: Claude Opus 4.8 Signed-off-by: Polina Binder --- .../recipes/bionemo-sae-recipe.md | 162 ++++++++++++++++++ 1 file changed, 162 insertions(+) create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md new file mode 100644 index 0000000000..4626a7af88 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md @@ -0,0 +1,162 @@ +--- +name: bionemo-sae-recipe +description: Build a new sparse-autoencoder recipe under bionemo-recipes/interpretability/sparse_autoencoders/recipes/ for a biological foundation model (e.g. evo2, nemotron, geneformer) — extract activations, train an SAE, and evaluate it. Trigger when the user asks to "add SAE for ", "build a new SAE recipe", or "run an SAE on ". +--- + +# Build a new SAE recipe in bionemo-framework + +## The pattern + +Every SAE recipe in `bionemo-recipes/interpretability/sparse_autoencoders/recipes/` decomposes into the same stages, separated by a universal contract: + +``` +extractor (model-specific) → ActivationStore parquet shards → train.py (universal) → eval (sae.eval, universal) + ↑ contract +``` + +- **Extractor** runs the model forward and **streams** layer-L activations *directly* into an `ActivationStore` — no intermediate `.pt` files. The clean pattern (see `evo2/scripts/extract.py`): reuse the model's existing `predict_` CLI but **monkeypatch its per-batch writer** with one that appends `hidden[mask]` to `sae.activation_store.ActivationStore`. Model-specific (~150 lines). +- **ActivationStore** (`sae/src/sae/activation_store.py`) is the universal on-disk format: a directory of `shard_{NNNNN}.parquet` + `metadata.json` (`{model_name, layer, hidden_dim, n_samples, n_shards, shard_size, n_sequences}`). +- **train.py** loads via `sae.activation_store.load_activations(cache_dir)` and trains a TopK/ReLU SAE. **Identical across recipes — copy verbatim** (only the docstring differs); uses `--model-path` solely for a one-line cache-validation warning. +- **eval** (`sae.eval`, universal): `reconstruction` (variance explained), `dead_latents` (%), `loss_recovered` (CE fidelity), and `probing` (per-feature AUROC / linear probes / domain-F1 over a labeled `ActivationBuffer`). Probing scoring is **CPU-only** — it reads saved buffers, no model. + +## When this applies + +Bringing up an SAE on a new biological foundation model — Evo2, ESM2, CodonFM, Nemotron, Geneformer, etc. A checkpoint (HF or local) is in hand. Scope is the full **extract → train → eval** pipeline. Per-model you write a thin **extractor** (and, for interpretability, **labelers**); everything downstream is shared. + +## Steps + +### 1. Reconnaissance (read, don't write) + +- Templates: `recipes/esm2/` (HF `AutoModel` path), `recipes/codonfm/` (custom checkpoint), `recipes/evo2/` (streaming reuse of a `predict_` CLI). Pick the closest. +- Find the model's inference path in `bionemo-recipes/recipes/_*/`. If it has a `predict_` CLI, reuse it (streaming); else write `extract.py` modeled on `esm2/`. +- Identify hidden_dim, layer count, **trained context length** (critical — see gotchas). + +### 2. Build the upstream env (if needed) + +Recipes under `bionemo-recipes/recipes/_*/` have `.ci_build.sh` that makes a `--system-site-packages` `.venv` — **assumes the NVIDIA pytorch container** with TransformerEngine preinstalled. Verify first: + +```bash +ls /usr/local/lib/python*/dist-packages/transformer_engine 2>/dev/null && echo "OK to build" +``` + +### 3. Scaffold the recipe dir + +``` +recipes// +├── README.md +├── pyproject.toml # deps: sae, torch, numpy, pyarrow ; [tool.uv.sources] sae = { workspace = true } +└── scripts/ + ├── .sh # orchestrator: chunk → stream-extract → train + ├── extract.py # STREAMING: wraps predict_, writes ActivationStore directly (NO .pt) + └── train.py # COPY VERBATIM from any recipe (e.g. codonfm/scripts/train.py) +``` + +### 4. The streaming extractor + +Reuse the upstream forward; swap only the writer: + +```python +from bionemo..run import predict as predict_mod +predict_mod._write_predictions_batch = _store_writer # appends hidden[pad_mask] to ActivationStore +sys.argv = [sys.argv[0], *forwarded_predict_flags] +predict_mod.main() +``` + +No `.pt`, ~half the disk, no separate conversion pass. Under DDP each rank writes its own tmp store; rank 0 merges at the end via a **file-based wait** (poll for sibling `metadata.json`) — **not** `dist.barrier()`, because `predict.main()` tears down the process group before the finalize hook runs. + +### 5. Launch the training + +The orchestrator (`.sh`) chains chunk → extract → train. Launch with `torchrun`, `--dp-size` = #GPUs. **Always smoke first** (20–100 sequences → confirm loss drops), then the full run. + +```bash +unset WANDB_API_KEY # a leaked key in the shared env overrides ~/.netrc — you'd log as someone else +export WANDB_ENTITY= # accounts with no default entity fail wandb.init otherwise + +torchrun --nproc_per_node 8 scripts/train.py \ + --cache-dir --model-path --layer L \ + --model-type topk --expansion-factor 16 --top-k 128 --normalize-input \ + --auxk 2048 --auxk-coef 0.03125 --dead-tokens-threshold 10000000 \ + --init-pre-bias --presample-shards 8 --mix-shards 10 \ + --aggregate-loss --dead-count-global \ + --n-epochs 1 --batch-size 1024 \ + --lr 1e-4 --lr-schedule cosine --lr-min 1e-5 --warmup-steps 1000 --max-grad-norm 1.0 \ + --dp-size 8 --wandb --wandb-project +``` + +For a **sweep**, run one config at a time on a fixed GPU group (sequential), not many in parallel — parallel runs contend on the same parquet cache I/O. Give each `torchrun` a distinct `--master-port`. + +### 6. Cache guards in the orchestrator + +Each long step needs an idempotency check on a sentinel the step itself produces: + +```bash +[[ -f "${PARQUET_DIR}/metadata.json" ]] || torchrun ... scripts/extract.py ... # finalize() writes metadata.json last +``` + +**Caveat:** guards check existence, not provenance — `rm -rf` the output dir when the input FASTA / model / layer changes. + +## Known-good training config (and why) + +These defaults reproduced the best Evo2-7B / layer-26 SAE (~21% dead, ~0.10 FVU). All are **opt-in** in the `sae` package (defaults reproduce older behavior): + +| flag | why it matters | +| ---------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------- | +| `--normalize-input` | the single biggest dead-latent lever (∼80% → ∼20% dead) | +| `--aggregate-loss` | batch-level FVU/AuxK ratio instead of per-token (per-token starves rare high-variance tokens → their latents die) | +| `--dead-count-global` | counts dead-latent inactivity in **total** tokens (×world_size); the per-rank default fires AuxK revival `world_size`× too late under DDP | +| `--mix-shards 10` | shuffles + blends shards; corpus/kingdom-ordered caches otherwise give a visible FVU cliff | +| `--presample-shards 8` | geometric-median pre-bias over 8 shards, not shard-0 alone — a single-shard sample is corpus-order-biased and **measurably worsens dead latents** | + +## Known gotchas (these cost real debug time) + +### Training dynamics (learned the hard way) + +1. **Never truncate steps without fixing the LR horizon.** `--n-epochs 1` lets cosine decay over the *whole* epoch. Capping steps (e.g. a `--max-steps`/short `--lr-decay-steps`) shrinks the cosine horizon, so LR collapses to `lr_min` early and training is much worse — looks like a model/code regression but it's the schedule. +2. **`--dead-count-global` is a no-op at dp-size 1** (world_size=1). It only does anything under DDP. And it must actually be **passed** — encoding "dcg=true" only in a run *name* while omitting the flag silently runs the per-rank default (a real sweep bug). +3. **Pre-bias from shard-0 only is biased.** On a corpus-ordered cache (e.g. all-prok-then-all-euk), a single-shard geometric-median init mis-centers `pre_bias` toward one kingdom and worsens dead latents. Use `--presample-shards N>1`. + +### wandb + +4. **`unset WANDB_API_KEY` before launching** — a leaked key in the shared env overrides `~/.netrc`, so your runs log under someone else's account. Then set `WANDB_ENTITY` if your account has no default entity (else `wandb.init` fails / lands in the wrong entity). + +### Container / env + +5. `.ci_build.sh` assumes system-site-packages TransformerEngine — verify before building (step 2). +6. `huggingface-cli` is deprecated → use `hf` (same args). HF README dir names are unreliable (OpenGenome2's `jsonl/` is really `json/`) — verify the tree and that the downloaded file count is nonzero. + +### Checkpoint loading + +7. **`weights_only=True` (torch 2.6 default) breaks legacy checkpoints with numpy arrays** — buried in stderr, exit 0, empty output dir. `UnpicklingError: Unsupported global: numpy.core.multiarray._reconstruct`. Patch the upstream `torch.load(...)` to `weights_only=False` if the source is trusted. (For Evo2, the recipe assumes an MBridge checkpoint — conversion from savanna/nemo2 is a prerequisite, not recipe code.) + +### Model architecture + +08. **Hyena (evo2) fftconv OOMs on long sequences even at micro-batch=1** (FFT intermediates scale super-linearly). **Chunk inputs to the trained context** before extraction: evo2 1B → 8192 bp; 7B → context-extended (check release); 40B → 1M. Don't rely on the inference tool to truncate. +09. `predict_evo2` takes **uncompressed FASTA only** (`<(zcat ...)` fails); but if your chunker already reads `.gz` → writes plain `.fasta`, no separate gunzip is needed. +10. `--micro-batch-size 1` is often far from optimal once chunks are uniform/short — memory drops ~10×, raise it (chunking alone gave ~17× on Evo2 1B). + +### Output format + +11. `predict_evo2 --embedding-layer N` yields `{hidden_embeddings:[B,S,H], pad_mask:[B,S], seq_idx:[B], tokens:[B,S], batch_idx:int}`. `pad_mask` is a **loss mask** (1=valid), not an HF attention mask. The streaming `_store_writer` appends `hidden_embeddings[pad_mask.bool()]`. + +## Evaluating the SAE + +After training, run `sae.eval` on a **held-out** cache (same distribution, disjoint instances): + +- `reconstruction` → variance explained; `dead_latents` → dead %; `loss_recovered` → CE fidelity (substitute the SAE recon at the layer-L hook). +- For interpretability, build a labeled `ActivationBuffer` (per-token feature codes + concept labels + optional dense-residual twin) and run `sae.eval.probing` — per-feature AUROC, winner's-curse-corrected best-single, SAE-vs-dense probes, domain-F1. Labelers are **per-domain** (DNA / protein / codon); the scoring is shared. + +## Verifying the recipe works (fastest → most confident) + +1. **Mechanical** — pipeline runs end-to-end, `checkpoint_final.pt` exists. Smoke on 20–100 sequences (minutes). +2. **Numerical** — `train.py` log shows loss ↓, FVU < 1, dead-% trending toward ~20% (not stuck at ~80%). If dead-% is stuck high, check normalize-input / presample / the LR horizon (gotcha 1). +3. **Shape sanity** — `torch.load(checkpoint_final.pt)`: encoder `[hidden_dim → expansion·hidden_dim]`, decoder the transpose. + +## Reference recipes + +| Recipe | Extract path | Mirror it when | +| ---------- | -------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------- | +| `esm2/` | `extract.py` → HF `AutoModel.from_pretrained` + `output_hidden_states` | new model is HF-native with a clean `AutoModel` | +| `codonfm/` | `extract.py` → custom inference class | new model has its own checkpoint + forward code | +| `evo2/` | **streaming** `extract.py` — wraps `predict_evo2`, monkeypatches its writer to an `ActivationStore` (no `.pt`) | upstream already has a `predict_` CLI; reuse it and stream | + +All share the same verbatim `train.py` and the `ActivationStore` parquet contract. From 091a778b72ad266d5d55d8d85bcba035f73ce5da Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Wed, 10 Jun 2026 22:25:59 +0000 Subject: [PATCH 2/8] sae recipe skill: reframe Evo2-specific gotchas as general principles + add a perf micro-benchmark The model-architecture gotchas (Hyena OOM, predict CLI input constraints, micro-batch tuning) were Evo2-specific in a general skill. Reframe each as a general principle with Evo2 as the *example*, and add a single-GPU micro-benchmark (micro-batch + seq-length sweeps measuring peak mem + throughput) so the perf claims are measured per-model instead of copied. Co-Authored-By: Claude Opus 4.8 Signed-off-by: Polina Binder --- .../recipes/bionemo-sae-recipe.md | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md index 4626a7af88..7f566f2c26 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md @@ -128,11 +128,18 @@ These defaults reproduced the best Evo2-7B / layer-26 SAE (~21% dead, ~0.10 FVU) 7. **`weights_only=True` (torch 2.6 default) breaks legacy checkpoints with numpy arrays** — buried in stderr, exit 0, empty output dir. `UnpicklingError: Unsupported global: numpy.core.multiarray._reconstruct`. Patch the upstream `torch.load(...)` to `weights_only=False` if the source is trusted. (For Evo2, the recipe assumes an MBridge checkpoint — conversion from savanna/nemo2 is a prerequisite, not recipe code.) -### Model architecture +### Model architecture / extraction (general principle → Evo2 example) -08. **Hyena (evo2) fftconv OOMs on long sequences even at micro-batch=1** (FFT intermediates scale super-linearly). **Chunk inputs to the trained context** before extraction: evo2 1B → 8192 bp; 7B → context-extended (check release); 40B → 1M. Don't rely on the inference tool to truncate. -09. `predict_evo2` takes **uncompressed FASTA only** (`<(zcat ...)` fails); but if your chunker already reads `.gz` → writes plain `.fasta`, no separate gunzip is needed. -10. `--micro-batch-size 1` is often far from optimal once chunks are uniform/short — memory drops ~10×, raise it (chunking alone gave ~17× on Evo2 1B). +These are **general principles**; the numbers are Evo2 examples — **measure them for your model** (see "Verify the perf claims" below), don't copy the constants. + +08. **Long sequences can blow up memory super-linearly on conv/FFT architectures → chunk inputs to the model's trained context before extraction.** *Evo2 example:* Hyena's fftconv OOMs even at micro-batch=1 (intermediates scale super-linearly); chunk to 1B → 8192 bp, 7B → context-extended (check release), 40B → 1M. Don't rely on the inference tool to truncate. +09. **Check your predict CLI's input constraints (compression/format).** *Evo2 example:* `predict_evo2` takes uncompressed FASTA only (`<(zcat ...)` fails); but if your chunker already reads `.gz` → writes plain `.fasta`, no separate gunzip is needed. +10. **micro-batch=1 is rarely optimal — once inputs are short/uniform, raise it.** *Evo2 example:* chunking dropped memory ~10× and gave ~17× per-batch speedup on Evo2 1B, so `--micro-batch-size` could be raised well past 1. + +**Verify the perf claims (don't trust the constants):** a few-minute single-GPU micro-benchmark — + +- **micro-batch sweep:** fix a chunked FASTA, run the extractor at `--micro-batch-size ∈ {1,4,8,16,32}`, log peak GPU mem (`torch.cuda.max_memory_allocated`) + throughput (tokens/s over fixed N). Find the largest mbs that fits + the throughput curve. +- **seq-length sweep** (for #8): mbs=1, L ∈ {1k,8k,16k,32k}, log peak mem → see the blowup / OOM point for *your* architecture. ### Output format From 9701eb6932eaacc39cb70fd37bfce663c60b7b89 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Wed, 10 Jun 2026 22:35:51 +0000 Subject: [PATCH 3/8] =?UTF-8?q?sae=20recipe=20skill:=20train.py=20is=20nea?= =?UTF-8?q?r-verbatim,=20NOT=20a=20blind=20copy=20=E2=80=94=20must=20wire?= =?UTF-8?q?=20the=20opt-in=20flags?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 'Copy verbatim' was wrong and dangerous: copying an older train.py silently drops the opt-in training flags (--aggregate-loss/--dead-count-global/--mix-shards/ --presample-shards), turning a 'reproduce the winner' run into a baseline/losing run. Reword all three 'verbatim' mentions: start from a CURRENT recipe (codonfm/evo2) that wires the flags, edit only docstring + wandb default, and note the dedup-into-sae follow-up. Co-Authored-By: Claude Opus 4.8 Signed-off-by: Polina Binder --- .../sparse_autoencoders/recipes/bionemo-sae-recipe.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md index 7f566f2c26..52c8ab023f 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md @@ -16,7 +16,7 @@ extractor (model-specific) → ActivationStore parquet shards → train.py (univ - **Extractor** runs the model forward and **streams** layer-L activations *directly* into an `ActivationStore` — no intermediate `.pt` files. The clean pattern (see `evo2/scripts/extract.py`): reuse the model's existing `predict_` CLI but **monkeypatch its per-batch writer** with one that appends `hidden[mask]` to `sae.activation_store.ActivationStore`. Model-specific (~150 lines). - **ActivationStore** (`sae/src/sae/activation_store.py`) is the universal on-disk format: a directory of `shard_{NNNNN}.parquet` + `metadata.json` (`{model_name, layer, hidden_dim, n_samples, n_shards, shard_size, n_sequences}`). -- **train.py** loads via `sae.activation_store.load_activations(cache_dir)` and trains a TopK/ReLU SAE. **Identical across recipes — copy verbatim** (only the docstring differs); uses `--model-path` solely for a one-line cache-validation warning. +- **train.py** loads via `sae.activation_store.load_activations(cache_dir)` and trains a TopK/ReLU SAE — **near-identical across recipes, but not a blind verbatim copy.** It must wire the opt-in training flags (`--aggregate-loss` / `--dead-count-global` / `--mix-shards` / `--presample-shards`). Start from a **current** recipe's `train.py` (codonfm/evo2 — they already wire them), then change only the docstring + `--wandb-project` default. **Copying an older train.py silently drops those flags → the losing config** (this is exactly how a "reproduce the winner" run quietly turns into a baseline run). Uses `--model-path` only for a cache-validation warning. (The copy-paste is a known smell; the intended end-state is a single shared train-CLI in `sae`.) - **eval** (`sae.eval`, universal): `reconstruction` (variance explained), `dead_latents` (%), `loss_recovered` (CE fidelity), and `probing` (per-feature AUROC / linear probes / domain-F1 over a labeled `ActivationBuffer`). Probing scoring is **CPU-only** — it reads saved buffers, no model. ## When this applies @@ -48,7 +48,7 @@ recipes// └── scripts/ ├── .sh # orchestrator: chunk → stream-extract → train ├── extract.py # STREAMING: wraps predict_, writes ActivationStore directly (NO .pt) - └── train.py # COPY VERBATIM from any recipe (e.g. codonfm/scripts/train.py) + └── train.py # near-verbatim from a CURRENT recipe (codonfm/evo2): MUST wire the opt-in flags; edit only docstring + wandb default ``` ### 4. The streaming extractor @@ -166,4 +166,4 @@ After training, run `sae.eval` on a **held-out** cache (same distribution, disjo | `codonfm/` | `extract.py` → custom inference class | new model has its own checkpoint + forward code | | `evo2/` | **streaming** `extract.py` — wraps `predict_evo2`, monkeypatches its writer to an `ActivationStore` (no `.pt`) | upstream already has a `predict_` CLI; reuse it and stream | -All share the same verbatim `train.py` and the `ActivationStore` parquet contract. +All share a near-identical `train.py` (current copies wire the opt-in flags) and the `ActivationStore` parquet contract — folding the duplicated train-CLI into a shared `sae` entrypoint is a planned follow-up. From 3c72b01bca4a3ed963365c386897960d2e0bc445 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Wed, 10 Jun 2026 22:37:26 +0000 Subject: [PATCH 4/8] =?UTF-8?q?sae=20recipe=20skill:=20add=20Step=200=20?= =?UTF-8?q?=E2=80=94=20get=20model=20+=20data=20+=20MBridge=20conversion?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The skill assumed 'a checkpoint is in hand'. Add a prerequisites step covering how to get the model (NGC for BioNeMo/Evo2, HF for esm2/codonfm), convert Megatron checkpoints to the MBridge directory that predict_evo2 requires (savanna/nemo2 -> mbridge, incl. the weights_only=False gotcha; HF models skip it), and pull/verify the data corpus (HF dir-name + file-count checks, decompress, chunk to context). Co-Authored-By: Claude Opus 4.8 Signed-off-by: Polina Binder --- .../recipes/bionemo-sae-recipe.md | 33 ++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md index 52c8ab023f..1d6faa4bb1 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md @@ -21,7 +21,38 @@ extractor (model-specific) → ActivationStore parquet shards → train.py (univ ## When this applies -Bringing up an SAE on a new biological foundation model — Evo2, ESM2, CodonFM, Nemotron, Geneformer, etc. A checkpoint (HF or local) is in hand. Scope is the full **extract → train → eval** pipeline. Per-model you write a thin **extractor** (and, for interpretability, **labelers**); everything downstream is shared. +Bringing up an SAE on a new biological foundation model — Evo2, ESM2, CodonFM, Nemotron, Geneformer, etc. Scope is the full **extract → train → eval** pipeline. Per-model you write a thin **extractor** (and, for interpretability, **labelers**); everything downstream is shared. + +## Step 0 — get the model + data (and, for Megatron models, convert to MBridge) + +Before any extraction you need a **checkpoint in the format the model's `predict`/forward expects** and a **sequence corpus**. This is upstream of the recipe — don't bake it into `extract.py`. + +**Model checkpoint:** + +```bash +# BioNeMo / Evo2 etc. live on NGC: +ngc registry model download-version "nvidia/clara/:" --dest ./checkpoints +# HF-native models (ESM2, CodonFM/Encodon) on HuggingFace (use `hf`, not the deprecated huggingface-cli): +hf download --local-dir ./checkpoints/ +``` + +**Convert to MBridge (Megatron models — e.g. Evo2):** `predict_evo2`/Megatron loads an **MBridge checkpoint *directory*** (has `latest_checkpointed_iteration.txt` + sharded weights), **not** a raw HF/savanna file. Convert first; the result is the `--ckpt-dir` you hand the extractor: + +```bash +evo2_convert_savanna_to_mbridge \ + --savanna-ckpt-path --mbridge-ckpt-dir \ + --model-size --tokenizer-path +# (or the nemo2 -> mbridge path if you have a nemo2 checkpoint) +``` + +- **Gotcha:** savanna conversion hits the torch-2.6 `weights_only=True` default → patch the converter's `torch.load(...)` to `weights_only=False` (trusted source); the failure is silent (exit 0, empty dir). See gotcha 7. +- **HF models (esm2/codonfm) skip MBridge entirely** — they load directly from the `.safetensors`/checkpoint. + +**Data corpus:** + +- Pull the sequence set (Evo2 → OpenGenome2; protein → UniRef/etc.). **Verify the download** — HF README dir names are unreliable (OpenGenome2's `jsonl/` is really `json/`); check the tree + a nonzero file count (`curl -s "https://huggingface.co/api/datasets//tree/main" | python3 -m json.tool`). +- Decompress `.gz` if the predict CLI needs plain FASTA, and **chunk to the trained context** (gotcha 8) before extraction. +- Grab a small subset (a few thousand sequences) first to smoke-test the whole pipeline. ## Steps From 180741bea4a5cfc7d04781f23cb4d02982bbde15 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Wed, 10 Jun 2026 23:02:32 +0000 Subject: [PATCH 5/8] sae recipe skill: add model assumptions + fix train.py source (evo2 only) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (1) Add an 'Assumptions about the model' section — the skill assumes the model is already integrated into bionemo (predict_ CLI or HF AutoModel under recipes/_*/), hidden states reachable at a layer, checkpoint loadable (MBridge for Megatron), a known token<->position mapping, float (not bf16) activations, chunkable sequence inputs. If any is false, that's upstream work. (2) Fix the train.py source: on main ONLY evo2/scripts/train.py wires the four opt-in flags (codonfm & esm2 are 0/4). 'Copy from codonfm/evo2' would have handed out the flag-less losing config. Point all three references to evo2's specifically, and reinforce that this drift is why the train-CLI should live in shared sae. Co-Authored-By: Claude Opus 4.8 Signed-off-by: Polina Binder --- .../recipes/bionemo-sae-recipe.md | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md index 1d6faa4bb1..f8fe16a7f8 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md @@ -16,13 +16,24 @@ extractor (model-specific) → ActivationStore parquet shards → train.py (univ - **Extractor** runs the model forward and **streams** layer-L activations *directly* into an `ActivationStore` — no intermediate `.pt` files. The clean pattern (see `evo2/scripts/extract.py`): reuse the model's existing `predict_` CLI but **monkeypatch its per-batch writer** with one that appends `hidden[mask]` to `sae.activation_store.ActivationStore`. Model-specific (~150 lines). - **ActivationStore** (`sae/src/sae/activation_store.py`) is the universal on-disk format: a directory of `shard_{NNNNN}.parquet` + `metadata.json` (`{model_name, layer, hidden_dim, n_samples, n_shards, shard_size, n_sequences}`). -- **train.py** loads via `sae.activation_store.load_activations(cache_dir)` and trains a TopK/ReLU SAE — **near-identical across recipes, but not a blind verbatim copy.** It must wire the opt-in training flags (`--aggregate-loss` / `--dead-count-global` / `--mix-shards` / `--presample-shards`). Start from a **current** recipe's `train.py` (codonfm/evo2 — they already wire them), then change only the docstring + `--wandb-project` default. **Copying an older train.py silently drops those flags → the losing config** (this is exactly how a "reproduce the winner" run quietly turns into a baseline run). Uses `--model-path` only for a cache-validation warning. (The copy-paste is a known smell; the intended end-state is a single shared train-CLI in `sae`.) +- **train.py** loads via `sae.activation_store.load_activations(cache_dir)` and trains a TopK/ReLU SAE — **near-identical across recipes, but not a blind verbatim copy.** It must wire the opt-in training flags (`--aggregate-loss` / `--dead-count-global` / `--mix-shards` / `--presample-shards`). **On `main`, only `evo2/scripts/train.py` wires all four** — `codonfm` and `esm2` are `0/4` (precisely why this belongs in shared `sae`). So copy **evo2's**, then change only the docstring + `--wandb-project` default. **Copying an older train.py silently drops those flags → the losing config** (this is exactly how a "reproduce the winner" run quietly turns into a baseline run). Uses `--model-path` only for a cache-validation warning. (The copy-paste is a known smell; the intended end-state is a single shared train-CLI in `sae`.) - **eval** (`sae.eval`, universal): `reconstruction` (variance explained), `dead_latents` (%), `loss_recovered` (CE fidelity), and `probing` (per-feature AUROC / linear probes / domain-F1 over a labeled `ActivationBuffer`). Probing scoring is **CPU-only** — it reads saved buffers, no model. ## When this applies Bringing up an SAE on a new biological foundation model — Evo2, ESM2, CodonFM, Nemotron, Geneformer, etc. Scope is the full **extract → train → eval** pipeline. Per-model you write a thin **extractor** (and, for interpretability, **labelers**); everything downstream is shared. +## Assumptions about the model (check these first) + +This skill assumes the model is **already integrated into bionemo** — i.e. there's a setup in `bionemo-recipes/recipes/` to build on. If one of these is false, that's upstream work, not part of the SAE recipe: + +1. **The model has an inference path in the repo** — a `predict_` CLI under `bionemo-recipes/recipes/_*/` (Megatron-style, like Evo2), or it's HF-native (`AutoModel.from_pretrained`, like ESM2). If neither exists, you must add the forward pass first. +2. **You can pull hidden states at a chosen layer** — `[B, S, H]` via `--embedding-layer` (predict CLI) or `output_hidden_states` (HF). +3. **The checkpoint loads in the available env** — an **MBridge directory** for Megatron models (convert in Step 0); `.safetensors`/HF otherwise. The Megatron path needs the NVIDIA pytorch container + TransformerEngine. +4. **Known token↔position mapping** — to label/probe, each activation row must map back to a sequence position. Evo2 byte tokenizer = 1 char/token; CodonFM = 1 codon (3-mer)/token; ESM2 = 1 aa/token. Get this wrong and your labels are misaligned. +5. **Activations are float (fp16/fp32), not bf16** — Arrow/NumPy can't store bf16; cast before `ActivationStore.append`. +6. **Inputs are sequences you can chunk/feed** (FASTA/CSV), and you know the model's **trained context length**. + ## Step 0 — get the model + data (and, for Megatron models, convert to MBridge) Before any extraction you need a **checkpoint in the format the model's `predict`/forward expects** and a **sequence corpus**. This is upstream of the recipe — don't bake it into `extract.py`. @@ -79,7 +90,7 @@ recipes// └── scripts/ ├── .sh # orchestrator: chunk → stream-extract → train ├── extract.py # STREAMING: wraps predict_, writes ActivationStore directly (NO .pt) - └── train.py # near-verbatim from a CURRENT recipe (codonfm/evo2): MUST wire the opt-in flags; edit only docstring + wandb default + └── train.py # near-verbatim from evo2/scripts/train.py (the ONLY recipe on main wiring all 4 opt-in flags); edit only docstring + wandb default ``` ### 4. The streaming extractor @@ -197,4 +208,4 @@ After training, run `sae.eval` on a **held-out** cache (same distribution, disjo | `codonfm/` | `extract.py` → custom inference class | new model has its own checkpoint + forward code | | `evo2/` | **streaming** `extract.py` — wraps `predict_evo2`, monkeypatches its writer to an `ActivationStore` (no `.pt`) | upstream already has a `predict_` CLI; reuse it and stream | -All share a near-identical `train.py` (current copies wire the opt-in flags) and the `ActivationStore` parquet contract — folding the duplicated train-CLI into a shared `sae` entrypoint is a planned follow-up. +All share a near-identical `train.py` and the `ActivationStore` parquet contract — **but only evo2's currently wires the opt-in flags** (codonfm/esm2 lag), so copy evo2's. Folding the duplicated train-CLI into one shared `sae` entrypoint (so no recipe can drift) is the planned fix. From 1e3b85bf91816a484670755eda9523924d8ed388 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Wed, 10 Jun 2026 23:04:14 +0000 Subject: [PATCH 6/8] sae recipe skill: caveat that probing lands with the eval PR The static-consistency check flagged sae.eval.probing as not yet on main (it's in the eval recipe PR). Note in the eval section that reconstruction/dead_latents/ loss_recovered are already in sae.eval but probing is the newest module and lands with that PR, so a reader on current main knows it's the dependency. Co-Authored-By: Claude Opus 4.8 Signed-off-by: Polina Binder --- .../sparse_autoencoders/recipes/bionemo-sae-recipe.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md index f8fe16a7f8..55153925ee 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md @@ -192,7 +192,7 @@ These are **general principles**; the numbers are Evo2 examples — **measure th After training, run `sae.eval` on a **held-out** cache (same distribution, disjoint instances): - `reconstruction` → variance explained; `dead_latents` → dead %; `loss_recovered` → CE fidelity (substitute the SAE recon at the layer-L hook). -- For interpretability, build a labeled `ActivationBuffer` (per-token feature codes + concept labels + optional dense-residual twin) and run `sae.eval.probing` — per-feature AUROC, winner's-curse-corrected best-single, SAE-vs-dense probes, domain-F1. Labelers are **per-domain** (DNA / protein / codon); the scoring is shared. +- For interpretability, build a labeled `ActivationBuffer` (per-token feature codes + concept labels + optional dense-residual twin) and run `sae.eval.probing` — per-feature AUROC, winner's-curse-corrected best-single, SAE-vs-dense probes, domain-F1. Labelers are **per-domain** (DNA / protein / codon); the scoring is shared. **Note:** `reconstruction` / `dead_latents` / `loss_recovered` are already in `sae.eval`; **`probing` is the newest module and lands with the eval recipe PR** — if it's not in your tree yet, that PR is the dependency. ## Verifying the recipe works (fastest → most confident) From f2878e434a4451b116b445abc0af900751496f49 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Thu, 11 Jun 2026 18:17:03 +0000 Subject: [PATCH 7/8] sae recipe skill: separate general flags from Evo2-specific values The training-config section + launch command presented Evo2-7B/L26's exact hyperparameters (expansion 16, top-k 128, auxk 2048, mix-shards 10, presample 8) as if canonical, making the skill read Evo2-specific. Reframe: the *flags* are general (turn them on for any model), but the *values* are Evo2-7B examples to re-tune per model -- width/top-k/auxk scale with hidden_dim, and mix/presample only matter for a corpus-ordered cache (set to 1 if shards are pre-shuffled). Co-Authored-By: Claude Opus 4.8 Signed-off-by: Polina Binder --- .../sparse_autoencoders/recipes/bionemo-sae-recipe.md | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md index 55153925ee..882687b316 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md @@ -108,7 +108,7 @@ No `.pt`, ~half the disk, no separate conversion pass. Under DDP each rank write ### 5. Launch the training -The orchestrator (`.sh`) chains chunk → extract → train. Launch with `torchrun`, `--dp-size` = #GPUs. **Always smoke first** (20–100 sequences → confirm loss drops), then the full run. +The orchestrator (`.sh`) chains chunk → extract → train. Launch with `torchrun`, `--dp-size` = #GPUs. **Always smoke first** (20–100 sequences → confirm loss drops), then the full run. The flags below are general; the numeric values are the **Evo2-7B/L26** example — re-tune per model (see "Training config" below). ```bash unset WANDB_API_KEY # a leaked key in the shared env overrides ~/.netrc — you'd log as someone else @@ -137,9 +137,12 @@ Each long step needs an idempotency check on a sentinel the step itself produces **Caveat:** guards check existence, not provenance — `rm -rf` the output dir when the input FASTA / model / layer changes. -## Known-good training config (and why) +## Training config — which knobs to turn on (general) vs. their values (per-model) -These defaults reproduced the best Evo2-7B / layer-26 SAE (~21% dead, ~0.10 FVU). All are **opt-in** in the `sae` package (defaults reproduce older behavior): +Separate two things: + +- **The flags below are general** — turn them on for *any* model. They're all opt-in in the `sae` package (the defaults reproduce older, worse behavior), so a recipe that doesn't pass them silently trains the losing config. +- **The values are not.** The numbers in the launch command (`--expansion-factor 16`, `--top-k 128`, `--auxk 2048`, `--mix-shards 10`, `--presample-shards 8`) are what reproduced the best **Evo2-7B / layer-26** SAE (~21% dead, ~0.10 FVU). **Re-tune per model:** expansion/top-k/auxk scale with `hidden_dim` and the sparsity you want; `--mix-shards`/`--presample-shards` only matter for a **corpus-ordered** cache — set both to `1` if your shards are already shuffled. | flag | why it matters | | ---------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------- | From 2e1ce6f3944c1cd457799da510e85a6e9ebc29c2 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Thu, 11 Jun 2026 18:20:20 +0000 Subject: [PATCH 8/8] sae recipe skill: flag the unverified perf number + note the flags aren't mandatory (1) The ~10x/~17x micro-batch figure was an unverified number inherited from an earlier extraction note (not re-measured); mark it as a hypothesis to measure, not a fact to quote. (2) The opt-in training flags fix specific Evo2 failure modes but are NOT universally required -- CodonFM trained a good SAE with none of them (0/4). Reframe from 'turn on for any model' to 'turn on if you hit the problem it fixes'; they're only non-negotiable for reproducing the Evo2 winner. Co-Authored-By: Claude Opus 4.8 Signed-off-by: Polina Binder --- .../sparse_autoencoders/recipes/bionemo-sae-recipe.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md index 882687b316..1de2b4b142 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md @@ -141,8 +141,8 @@ Each long step needs an idempotency check on a sentinel the step itself produces Separate two things: -- **The flags below are general** — turn them on for *any* model. They're all opt-in in the `sae` package (the defaults reproduce older, worse behavior), so a recipe that doesn't pass them silently trains the losing config. -- **The values are not.** The numbers in the launch command (`--expansion-factor 16`, `--top-k 128`, `--auxk 2048`, `--mix-shards 10`, `--presample-shards 8`) are what reproduced the best **Evo2-7B / layer-26** SAE (~21% dead, ~0.10 FVU). **Re-tune per model:** expansion/top-k/auxk scale with `hidden_dim` and the sparsity you want; `--mix-shards`/`--presample-shards` only matter for a **corpus-ordered** cache — set both to `1` if your shards are already shuffled. +- **The flags are *available*, not mandatory.** All opt-in in `sae` (defaults = older behavior). Each fixes a specific failure mode we hit on **Evo2** — severe dead latents (`--normalize-input`, `--aggregate-loss`), a corpus/kingdom-ordered cache (`--mix-shards`, `--presample-shards`), and DDP dead-counting (`--dead-count-global`). They mattered a lot *there*. **They are not universally required: CodonFM trained a good SAE with none of them** (its `train.py` wires 0/4). So turn each on only if you actually hit the problem it fixes — don't cargo-cult them. *(The one place they're non-negotiable: **reproducing the Evo2 winner** — which is why copying an older, flag-less `train.py` into an Evo2 recipe silently gives the losing config.)* +- **The values are model-specific.** The numbers in the launch command (`--expansion-factor 16`, `--top-k 128`, `--auxk 2048`, `--mix-shards 10`, `--presample-shards 8`) are what reproduced the best **Evo2-7B / layer-26** SAE (~21% dead, ~0.10 FVU). **Re-tune per model:** expansion/top-k/auxk scale with `hidden_dim` and the sparsity you want; `--mix-shards`/`--presample-shards` only matter for a **corpus-ordered** cache — set both to `1` if your shards are already shuffled. | flag | why it matters | | ---------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------- | @@ -179,7 +179,7 @@ These are **general principles**; the numbers are Evo2 examples — **measure th 08. **Long sequences can blow up memory super-linearly on conv/FFT architectures → chunk inputs to the model's trained context before extraction.** *Evo2 example:* Hyena's fftconv OOMs even at micro-batch=1 (intermediates scale super-linearly); chunk to 1B → 8192 bp, 7B → context-extended (check release), 40B → 1M. Don't rely on the inference tool to truncate. 09. **Check your predict CLI's input constraints (compression/format).** *Evo2 example:* `predict_evo2` takes uncompressed FASTA only (`<(zcat ...)` fails); but if your chunker already reads `.gz` → writes plain `.fasta`, no separate gunzip is needed. -10. **micro-batch=1 is rarely optimal — once inputs are short/uniform, raise it.** *Evo2 example:* chunking dropped memory ~10× and gave ~17× per-batch speedup on Evo2 1B, so `--micro-batch-size` could be raised well past 1. +10. **micro-batch=1 is rarely optimal — once inputs are short/uniform, raise it.** The specific figure (chunking dropping memory ~10× and a ~17× per-batch speedup on Evo2 1B) is an **unverified number inherited from an earlier extraction note** — we did *not* re-measure it. Treat it as a hypothesis and **measure your own** (below), don't quote it. **Verify the perf claims (don't trust the constants):** a few-minute single-GPU micro-benchmark —