Tool: evaluate layer-wise numerical-error propagation#525
Tool: evaluate layer-wise numerical-error propagation#525jlamypoirier wants to merge 31 commits into
Conversation
A new `tools/evaluate_precision.py` (`RunnableConfig`) drives a fp32 reference run plus one one-iteration trainer run per named variant from a Fast-LLM training YAML, then extracts per-layer forward activations and input gradients from the saved tensor logs and reports per-tensor RMS and max diffs (absolute and scaled). Variants are flat dicts of dotted-path overrides, the same syntax as Fast-LLM CLI key=value args, so they can sweep arbitrary configuration knobs (dtype, attention implementation, optimizer dtype, etc.) — not just compute_dtype. Also moves `compare_tensor_logs.py` into the `fast_llm` package so it is importable from `tools/` (the test tree isn't on sys.path for script entry points), and factors a `_compute_diff` helper out of `CompareConfig.compare_tensors` so the tool can extract numbers for every tensor rather than only those that breach a tolerance. Existing test callers are unaffected. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The tool now takes a single YAML containing `pretrained:` (the checkpoint that defines the model architecture + weights), `variants:`, `output_dir:` and a few optional knobs (`model_type`, `num_samples`, `micro_batch_size`, `sequence_length`). The training/optimizer/data sections of the underlying training config are hardcoded — they have no bearing on the propagation measurement (1 iteration, no checkpoint save, random tokens, dummy learning rate, optimization dtype forced to float32 alongside compute dtype). A variant can still override any of the hardcoded fields via the dotted-path mechanism if needed. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The tool's input mirrors the trainer config's top-level shape: both `model:` (FastLLMModelConfig dict) and `pretrained:` are user-facing, and either or both may be set. Pretrained-from-HF is one config choice among many — a user can also specify the architecture inline, or load from HF and override individual fields. The forced fp32 dtypes and tool-required debug levels are now applied as overrides on top of whatever the user supplies, instead of being hardcoded into the model section. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The tool now inherits from `PretrainedGPTModelConfig` so `model` and `pretrained` are typed `FastLLMModelConfig` / `CheckpointLoadConfig` fields rather than loose dicts — validated, autocompleted, and introspectable like any other Fast-LLM config block. Per-variant trainer configs are built with `TrainerConfig.get_subclass(...) .from_dict(base, *updates)` instead of mutating a dict and round-tripping through YAML. Updates use tuple-keyed dotted paths so forced-fp32, variant overrides, and tool-required debug-logging overrides compose cleanly in the right precedence. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
6431307 to
4c444d8
Compare
`transformers.PretrainedConfig.to_dict()` serializes a growing set of generic defaults (generation knobs, family markers, encoder-decoder flags). The Fast-LLM allowlist covered only a subset, so loading any modern HF Llama checkpoint via `pretrained.format: llama` tripped the coverage walker on keys like `torchscript`, `is_decoder`, `is_llama_config`, `rope_interleaved`, and the full set of generation defaults. Fill in the missing entries, grouped by category. None of them are architecture knobs that Fast-LLM consumes. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Drop step / shape / max_rel columns, shorten the tensor name to the description after the colon, reorder to Tensor / Kind / Relative / Absolute / Max / Scale, format Relative as percent and the rest with `.3g`. The JSON report keeps every field. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Drop the separate Kind column and append `(fw)` / `(bw)` to the shortened tensor name. Switch numeric formatting to fixed precision: Relative shows `.2f` percent, Absolute / Max / Scale show `.2e` scientific. Every column now lines up on a consistent digit count. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Scientific notation was overkill for values that mostly land between 0.01 and a few hundred. `.3f` is more readable while keeping the per-column digit count consistent. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Fast-LLM's `Run.__init__` picks the next free `runs/<n>` subdirectory based on what already exists, but `_artifact_path` reads `runs/0` unconditionally. Without this wipe, re-running the tool against the same `output_dir` reads stale artifacts from the first invocation and silently reports old numbers — even though the trainer correctly ran with the new config. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add a `data_path` field to the tool. When set, the tool lazily generates a tokenized memmap dataset with random advantages and old_logprobs at the given path (via the test helper `tests/utils/dataset._get_test_dataset`) and uses it as the training input. Required for policy-gradient losses like GSPO/GRPO that consume those fields. Without it, the tool falls back to the random token generator as before. Console table now formats numeric columns with `.4g` so 1e-7-scale GSPO gradients aren't rounded to zero while normal CE-magnitude values still read as fixed-point numbers. Rename `download_santacoder_tokenizer` to `download_test_tokenizer` — it actually downloads the GPT-2 tokenizer. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
After the per-tensor tables, emit a short summary block per variant showing first/last/max/median for forward and backward separately. Aggregates over the intermediate layers per metric column (max and median are computed per-column, so each row is a per-metric envelope of the intermediate band rather than the metrics of any single layer). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Single compact table with one row per variant and columns for fw/bw first/last/max/median Relative %. Max/median are over intermediate layers (excluding first/last) when there is at least one intermediate row. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Rename `max`/`median` columns to `mid max`/`mid med` and add a header note (`mid = excluding first/last`) so it's clear the aggregation excludes the boundary layers. Also fix a column-collision bug where labels at exactly the cell width touched without separator. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Each variant now occupies two rows in the summary (fw on the first, bw on the second), with the metric columns shared. Reads more naturally and keeps the table half as wide. Percent precision goes from .2f to .3f so single-digit-percent differences between variants are visible. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Top header line groups columns under `fw` / `bw`; the second line lists the per-pass aggregations. Aggregations are ordered chronologically along the pass — first → mid med → mid max → last — so reading left to right traces the propagation. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds an `fp32_lm_head` field on `LanguageModelHeadConfig`. When `True`, the LM head linear's input and weight are upcast to FP32 before the matmul, matching vLLM's `bf16_last_layer_fp32` quantization. This lets the trainer compute log-probabilities at the same numerical precision as the actor's sampling, so the importance-sampling ratio starts near 1.0 instead of being artificially inflated by a trainer/actor precision mismatch. The detached FP32 weight has `requires_grad=False`, which makes `output_parallel_linear_backward` skip the weight-grad path. The FSDP gradient contract is restored by computing `grad_weight = grad.t() @ saved_input` explicitly and accumulating into the original BF16 param's `grad_buffer` via `accumulate_gradient`. Off by default — disabled path is byte-identical to before. Cherry-picked from #526 to unblock the precision-evaluation tool's GSPO smoke test, which compares fp32_lm_head=true vs false. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Instead of generic `first` / `last` headers in the summary, use the actual layer name pulled from the matching tensor's `Global <layer> <kind>:` prefix. For the SmolLM2 smoke run that surfaces as `embeddings` / `head` on fw and `head` / `decoder.0` on bw — directly showing which layer the boundary values come from rather than making the reader guess. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…den_states Previously the only way to get a non-layer-output tensor (e.g. the LM head's logits) into `tensor_logs` was to crank `model_debug_level`, which logs every single `_debug`-emitted tensor (~700 per step for a 30-layer model). Add a `MultiStageConfig.debug_hidden_states_log: list[str]` field — regex patterns that get appended to each model input's `output_hidden_states` set. Matching tensors are still populated into `kwargs[hidden_states]` (existing contract for the HF inference wrapper); now they're also written to `tensor_logs` so the precision tool can compare them across variants. `_debug` already had the `output_hidden_state`-matched branch but only used it to populate `kwargs[hidden_states]`. Extending it to also call `log_distributed_tensor` at a fixed verbosity (13, matching the test convention so samples are recorded) is a small gating change. Plumbed through `GPTModel.get_preprocessing_config` → `LanguageModelBatchPreprocessingConfig.output_hidden_states` → `LanguageModelBatch.get_model_inputs`, which compiles the patterns and unions them into each `LanguageModelInput.output_hidden_states`. The precision tool now sets `[r"head\.logits"]` and surfaces logits as a dedicated `logits` column on the fw side of the summary table. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The head's `logits` tensor has `requires_grad=False` (output of a custom-autograd Function), so the existing `_debug(logits, ...)` could only capture the forward value. Add a second `_debug(grad, "logits.grad", ...)` call right after the loss returns the explicit `dL/d_logits` so the gradient is captured at the same fidelity. With the precision tool's `output_hidden_states` pattern `r"head\.logits"`, both `head.logits` and `head.logits.grad` end up in tensor_logs. Tool summary surfaces both via dedicated `logits` columns — placed at end-of-fw and start-of-bw chronologically. For GSPO the bw-logits column reveals that the dL/dlogits computation itself is extremely precise (~0.001% relative error), and the apparent backward noise actually enters through the head matmul further downstream. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…alues `.3f%` was rounding the bw-logits values down to 0.001%-0.000%, hiding real signal. Switch to `.4g%` so values across 5 orders of magnitude (0.0001% to ~20%) all render with meaningful precision; large values keep 4 significant figures, tiny ones spell out their leading non-zero digits or fall back to scientific. Bw column order is now first / logits / mid med / mid max / last so `logits` sits right after `head` (the first bw row) — semantically the gradient at logits is what the head's backward consumes before producing the gradient at its input. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Keep the prior `.3f%` default in the summary so most columns still show `0.000%` / `12.672%` style values, but compute a per-column decimal count based on the smallest non-zero value in that column — bumping up just enough that every cell carries at least two significant figures. Decimal count is uniform within a column. For the GSPO run, only the bw-logits column hits the threshold and gets bumped from 3 to 5 decimals, surfacing values like `0.00095%` that previously rounded to `0.001%` or worse. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Cell width drops from `max_label + 1` to `max_label`, inter-cell sep from two spaces to one, group sep from four spaces to three. About 18 chars narrower on the GSPO smoke run with no loss of alignment or readability. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Lets `pretrained.path: org/model-id` resolve via huggingface_hub.snapshot_download when not a local directory, matching transformers' from_pretrained behavior. Local paths pass through unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Two ready-to-run configs for tools/evaluate_precision: smol.yaml sweeps precision-stability features (full_precision_gradients, full_precision_residual, fp32_lm_head) on SmolLM2-135M; smol_gspo.yaml repeats the sweep with the GSPO policy-gradient loss enabled. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
A single forward+backward pass with micro_batch_size=1 has no gradient accumulation, so toggling full_precision_gradients produces bit-identical results to the bf16 baseline. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sample precision-evaluation runsOutput of
|
Enables debug_all_param_gradients so every parameter's reduced gradient is captured in tensor_logs alongside the existing layer activations and input gradients. New rows are tagged with kind 'grad' and appear in the per-variant table but stay out of the fw/bw summary table. Also makes the per-variant table's Tensor column width fit the longest name (parameter gradients can be 40+ chars) and bumps the Relative column to adaptive precision (capped at 5 decimals) so legitimately tiny values stay legible. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Group rows in the per-variant tables by display group with blank lines between fw, bw, and grad. The reduce_gradients hook emits parameter gradients chronologically interleaved with the backward pass, which made the previous table hard to scan. Display grouping is independent of `kind` so the summary aggregation is unaffected — head.logits.grad just moves to the bw block visually. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Each pass gets its own self-contained Variant x columns table with labels picked from the actual first/last logged tensor. Weight gradients get a head/mid med/mid max/embeddings layout mirroring the bw structure; the grad table makes large norm_1 outliers (>200% relative) immediately visible at a glance. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replace the chronological first/last columns in the grad table with named lookups (lm_head / embeddings) and split the intermediate aggregation by category: linear weights, norm weights, biases. The bias columns appear only when biases exist. lm_head shows n/a when the LM head weight is tied to the embedding (e.g. SmolLM2), since the combined gradient is recorded under the embedding parameter. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add `sample_level_overrides: dict[str, int]` (regex pattern -> level) to `TensorLogsConfig`. `log_tensor` raises the effective level for any tensor whose logged name matches a pattern, so callers can collect more samples for specific tensors without changing the default. Useful for sparsely-non-zero tensors like embedding-weight gradients, where the default uniform stride misses every non-zero row. evaluate_precision: switch `num_samples` to actually drive the level (was only cropping the text log), bump default to 8192, default sequence length to 2048 in the example yamls, and add a 1M-sample override for `Global gradient: embeddings.*` to make embedding-grad errors measurable on small batches. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Within-engine precision: chosen-token log πAdded per-position Bug fix bundled in: Fast-LLM's Smol (random labels)
Smol_GSPO (RL data)
Reference scale is ~11–12 nats, so 1% relative ≈ 0.1 nats absolute. Per-token bias sits at 0.005–0.014 nats. Findings1. Within-engine bf16 vs fp32 is small across the board. All RMS values < 1%, all bias values < 0.05% (≈0.005 nats), all correlations ≥0.9995, all slopes ≈1.000. There's no systematic distortion — just per-token decorrelated noise plus negligible mean shift. The intrinsic precision of a single bf16 engine compared to its fp32 equivalent is not on a scale that would cause RL collapse on its own. 2. 3. 4. Caveats and future workThis is a single-engine, single-step, small-model measurement (SmolLM2-135M, 1 fwd+bwd). The literature reports vLLM-vs-trainer log-prob mismatches of 2–24 nats per sequence and RL training collapse without precision fixes; our largest within-engine bias is ~0.005 nats. The gap is real, but many things differ between the two settings, any combination of which could be responsible:
The most direct follow-up — and the one closest to what the literature actually measures — is a vLLM-vs-trainer chosen-logprob comparison at matched scale. That would isolate the cross-engine factor specifically; combining it with the within-engine measurements here would let us decompose the literature's 2–24 nat mismatch into engine-mismatch vs other factors. Out of scope for this PR. |
…riants - New `chosen_logprob` LM loss: logs `log_softmax(logits)[label]` per position with no gradient contribution. Tool auto-adds it and surfaces a dedicated summary with bias, correlation, slope, and residual-after-linear-fit. - `_compute_diff` reports bias_abs/rel, correlation, slope, residual_rms_abs/rel — the linear decomposition separates systematic shift/scale from per-position noise. - Per-variant auto-calibrated power-of-2 gradient scale: each variant runs a calibration pass at scale=1 to measure max unscaled gradient, then the real run picks the largest power-of-2 scale that fits within fp16 range (with a small safety factor for fused-kernel partial sums). `_compare` unscales per variant. - Tool: backend-override mechanism (`_torch_backend.*`) and `_torch_matmul_precision` variant keys for diagnostic variants. New variants: `bf16_in_fp32_out` (probes whether `fp32_lm_head`'s gain is from output dtype vs matmul precision), `bf16_reduced_reduction` (probes the split-K reduction path), and a full fp16 sweep mirroring the bf16 variants. - Fix: `data.micro_batch_size` in Fast-LLM is the per-sample sequence length, not the batch dim. Tool was passing 1 → every prior run was on 1-token inputs. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
FP16 vs BF16 within-engine: ~8× precision improvement, no biasAdded an FP16 sweep ( Chosen-token log π — smol (random labels)
Chosen-token log π — smol_GSPO (RL data)
Gradients — smol_GSPO (RL data)
Findings1. FP16 gives ~8× precision reduction across the board. Per-token chosen_logprob: 0.473%→0.057% (smol) and 0.931%→0.113% (gspo) — both 8.2-8.3× reduction, matching the 7→10 mantissa-bit ratio. Gradients show the same ~8× ratio across linear/norm/embedding weights. Bias collapses too: from ~0.05% in bf16 to ~0.003-0.005% in fp16. 2. FP16 + fp32_lm_head / fp32_residual add little on top. Unlike bf16 where 3. Correlation slope is exactly 1.0 for fp16. Bf16 had slope deviations from 1.0 in the 0.0008-0.0011 range (very small systematic distortion); fp16's slope is 0.99979-1.00000 — effectively unity. No structural distortion at all. 4. The remaining caveats from the prior bf16 analysis still stand. This is single-engine, single-step, SmolLM2-135M. The literature's reported RL collapse on bf16 still isn't visible at this scale, and we can't probe cross-engine alignment with this tool. What this commit does settle: FP16 has the expected ~8× intrinsic precision over BF16 in our setting, which is consistent with the precision-based mechanism that the FP16 paper (Liu et al. 2510.26788) invokes. The natural next step to actually attribute the literature's RL-stability claim is still a vLLM-vs-trainer log-prob diff — out of scope for this PR. Committed in |
Summary
tools/evaluate_precision.py— inheritsPretrainedGPTModelConfig(somodel:andpretrained:are real typed Config fields) and addsvariants:,output_dir:, and a few optional knobs. Runs a fp32 reference plus one trainer invocation per variant in-process; captures per-layer forward activations and input gradients via the standard tensor-logs pipeline; emits per-tensor RMS / max diffs as a console table +precision_report.json.key=valueargs) so a variant can sweep any config knob — attention implementation, optimizer dtype, fused vs unfused, etc.TrainerConfig.get_subclass(...).from_dict(base, fp32_dtypes, variant_updates, tool_overrides). Tuple-keyed updates compose in precedence order: forced fp32 → variant overrides (which can re-override fp32) → tool-required debug-logging overrides (which always win).model,pretrained,variants,output_dir,num_samples,micro_batch_size,sequence_length.compare_tensor_logs.pyfromtests/utils/intofast_llm/engine/config_utils/so it's importable fromtools/, and factors a_compute_diffhelper out ofCompareConfig.compare_tensorsso the tool can extract numbers for every tensor — not only those that breach a tolerance. Three test callers updated; behaviour unchanged.fast_llm/engine/checkpoint/huggingface.py) with the genericPretrainedConfigkeys newertransformersversions serialize: generation defaults, encoder-decoder flags, family markers,torchscript,is_decoder, etc. Without this, loading any modern HF Llama checkpoint trips the coverage walker. None are architecture knobs Fast-LLM consumes.Usage
Fast-LLM's HF loader reads weights from a local directory, so HF Hub IDs need to be
snapshot_download'd first.model:andpretrained:can also be combined — pretrained provides architecture+weights,model:overrides individual fields.Test plan
huggingface_hub). Reference fp32 + bf16 variant ran end-to-end; per-layer RMS/max table populated for all 30 decoder layers + embeddings + head, fw + bw; JSON artifact round-trips throughjson.load. Output shows propagated error growing with depth, with sharp jumps at layers where activation magnitude regime changes (e.g.ref_scale6 → 777 around block 11, bf16 RMS rel 10% → 0.7% → back up to 13% at block 28).compare_tensor_logs.pyand the refactoredcompare_tensors.🤖 Generated with Claude Code