sae recipe skill: streaming extract + launch step + gotchas#1625
sae recipe skill: streaming extract + launch step + gotchas#1625polinabinder1 wants to merge 8 commits into
Conversation
Update the bionemo-sae-recipe skill: (1) replace the predict->.pt->parquet-shim flow with the STREAMING extractor (wrap predict_<model>, monkeypatch its writer to ActivationStore, no .pt) per evo2/scripts/extract.py; (2) add a launch step with the known-good training config; (3) bake in the issues we hit -- LR cosine must span the full epoch (don't truncate steps), --dead-count-global is a DDP-only no-op (and must be passed, not just named), shard-0 pre-bias is biased (use --presample-shards N>1), and the wandb env (unset WANDB_API_KEY + WANDB_ENTITY); (4) extend scope to eval via sae.eval/probing. Drops the obsolete .pt disk-pressure and shim-throughput caveats. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Polina Binder <pbinder@nvidia.com>
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughThis PR introduces a comprehensive recipe document for building sparse autoencoders on biological foundation models. The document standardizes the end-to-end workflow from checkpoint setup through activation extraction, training, and evaluation, including model-specific implementation patterns, orchestration guidance, and operational troubleshooting. ChangesSparse Autoencoders SAE Recipe Documentation
Estimated code review effort🎯 1 (Trivial) | ⏱️ ~5 minutes Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 inconclusive)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
… + add a perf micro-benchmark The model-architecture gotchas (Hyena OOM, predict CLI input constraints, micro-batch tuning) were Evo2-specific in a general skill. Reframe each as a general principle with Evo2 as the *example*, and add a single-GPU micro-benchmark (micro-batch + seq-length sweeps measuring peak mem + throughput) so the perf claims are measured per-model instead of copied. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Polina Binder <pbinder@nvidia.com>
…wire the opt-in flags 'Copy verbatim' was wrong and dangerous: copying an older train.py silently drops the opt-in training flags (--aggregate-loss/--dead-count-global/--mix-shards/ --presample-shards), turning a 'reproduce the winner' run into a baseline/losing run. Reword all three 'verbatim' mentions: start from a CURRENT recipe (codonfm/evo2) that wires the flags, edit only docstring + wandb default, and note the dedup-into-sae follow-up. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Polina Binder <pbinder@nvidia.com>
The skill assumed 'a checkpoint is in hand'. Add a prerequisites step covering how to get the model (NGC for BioNeMo/Evo2, HF for esm2/codonfm), convert Megatron checkpoints to the MBridge directory that predict_evo2 requires (savanna/nemo2 -> mbridge, incl. the weights_only=False gotcha; HF models skip it), and pull/verify the data corpus (HF dir-name + file-count checks, decompress, chunk to context). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Polina Binder <pbinder@nvidia.com>
…nly) (1) Add an 'Assumptions about the model' section — the skill assumes the model is already integrated into bionemo (predict_<model> CLI or HF AutoModel under recipes/<model>_*/), hidden states reachable at a layer, checkpoint loadable (MBridge for Megatron), a known token<->position mapping, float (not bf16) activations, chunkable sequence inputs. If any is false, that's upstream work. (2) Fix the train.py source: on main ONLY evo2/scripts/train.py wires the four opt-in flags (codonfm & esm2 are 0/4). 'Copy from codonfm/evo2' would have handed out the flag-less losing config. Point all three references to evo2's specifically, and reinforce that this drift is why the train-CLI should live in shared sae. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Polina Binder <pbinder@nvidia.com>
The static-consistency check flagged sae.eval.probing as not yet on main (it's in the eval recipe PR). Note in the eval section that reconstruction/dead_latents/ loss_recovered are already in sae.eval but probing is the newest module and lands with that PR, so a reader on current main knows it's the dependency. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Polina Binder <pbinder@nvidia.com>
The training-config section + launch command presented Evo2-7B/L26's exact hyperparameters (expansion 16, top-k 128, auxk 2048, mix-shards 10, presample 8) as if canonical, making the skill read Evo2-specific. Reframe: the *flags* are general (turn them on for any model), but the *values* are Evo2-7B examples to re-tune per model -- width/top-k/auxk scale with hidden_dim, and mix/presample only matter for a corpus-ordered cache (set to 1 if shards are pre-shuffled). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Polina Binder <pbinder@nvidia.com>
…en't mandatory (1) The ~10x/~17x micro-batch figure was an unverified number inherited from an earlier extraction note (not re-measured); mark it as a hypothesis to measure, not a fact to quote. (2) The opt-in training flags fix specific Evo2 failure modes but are NOT universally required -- CodonFM trained a good SAE with none of them (0/4). Reframe from 'turn on for any model' to 'turn on if you hit the problem it fixes'; they're only non-negotiable for reproducing the Evo2 winner. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Polina Binder <pbinder@nvidia.com>
| 2. **You can pull hidden states at a chosen layer** — `[B, S, H]` via `--embedding-layer` (predict CLI) or `output_hidden_states` (HF). | ||
| 3. **The checkpoint loads in the available env** — an **MBridge directory** for Megatron models (convert in Step 0); `.safetensors`/HF otherwise. The Megatron path needs the NVIDIA pytorch container + TransformerEngine. | ||
| 4. **Known token↔position mapping** — to label/probe, each activation row must map back to a sequence position. Evo2 byte tokenizer = 1 char/token; CodonFM = 1 codon (3-mer)/token; ESM2 = 1 aa/token. Get this wrong and your labels are misaligned. | ||
| 5. **Activations are float (fp16/fp32), not bf16** — Arrow/NumPy can't store bf16; cast before `ActivationStore.append`. |
There was a problem hiding this comment.
Based on my work on the multimodal recipe, would be helpful to develop a separate skill for layer selection to guide which activations are most relevant --ideas we discussed including:
- designing linear probe + rank analysis for layers to determine the most informative layers
- doing pca style analysis on each layer (rank decomposition)
| - **The flags are *available*, not mandatory.** All opt-in in `sae` (defaults = older behavior). Each fixes a specific failure mode we hit on **Evo2** — severe dead latents (`--normalize-input`, `--aggregate-loss`), a corpus/kingdom-ordered cache (`--mix-shards`, `--presample-shards`), and DDP dead-counting (`--dead-count-global`). They mattered a lot *there*. **They are not universally required: CodonFM trained a good SAE with none of them** (its `train.py` wires 0/4). So turn each on only if you actually hit the problem it fixes — don't cargo-cult them. *(The one place they're non-negotiable: **reproducing the Evo2 winner** — which is why copying an older, flag-less `train.py` into an Evo2 recipe silently gives the losing config.)* | ||
| - **The values are model-specific.** The numbers in the launch command (`--expansion-factor 16`, `--top-k 128`, `--auxk 2048`, `--mix-shards 10`, `--presample-shards 8`) are what reproduced the best **Evo2-7B / layer-26** SAE (~21% dead, ~0.10 FVU). **Re-tune per model:** expansion/top-k/auxk scale with `hidden_dim` and the sparsity you want; `--mix-shards`/`--presample-shards` only matter for a **corpus-ordered** cache — set both to `1` if your shards are already shuffled. | ||
|
|
||
| | flag | why it matters | |
There was a problem hiding this comment.
based on my debugging for the multimodal sae recipe, we likely want to have the knowledge here tailored how you should change your normalization strategy for the two approaches
| ## Verifying the recipe works (fastest → most confident) | ||
|
|
||
| 1. **Mechanical** — pipeline runs end-to-end, `checkpoint_final.pt` exists. Smoke on 20–100 sequences (minutes). | ||
| 2. **Numerical** — `train.py` log shows loss ↓, FVU < 1, dead-% trending toward ~20% (not stuck at ~80%). If dead-% is stuck high, check normalize-input / presample / the LR horizon (gotcha 1). |
There was a problem hiding this comment.
you also want to start debugging if the variance explained is suspiciously high (i.e., > 95%) or some threshold as usually this indicates so sort of error with the norms of the inputs, normalization approahc, dead latents, etc.
|
@coderabbitai review |
✅ Action performedReview finished.
|
Revamps the
bionemo-sae-recipeskill to match what we actually built + learned:predict → .pt → pt_to_parquet shimflow with the streaming pattern (wrappredict_<model>, monkeypatch its per-batch writer to append tosae.ActivationStore, no.pt). Drops the now-obsolete.ptdisk-pressure and shim-throughput caveats.torchruninvocation with the known-good config (normalize-input, aggregate-loss, dead-count-global, mix-shards 10, presample-shards 8, full-epoch cosine, dp-size = #GPUs) + the wandb env.--dead-count-globalis a DDP-only no-op and must be passed not just named; shard-0 pre-bias is corpus-biased (use--presample-shards N>1);unset WANDB_API_KEY+ setWANDB_ENTITY.sae.eval/probing(dead% / FVU / loss-recovered + per-feature AUROC).Still general (esm2 / codonfm / evo2 templates); Evo2 specifics remain as caveats.
Summary by CodeRabbit
Release Notes