[PyTorch Debug] Add scale_inv_std stat and skip NVFP4 layers in LogFp8TensorStats#3044
[PyTorch Debug] Add scale_inv_std stat and skip NVFP4 layers in LogFp8TensorStats#3044pggPL wants to merge 3 commits into
Conversation
…8TensorStats (NVIDIA#2801) - Register scale_inv_std (plus helper variance/numel/sum buffers using Welford reduction) for all FP8 recipes and NVFP4 in add_scale_inv_stats. Population variance keeps std=0 for delayed/current scaling where scale_inv is a single scalar. - Also wire scale_inv_min/max/std for NVFP4 (was previously only FP8 recipes). - LogFp8TensorStats.inspect_tensor now filters bare stats on NVFP4 layers with a warning instead of raising, so dual LogFp8TensorStats + LogNvfp4TensorStats configs work with overlapping (or catch-all) layer regexes. Recipe-prefixed FP8 stats (e.g. mxfp8_mse) are preserved for what-if comparisons. - Numerics test extended to validate scale_inv_min/max/std against torch.std(scale_inv, unbiased=False). Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci pytorch |
Greptile SummaryThis PR adds a
Confidence Score: 5/5Safe to merge — the Welford variance formula is now correct, NVFP4 filtering is well-guarded, and the new stat is consistently registered across all recipes. The most impactful change — correcting the Welford combination formula from subtraction to addition — was the right fix and is verified by the updated numeric test. All new stat registrations follow the established helper-buffer pattern, and the NVFP4 path degrades gracefully with a warning rather than a crash. No files require special attention; all changes are localized to the debug stats subsystem and do not touch training paths. Important Files Changed
Reviews (2): Last reviewed commit: "fix sign in parallel-axis Welford varian..." | Re-trigger Greptile |
| # NVFP4-resolved stats are filtered out before this point in inspect_tensor(). | ||
| assert recipe_from_stat != "nvfp4" |
There was a problem hiding this comment.
Replacing a user-facing
raise ValueError with a bare assert weakens the defensive guard. Python's -O flag silently disables all assert statements, so if this path is ever reached in an optimised build, execution would continue silently and produce a confusing failure deep in the quantization path instead of a clear error message.
| # NVFP4-resolved stats are filtered out before this point in inspect_tensor(). | |
| assert recipe_from_stat != "nvfp4" | |
| # NVFP4-resolved stats are filtered out before this point in inspect_tensor(). | |
| if recipe_from_stat == "nvfp4": | |
| raise ValueError( | |
| f"[NVTORCH INSPECT ERROR] Cannot compute NVFP4 stats '{stat}' in " | |
| "LogFp8TensorStats. This is an internal error: bare NVFP4 stats should " | |
| "have been filtered in inspect_tensor() before reaching this point." | |
| ) |
Parallel-group variance is Sigma n_i*(var_i + (mean_i - mean)^2) / N - the between-group term must be added, not subtracted. Single-group buffers hide the bug (mean_i = mean_global so the term is 0); it surfaces with scale_inv_std reduced across microbatches/ranks, where negative variance flows into sqrt() and yields NaN. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Description