Skip to content

sae recipe skill: streaming extract + launch step + gotchas#1625

Open
polinabinder1 wants to merge 8 commits into
mainfrom
pbinder/sae-recipe-skill
Open

sae recipe skill: streaming extract + launch step + gotchas#1625
polinabinder1 wants to merge 8 commits into
mainfrom
pbinder/sae-recipe-skill

Conversation

@polinabinder1

@polinabinder1 polinabinder1 commented Jun 10, 2026

Copy link
Copy Markdown
Collaborator

Draft. Supersedes the fork-based #1578.

Revamps the bionemo-sae-recipe skill to match what we actually built + learned:

  • Streaming extractor — replaces the old predict → .pt → pt_to_parquet shim flow with the streaming pattern (wrap predict_<model>, monkeypatch its per-batch writer to append to sae.ActivationStore, no .pt). Drops the now-obsolete .pt disk-pressure and shim-throughput caveats.
  • Launch step — a concrete torchrun invocation with the known-good config (normalize-input, aggregate-loss, dead-count-global, mix-shards 10, presample-shards 8, full-epoch cosine, dp-size = #GPUs) + the wandb env.
  • Hard-won gotchas — LR cosine must span the full epoch (capping steps collapses LR early); --dead-count-global is a DDP-only no-op and must be passed not just named; shard-0 pre-bias is corpus-biased (use --presample-shards N>1); unset WANDB_API_KEY + set WANDB_ENTITY.
  • Eval scope — extends past train to sae.eval / probing (dead% / FVU / loss-recovered + per-feature AUROC).

Still general (esm2 / codonfm / evo2 templates); Evo2 specifics remain as caveats.

Summary by CodeRabbit

Release Notes

  • Documentation
    • Added comprehensive recipe documentation for sparse autoencoders on biological foundation models, including end-to-end workflow procedures, activation extraction and training pipelines, evaluation methodologies, operational guidance, troubleshooting steps, and model-specific tuning recommendations.

Update the bionemo-sae-recipe skill: (1) replace the predict->.pt->parquet-shim
flow with the STREAMING extractor (wrap predict_<model>, monkeypatch its writer
to ActivationStore, no .pt) per evo2/scripts/extract.py; (2) add a launch step
with the known-good training config; (3) bake in the issues we hit -- LR cosine
must span the full epoch (don't truncate steps), --dead-count-global is a DDP-only
no-op (and must be passed, not just named), shard-0 pre-bias is biased (use
--presample-shards N>1), and the wandb env (unset WANDB_API_KEY + WANDB_ENTITY);
(4) extend scope to eval via sae.eval/probing. Drops the obsolete .pt disk-pressure
and shim-throughput caveats.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Polina Binder <pbinder@nvidia.com>
@copy-pr-bot

copy-pr-bot Bot commented Jun 10, 2026

Copy link
Copy Markdown

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai

coderabbitai Bot commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: bf0879a6-dce7-41dd-bf69-c806f6fafb17

📥 Commits

Reviewing files that changed from the base of the PR and between e407165 and 2e1ce6f.

📒 Files selected for processing (1)
  • bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md

📝 Walkthrough

Walkthrough

This PR introduces a comprehensive recipe document for building sparse autoencoders on biological foundation models. The document standardizes the end-to-end workflow from checkpoint setup through activation extraction, training, and evaluation, including model-specific implementation patterns, orchestration guidance, and operational troubleshooting.

Changes

Sparse Autoencoders SAE Recipe Documentation

Layer / File(s) Summary
Recipe overview and universal contract
bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md
Introduces the universal SAE pipeline contract: model-specific extractors stream activations into a shared ActivationStore (parquet shards + metadata.json), feeding a near-universal training step (train.py) and universal evaluation flow (reconstruction, dead-latent detection, loss recovery, and probing).
Data preparation and streaming extraction
bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md
Documents Step 0 setup including checkpoint conversion (Megatron MBridge for Evo2-style models versus skipping for HuggingFace-native models), corpus preparation, input chunking, smoke-test runs, and the streaming extractor implementation that monkeypatches model writers with file-based DDP synchronization via polling (avoiding dist.barrier()).
Training launch and flag configuration
bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md
Provides torchrun chaining guidance, WANDB environment setup, representative Evo2-style launch flags with four opt-in parameters, cache-guard idempotency using metadata.json existence sentinel, and detailed tuning strategy for model-specific flag values (layer selection, learning rates, batch sizes).
Operational guidance, gotchas, and validation
bionemo-recipes/interpretability/sparse_autoencoders/recipes/bionemo-sae-recipe.md
Enumerates operational gotchas (training dynamics, wandb hygiene, container assumptions, checkpoint-loading quirks like weights_only=True, extraction memory constraints), evaluation workflow including CPU-only probing, fast verification steps (end-to-end runs, numeric trends, checkpoint shape checks), and references to mirror recipes for esm2, codonfm, and evo2 models.

Estimated code review effort

🎯 1 (Trivial) | ⏱️ ~5 minutes

Suggested reviewers

  • jstjohn
  • pstjohn
  • trvachov

Poem

🐰 A recipe so grand, with activations to sift,
SAEs in your sparse autoencoders, a computational gift!
From extraction through training, to eval with care,
The pipeline now documented, for all who would dare! ✨

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 inconclusive)

Check name Status Explanation Resolution
Description check ❓ Inconclusive The PR description is comprehensive and covers the main changes (streaming extractor, launch step, gotchas, eval scope), but does not follow the required template structure with Description, Usage, Type of changes, and CI Pipeline Configuration sections. Restructure the description to follow the template format: add a formal Description section, provide Usage examples, explicitly mark the change type (Documentation update), and configure CI pipeline labels as needed.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the three main changes in the PR: streaming extract, launch step, and gotchas documentation.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch pbinder/sae-recipe-skill

Comment @coderabbitai help to get the list of available commands and usage tips.

polinabinder1 and others added 5 commits June 10, 2026 22:25
… + add a perf micro-benchmark

The model-architecture gotchas (Hyena OOM, predict CLI input constraints,
micro-batch tuning) were Evo2-specific in a general skill. Reframe each as a
general principle with Evo2 as the *example*, and add a single-GPU micro-benchmark
(micro-batch + seq-length sweeps measuring peak mem + throughput) so the perf
claims are measured per-model instead of copied.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Polina Binder <pbinder@nvidia.com>
…wire the opt-in flags

'Copy verbatim' was wrong and dangerous: copying an older train.py silently drops
the opt-in training flags (--aggregate-loss/--dead-count-global/--mix-shards/
--presample-shards), turning a 'reproduce the winner' run into a baseline/losing
run. Reword all three 'verbatim' mentions: start from a CURRENT recipe (codonfm/evo2)
that wires the flags, edit only docstring + wandb default, and note the dedup-into-sae
follow-up.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Polina Binder <pbinder@nvidia.com>
The skill assumed 'a checkpoint is in hand'. Add a prerequisites step covering
how to get the model (NGC for BioNeMo/Evo2, HF for esm2/codonfm), convert Megatron
checkpoints to the MBridge directory that predict_evo2 requires (savanna/nemo2 ->
mbridge, incl. the weights_only=False gotcha; HF models skip it), and pull/verify
the data corpus (HF dir-name + file-count checks, decompress, chunk to context).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Polina Binder <pbinder@nvidia.com>
…nly)

(1) Add an 'Assumptions about the model' section — the skill assumes the model is
already integrated into bionemo (predict_<model> CLI or HF AutoModel under
recipes/<model>_*/), hidden states reachable at a layer, checkpoint loadable
(MBridge for Megatron), a known token<->position mapping, float (not bf16)
activations, chunkable sequence inputs. If any is false, that's upstream work.

(2) Fix the train.py source: on main ONLY evo2/scripts/train.py wires the four
opt-in flags (codonfm & esm2 are 0/4). 'Copy from codonfm/evo2' would have handed
out the flag-less losing config. Point all three references to evo2's specifically,
and reinforce that this drift is why the train-CLI should live in shared sae.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Polina Binder <pbinder@nvidia.com>
The static-consistency check flagged sae.eval.probing as not yet on main (it's in
the eval recipe PR). Note in the eval section that reconstruction/dead_latents/
loss_recovered are already in sae.eval but probing is the newest module and lands
with that PR, so a reader on current main knows it's the dependency.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Polina Binder <pbinder@nvidia.com>
@polinabinder1 polinabinder1 marked this pull request as ready for review June 11, 2026 18:13
polinabinder1 and others added 2 commits June 11, 2026 18:17
The training-config section + launch command presented Evo2-7B/L26's exact
hyperparameters (expansion 16, top-k 128, auxk 2048, mix-shards 10, presample 8)
as if canonical, making the skill read Evo2-specific. Reframe: the *flags* are
general (turn them on for any model), but the *values* are Evo2-7B examples to
re-tune per model -- width/top-k/auxk scale with hidden_dim, and mix/presample
only matter for a corpus-ordered cache (set to 1 if shards are pre-shuffled).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Polina Binder <pbinder@nvidia.com>
…en't mandatory

(1) The ~10x/~17x micro-batch figure was an unverified number inherited from an
earlier extraction note (not re-measured); mark it as a hypothesis to measure,
not a fact to quote. (2) The opt-in training flags fix specific Evo2 failure modes
but are NOT universally required -- CodonFM trained a good SAE with none of them
(0/4). Reframe from 'turn on for any model' to 'turn on if you hit the problem it
fixes'; they're only non-negotiable for reproducing the Evo2 winner.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Polina Binder <pbinder@nvidia.com>
@polinabinder1 polinabinder1 marked this pull request as draft June 11, 2026 21:18
2. **You can pull hidden states at a chosen layer** — `[B, S, H]` via `--embedding-layer` (predict CLI) or `output_hidden_states` (HF).
3. **The checkpoint loads in the available env** — an **MBridge directory** for Megatron models (convert in Step 0); `.safetensors`/HF otherwise. The Megatron path needs the NVIDIA pytorch container + TransformerEngine.
4. **Known token↔position mapping** — to label/probe, each activation row must map back to a sequence position. Evo2 byte tokenizer = 1 char/token; CodonFM = 1 codon (3-mer)/token; ESM2 = 1 aa/token. Get this wrong and your labels are misaligned.
5. **Activations are float (fp16/fp32), not bf16** — Arrow/NumPy can't store bf16; cast before `ActivationStore.append`.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on my work on the multimodal recipe, would be helpful to develop a separate skill for layer selection to guide which activations are most relevant --ideas we discussed including:

  • designing linear probe + rank analysis for layers to determine the most informative layers
  • doing pca style analysis on each layer (rank decomposition)

- **The flags are *available*, not mandatory.** All opt-in in `sae` (defaults = older behavior). Each fixes a specific failure mode we hit on **Evo2** — severe dead latents (`--normalize-input`, `--aggregate-loss`), a corpus/kingdom-ordered cache (`--mix-shards`, `--presample-shards`), and DDP dead-counting (`--dead-count-global`). They mattered a lot *there*. **They are not universally required: CodonFM trained a good SAE with none of them** (its `train.py` wires 0/4). So turn each on only if you actually hit the problem it fixes — don't cargo-cult them. *(The one place they're non-negotiable: **reproducing the Evo2 winner** — which is why copying an older, flag-less `train.py` into an Evo2 recipe silently gives the losing config.)*
- **The values are model-specific.** The numbers in the launch command (`--expansion-factor 16`, `--top-k 128`, `--auxk 2048`, `--mix-shards 10`, `--presample-shards 8`) are what reproduced the best **Evo2-7B / layer-26** SAE (~21% dead, ~0.10 FVU). **Re-tune per model:** expansion/top-k/auxk scale with `hidden_dim` and the sparsity you want; `--mix-shards`/`--presample-shards` only matter for a **corpus-ordered** cache — set both to `1` if your shards are already shuffled.

| flag | why it matters |

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

based on my debugging for the multimodal sae recipe, we likely want to have the knowledge here tailored how you should change your normalization strategy for the two approaches

## Verifying the recipe works (fastest → most confident)

1. **Mechanical** — pipeline runs end-to-end, `checkpoint_final.pt` exists. Smoke on 20–100 sequences (minutes).
2. **Numerical** — `train.py` log shows loss ↓, FVU < 1, dead-% trending toward ~20% (not stuck at ~80%). If dead-% is stuck high, check normalize-input / presample / the LR horizon (gotcha 1).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you also want to start debugging if the variance explained is suspiciously high (i.e., > 95%) or some threshold as usually this indicates so sort of error with the norms of the inputs, normalization approahc, dead latents, etc.

@polinabinder1

Copy link
Copy Markdown
Collaborator Author

@coderabbitai review

@coderabbitai

coderabbitai Bot commented Jun 12, 2026

Copy link
Copy Markdown
Contributor
✅ Action performed

Review finished.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@polinabinder1 polinabinder1 marked this pull request as ready for review June 12, 2026 05:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants