Skip to content

[PyTorch Debug] Add scale_inv_std stat and skip NVFP4 layers in LogFp8TensorStats#3044

Open
pggPL wants to merge 3 commits into
NVIDIA:mainfrom
pggPL:debug_log_fp_fixes
Open

[PyTorch Debug] Add scale_inv_std stat and skip NVFP4 layers in LogFp8TensorStats#3044
pggPL wants to merge 3 commits into
NVIDIA:mainfrom
pggPL:debug_log_fp_fixes

Conversation

@pggPL
Copy link
Copy Markdown
Collaborator

@pggPL pggPL commented May 26, 2026

Description

Addresses items 1 and 2 from #2801:

1. Adds `scale_inv_std` for every recipe (`fp8_delayed_scaling`, `fp8_current_scaling`,
   `mxfp8`, `fp8_block_scaling`, `nvfp4`) plus `_columnwise` variants. Min/max alone
   can hide serious clipping when scale spread is wide. Reduction across
   microbatches/ranks uses Welford (population variance, so `std=0` instead of NaN
   for the single-scalar delayed/current-scaling case). NVFP4 also gains
   `scale_inv_min` / `scale_inv_max`.

2. `LogFp8TensorStats` no longer crashes on layers using `NVFP4Quantizer` — bare
   stats are filtered with a warning, FP8-recipe-prefixed stats (e.g. `mxfp8_mse`
   for what-if) are preserved. Combined with #2296 and #2652, the BioNeMo ESM2
   YAML workaround referenced in the issue can now be removed.

Additionally, this PR fixes a pre-existing sign error in `compute_variance`
(parallel-axis Welford combination): the formula subtracted `(mean_i - mean)^2`
instead of adding, yielding negative variance whenever >=2 buffer groups had
different means (multi-microbatch or multi-rank reductions). Single-group
buffers were unaffected (`mean_i - mean = 0`), which is why existing `variance`
/ `std` stats never tripped this in tests. With `scale_inv_std` now reduced
across microbatches/ranks the bug becomes user-visible (`sqrt(negative)` -> NaN).

Item 3 from #2801 (per-block dump for MXFP8/NVFP4) is out of scope here.

Fixes #2801 (partial)

## Type of change

- [ ] Documentation change (change only to the documentation, either a fix or a new content)
- [x] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
- [ ] Infra/Build change
- [ ] Code refactoring

## Changes

- `stats_computation.py`: `add_scale_inv_stats` registers `scale_inv_std` plus
  Welford helpers; NVFP4 added to the registration loop.
- `stats_computation.py`: fix sign in `compute_variance` — Welford parallel
  combination needs `+ (mean_i - mean)^2`, not `-`.
- `log_fp8_tensor_stats.py`: `inspect_tensor` filters NVFP4-resolved bare stats
  with a warning instead of raising.
- `log_nvfp4_tensor_stats.py`: docstring + `supported_stats` updated for
  `scale_inv_min` / `scale_inv_max` / `scale_inv_std`.
- `test_log.py`: `scale_inv_std` added to `bare_stats`; numerics test validates
  min/max/std against `torch.std(scale_inv, unbiased=False)`.

# Checklist:

- [x] I have read and followed the [contributing guidelines](https://github.com/NVIDIA/TransformerEngine/blob/main/CONTRIBUTING.rst)
- [x] The functionality is complete
- [x] I have commented my code, particularly in hard-to-understand areas
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my feature works
- [x] New and existing unit tests pass locally with my changes

pggPL and others added 2 commits May 26, 2026 15:08
…8TensorStats (NVIDIA#2801)

- Register scale_inv_std (plus helper variance/numel/sum buffers using
  Welford reduction) for all FP8 recipes and NVFP4 in
  add_scale_inv_stats. Population variance keeps std=0 for delayed/current
  scaling where scale_inv is a single scalar.
- Also wire scale_inv_min/max/std for NVFP4 (was previously only FP8
  recipes).
- LogFp8TensorStats.inspect_tensor now filters bare stats on NVFP4 layers
  with a warning instead of raising, so dual LogFp8TensorStats +
  LogNvfp4TensorStats configs work with overlapping (or catch-all) layer
  regexes. Recipe-prefixed FP8 stats (e.g. mxfp8_mse) are preserved for
  what-if comparisons.
- Numerics test extended to validate scale_inv_min/max/std against
  torch.std(scale_inv, unbiased=False).

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL
Copy link
Copy Markdown
Collaborator Author

pggPL commented May 26, 2026

/te-ci pytorch

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 26, 2026

Greptile Summary

This PR adds a scale_inv_std statistic (population standard deviation of inverse scaling factors) to all quantization recipes including NVFP4, and prevents LogFp8TensorStats from crashing on NVFP4 layers by filtering unsupported bare stats with a warning while preserving FP8-recipe-prefixed what-if stats.

  • Welford formula corrected (-+ in compute_variance): fixes the parallel combination formula for both the new scale_inv_std stat and the pre-existing global variance/std stats — the old formula yielded incorrect (potentially negative) variance whenever per-rank/microbatch means differed.
  • NVFP4 filtering in LogFp8TensorStats.inspect_tensor: bare stats are now dropped with a warnings.warn for NVFP4 layers instead of raising an unhandled ValueError; FP8-recipe-prefixed stats (e.g. mxfp8_mse) are preserved for what-if comparisons.
  • scale_inv_std registered for all recipes: helper buffer stats (scale_inv_variance, scale_inv_numel, scale_inv_sum) are registered alongside min/max/std for Welford reduction; numerics test validates min/max/std against torch.std(..., unbiased=False).

Confidence Score: 5/5

Safe to merge — the Welford variance formula is now correct, NVFP4 filtering is well-guarded, and the new stat is consistently registered across all recipes.

The most impactful change — correcting the Welford combination formula from subtraction to addition — was the right fix and is verified by the updated numeric test. All new stat registrations follow the established helper-buffer pattern, and the NVFP4 path degrades gracefully with a warning rather than a crash.

No files require special attention; all changes are localized to the debug stats subsystem and do not touch training paths.

Important Files Changed

Filename Overview
transformer_engine/debug/features/utils/stats_computation.py Fixes long-standing Welford sign bug and adds scale_inv_std with correct helper-buffer dependencies; implementation is consistent with the existing std/variance pattern.
transformer_engine/debug/features/log_fp8_tensor_stats.py NVFP4 bare-stat filtering added to inspect_tensor with warning; scale_inv_std added to supported_stats list; assert guard replaces raise ValueError (noted in prior review thread).
transformer_engine/debug/features/log_nvfp4_tensor_stats.py scale_inv_min/max/std added to supported_stats and docstring; _columnwise suffix handling added to check_if_stat_is_supported.
tests/pytorch/debug/test_log.py scale_inv_std added to bare_stats; numerics assertions for min/max/std against torch.std(unbiased=False) added; single-rank scenario correctly validates Welford path.

Reviews (2): Last reviewed commit: "fix sign in parallel-axis Welford varian..." | Re-trigger Greptile

Comment on lines +212 to +213
# NVFP4-resolved stats are filtered out before this point in inspect_tensor().
assert recipe_from_stat != "nvfp4"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Replacing a user-facing raise ValueError with a bare assert weakens the defensive guard. Python's -O flag silently disables all assert statements, so if this path is ever reached in an optimised build, execution would continue silently and produce a confusing failure deep in the quantization path instead of a clear error message.

Suggested change
# NVFP4-resolved stats are filtered out before this point in inspect_tensor().
assert recipe_from_stat != "nvfp4"
# NVFP4-resolved stats are filtered out before this point in inspect_tensor().
if recipe_from_stat == "nvfp4":
raise ValueError(
f"[NVTORCH INSPECT ERROR] Cannot compute NVFP4 stats '{stat}' in "
"LogFp8TensorStats. This is an internal error: bare NVFP4 stats should "
"have been filtered in inspect_tensor() before reaching this point."
)

Parallel-group variance is Sigma n_i*(var_i + (mean_i - mean)^2) / N -
the between-group term must be added, not subtracted. Single-group
buffers hide the bug (mean_i = mean_global so the term is 0); it
surfaces with scale_inv_std reduced across microbatches/ranks, where
negative variance flows into sqrt() and yields NaN.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant