diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index 055210f93a..14a087da5a 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -61,6 +61,7 @@ "underflows%", "scale_inv_min", "scale_inv_max", + "scale_inv_std", "mse", ] @@ -248,6 +249,10 @@ def test_log_quantized_stats_numerics(fp8_recipe, feature_dirs): debug_api.step() dequantized_tensor = quantized_tensor.dequantize() + if hasattr(quantized_tensor, "_scale_inv"): + scale_inv_rowwise = quantized_tensor._scale_inv.float() + else: + scale_inv_rowwise = quantized_tensor._rowwise_scale_inv.float() output = read_log(log_dir) for line in output.splitlines(): @@ -267,6 +272,17 @@ def test_log_quantized_stats_numerics(fp8_recipe, feature_dirs): (abs(dequantized_tensor) > abs(tensor)).sum() / dequantized_tensor.numel() * 100 ) assert overflows == pytest.approx(expected.cpu(), abs=1e-4) + # Rowwise scale_inv stats only; logger formats with {:.4f} so abs<1e-4. + if "scale_inv_min" in line and "_columnwise" not in line: + value = float(line.split("value=")[1]) + assert value == pytest.approx(scale_inv_rowwise.min().cpu().item(), abs=1e-4) + if "scale_inv_max" in line and "_columnwise" not in line: + value = float(line.split("value=")[1]) + assert value == pytest.approx(scale_inv_rowwise.max().cpu().item(), abs=1e-4) + if "scale_inv_std" in line and "_columnwise" not in line: + value = float(line.split("value=")[1]) + expected = torch.std(scale_inv_rowwise, unbiased=False).cpu().item() + assert value == pytest.approx(expected, abs=1e-4) LOG_HIGH_PRECISION_CONFIG = """ @@ -403,7 +419,8 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs): with open( os.path.join( - temp_dir, "nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log" + temp_dir, + "nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log", ), "r", ) as f: diff --git a/transformer_engine/debug/features/log_fp8_tensor_stats.py b/transformer_engine/debug/features/log_fp8_tensor_stats.py index d26f9ef7f6..f453b2a36a 100644 --- a/transformer_engine/debug/features/log_fp8_tensor_stats.py +++ b/transformer_engine/debug/features/log_fp8_tensor_stats.py @@ -10,19 +10,26 @@ import torch import nvdlfw_inspect.api as debug_api -from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats +from nvdlfw_inspect.debug_features.log_tensor_stats import ( + LogTensorStats as BaseLogTensorStats, +) from nvdlfw_inspect.registry import Registry, api_method import transformer_engine_torch as tex from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS -from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter +from transformer_engine.debug.features.utils import ( + get_reduction_params, + next_enabled_iter, +) from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor from transformer_engine.pytorch.tensor.float8_tensor import ( Float8Quantizer, Float8CurrentScalingQuantizer, ) from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer -from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( + Float8BlockQuantizer, +) try: from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer @@ -33,7 +40,12 @@ NVFP4Quantizer = None -ALL_RECIPE_NAMES = ["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8", "fp8_block_scaling"] +ALL_RECIPE_NAMES = [ + "fp8_delayed_scaling", + "fp8_current_scaling", + "mxfp8", + "fp8_block_scaling", +] def _get_recipe_name(quantizer: Optional[Quantizer]): @@ -57,7 +69,10 @@ def _get_new_quantizer(recipe_name, fp8_dtype): return Float8BlockQuantizer(fp8_dtype=fp8_dtype, rowwise=True, columnwise=True) if recipe_name == "fp8_current_scaling": return Float8CurrentScalingQuantizer( - fp8_dtype=fp8_dtype, device=torch.device("cuda"), rowwise=True, columnwise=True + fp8_dtype=fp8_dtype, + device=torch.device("cuda"), + rowwise=True, + columnwise=True, ) if recipe_name == "mxfp8": return MXFP8Quantizer(fp8_dtype=fp8_dtype, rowwise=True, columnwise=True) @@ -119,10 +134,13 @@ class LogFp8TensorStats(BaseLogTensorStats): - overflows% - percentage of elements of tensor that were clipped to the max/min value of the FP8 range - supported only for fp8_delayed_scaling, - scale_inv_min - minimum of the inverse of the scaling factors, - scale_inv_max - maximum of the inverse of the scaling factors, + - scale_inv_std - population standard deviation of the inverse of the scaling factors; + useful for spotting clipping that min/max alone can miss (degenerate to 0 for + fp8_delayed_scaling / fp8_current_scaling since those use a single scalar scale). - mse - mean squared error of the quantized tensor and the original tensor = sum((quantized_tensor - original_tensor)**2) / num_elements, When collecting stats for the weight tensor with FP8 model parameters enabled, - only "scale_inv_min" and "scale_inv_max" are available. + only "scale_inv_min", "scale_inv_max" and "scale_inv_std" are available. All other statistics require access to the high precision tensor. tensors/tensors_struct: List[str] @@ -191,15 +209,8 @@ def check_if_stat_is_supported( if recipe_from_stat != "" and recipe_from_stat not in ALL_RECIPE_NAMES: raise ValueError(f"Stat {stat} contains an unsupported recipe name: {recipe_from_stat}") - # Block any NVFP4 stats in LogFp8TensorStats (FP8-specific logic won't work) - # But allow recipe-prefixed FP8 stats like "mxfp8_underflows%" even with NVFP4 quantizer - if recipe_from_stat == "nvfp4": - raise ValueError( - f"[NVTORCH INSPECT ERROR] Cannot compute NVFP4 stats '{stat}' in LogFp8TensorStats." - " FP8-specific statistics do not work with NVFP4. Use LogNvfp4TensorStats for" - " NVFP4-specific stats, or use FP8 recipe-prefixed stats (e.g.," - " 'mxfp8_underflows%', 'fp8_block_scaling_mse') for what-if FP8 comparisons." - ) + # NVFP4-resolved stats are filtered out before this point in inspect_tensor(). + assert recipe_from_stat != "nvfp4" if recipe_from_stat in ["fp8_delayed_scaling", "fp8_current_scaling"] and columnwise: raise ValueError( @@ -216,7 +227,13 @@ def check_if_stat_is_supported( if recipe_from_stat == "mxfp8" and torch.cuda.get_device_capability()[0] < 10: raise ValueError(f"Stat {stat} needs Blackwell or later GPU.") - supported_stats = ["underflows%", "scale_inv_min", "scale_inv_max", "mse"] + supported_stats = [ + "underflows%", + "scale_inv_min", + "scale_inv_max", + "scale_inv_std", + "mse", + ] if stat_without_recipe not in supported_stats: raise ValueError( f"Stat {stat} contains an unsupported stat name: {stat_without_recipe}" @@ -252,9 +269,14 @@ def update_aux_dict( Needs to clean after usage, because it possibly change the usage of the quantized tensor. """ fp8_dtype = tex.DType.kFloat8E4M3 - if recipe_name in ["fp8_delayed_scaling", "fp8_current_scaling", "fp8_block_scaling"]: + if recipe_name in [ + "fp8_delayed_scaling", + "fp8_current_scaling", + "fp8_block_scaling", + ]: assert isinstance( - quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer, Float8BlockQuantizer) + quantizer, + (Float8Quantizer, Float8CurrentScalingQuantizer, Float8BlockQuantizer), ) fp8_dtype = quantizer.dtype @@ -280,7 +302,8 @@ def update_aux_dict( finally: if isinstance(quantized_tensor, QuantizedTensor): quantized_tensor.update_usage( - rowwise_usage=old_rowwise_usage, columnwise_usage=old_columnwise_usage + rowwise_usage=old_rowwise_usage, + columnwise_usage=old_columnwise_usage, ) @api_method @@ -338,6 +361,27 @@ def inspect_tensor( recipe_name = _get_recipe_name(quantizer) + # If the layer uses NVFP4, drop bare stats (which would target the NVFP4 + # recipe that LogFp8TensorStats can't handle) but keep stats explicitly + # prefixed with an FP8 recipe (e.g. "mxfp8_mse") for what-if FP8 comparison. + if _nvfp4_available and isinstance(quantizer, NVFP4Quantizer): + kept_stats, dropped_stats = [], [] + for stat in config["stats"]: + if any(r in stat for r in ALL_RECIPE_NAMES): + kept_stats.append(stat) + else: + dropped_stats.append(stat) + if dropped_stats: + warnings.warn( + f"[LogFp8TensorStats] Skipping stats {dropped_stats} for layer " + f"'{layer_name}', tensor '{tensor_name}': layer uses NVFP4. Use " + "LogNvfp4TensorStats for NVFP4 stats, or prefix stats with an FP8 " + "recipe name (e.g. 'mxfp8_mse') for what-if FP8 comparisons." + ) + if not kept_stats: + return + config = {**config, "stats": kept_stats} + for stat in config["stats"]: self.check_if_stat_is_supported( stat, recipe_name, high_precision_tensor_provided=tensor is not None diff --git a/transformer_engine/debug/features/log_nvfp4_tensor_stats.py b/transformer_engine/debug/features/log_nvfp4_tensor_stats.py index 8a76f4edcf..848dfa8ab7 100644 --- a/transformer_engine/debug/features/log_nvfp4_tensor_stats.py +++ b/transformer_engine/debug/features/log_nvfp4_tensor_stats.py @@ -11,14 +11,21 @@ import torch import nvdlfw_inspect.api as debug_api -from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats +from nvdlfw_inspect.debug_features.log_tensor_stats import ( + LogTensorStats as BaseLogTensorStats, +) from nvdlfw_inspect.registry import Registry, api_method from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer -from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter -from transformer_engine.pytorch.tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage +from transformer_engine.debug.features.utils import ( + get_reduction_params, + next_enabled_iter, +) +from transformer_engine.pytorch.tensor.storage.nvfp4_tensor_storage import ( + NVFP4TensorStorage, +) @Registry.register_feature(namespace="transformer_engine") @@ -45,6 +52,10 @@ class LogNvfp4TensorStats(BaseLogTensorStats): List of statistics to collect. Available stats: - underflows% - percentage of non-zero elements clipped to 0 (from packed FP4 data) - mse - mean squared error = sum((quantized_tensor - original_tensor)**2) / num_elements + - scale_inv_min - minimum of the inverse of the scaling factors + - scale_inv_max - maximum of the inverse of the scaling factors + - scale_inv_std - population standard deviation of the inverse of the scaling factors; + useful for spotting clipping that min/max alone can miss tensors/tensors_struct: List[str] list of tensors to log @@ -85,13 +96,18 @@ class LogNvfp4TensorStats(BaseLogTensorStats): def check_if_stat_is_supported(self, stat: str): """Returns True if stat is supported, raises ValueError otherwise.""" + bare = stat[: -len("_columnwise")] if stat.endswith("_columnwise") else stat supported_stats = [ "underflows%", "mse", + "scale_inv_min", + "scale_inv_max", + "scale_inv_std", ] - if stat not in supported_stats: + if bare not in supported_stats: raise ValueError( f"Stat {stat} is not supported for NVFP4. Supported stats: {supported_stats}" + " (any of these may take an optional '_columnwise' suffix)" ) return True diff --git a/transformer_engine/debug/features/utils/stats_computation.py b/transformer_engine/debug/features/utils/stats_computation.py index b0002ffee6..aabb7d6959 100644 --- a/transformer_engine/debug/features/utils/stats_computation.py +++ b/transformer_engine/debug/features/utils/stats_computation.py @@ -125,7 +125,7 @@ def compute_variance(variances, numels, sums): """Welford algorithm is used for numerically stable distributed variance computation.""" mean = torch.sum(sums) / torch.sum(numels) means = sums / numels - var = torch.sum(numels * (variances - torch.pow((means - mean), 2))) / torch.sum(numels) + var = torch.sum(numels * (variances + torch.pow((means - mean), 2))) / torch.sum(numels) return var @@ -207,15 +207,24 @@ def _get(buffers, stat_name): } STATS = { - "min": (lambda x, aux_dict: torch.min(x), lambda buffers: min(_get(buffers, "min"))), - "max": (lambda x, aux_dict: torch.max(x), lambda buffers: max(_get(buffers, "max"))), - "sum": (lambda x, aux_dict: torch.sum(x), lambda buffers: sum(_get(buffers, "sum"))), + "min": ( + lambda x, aux_dict: torch.min(x), + lambda buffers: min(_get(buffers, "min")), + ), + "max": ( + lambda x, aux_dict: torch.max(x), + lambda buffers: max(_get(buffers, "max")), + ), + "sum": ( + lambda x, aux_dict: torch.sum(x), + lambda buffers: sum(_get(buffers, "sum")), + ), "mean": ( lambda x, aux_dict: torch.mean(x), lambda buffers: sum(_get(buffers, "sum")) / sum(_get(buffers, "numel")), ), "numel": ( - lambda x, aux_dict: x.numel() if hasattr(x, "numel") else x.get_data_tensors()[0].numel(), + lambda x, aux_dict: (x.numel() if hasattr(x, "numel") else x.get_data_tensors()[0].numel()), lambda buffers: sum(_get(buffers, "numel")), ), "l1_norm": ( @@ -236,7 +245,10 @@ def _get(buffers, stat_name): _get(buffers, "variance"), _get(buffers, "numel"), _get(buffers, "sum") ), ), - "cur_amax": (lambda x, aux_dict: x.abs().max(), lambda buffers: max(_get(buffers, "cur_amax"))), + "cur_amax": ( + lambda x, aux_dict: x.abs().max(), + lambda buffers: max(_get(buffers, "cur_amax")), + ), "dynamic_range_top": ( lambda x, aux_dict: _compute_dynamic_range_top(x), lambda buffers: max(_get(buffers, "dynamic_range_top")), @@ -335,11 +347,13 @@ def add_underflows_stats(recipe_name: str, columnwise: bool = False): def add_scale_inv_stats(recipe_name: str, columnwise: bool = False): - """Register *both* scale-inv min and max stats for a given recipe. + """Register scale-inv min/max/std stats for a given recipe. - This replaces the earlier separate helpers and avoids duplicated boilerplate. + The std uses Welford's algorithm to combine partial variances across + microbatches/ranks, so helper buffers for variance/numel/sum are also + registered. Population variance (unbiased=False) is used so single-element + scale_inv tensors (delayed/current scaling) yield std=0 rather than NaN. """ - # Determine which attribute holds the scale-inverse tensor. def get_scale_inv(quantized_tensor, columnwise): if hasattr(quantized_tensor, "_scale_inv"): @@ -348,18 +362,27 @@ def get_scale_inv(quantized_tensor, columnwise): return getattr(quantized_tensor, "_columnwise_scale_inv") return getattr(quantized_tensor, "_rowwise_scale_inv") + def _prefix(): + return f"{recipe_name}{'_' if recipe_name != '' else ''}" + columnwise_suffix = "_columnwise" if columnwise else "" - # Prepare stat names. - stat_name_min = ( - f"{recipe_name}{'_' if recipe_name != '' else ''}scale_inv_min{columnwise_suffix}" - ) - stat_name_max = ( - f"{recipe_name}{'_' if recipe_name != '' else ''}scale_inv_max{columnwise_suffix}" - ) + stat_name_min = f"{_prefix()}scale_inv_min{columnwise_suffix}" + stat_name_max = f"{_prefix()}scale_inv_max{columnwise_suffix}" + stat_name_std = f"{_prefix()}scale_inv_std{columnwise_suffix}" + stat_name_var = f"{_prefix()}scale_inv_variance{columnwise_suffix}" + stat_name_numel = f"{_prefix()}scale_inv_numel{columnwise_suffix}" + stat_name_sum = f"{_prefix()}scale_inv_sum{columnwise_suffix}" # Assign indices in `stats_to_num` (order matters — keep insertion order deterministic). - stats_to_num[stat_name_min] = len(stats_to_num) - stats_to_num[stat_name_max] = len(stats_to_num) + for name in ( + stat_name_min, + stat_name_max, + stat_name_std, + stat_name_var, + stat_name_numel, + stat_name_sum, + ): + stats_to_num[name] = len(stats_to_num) # Capture the attribute name inside lambdas via default args to avoid late binding. STATS[stat_name_min] = ( @@ -370,9 +393,39 @@ def get_scale_inv(quantized_tensor, columnwise): lambda x, aux_dict, _col=columnwise: get_scale_inv(aux_dict[recipe_name], _col).max(), lambda buffers, _sn=stat_name_max: max(_get(buffers, _sn)), ) + STATS[stat_name_var] = ( + lambda x, aux_dict, _col=columnwise: torch.var( + get_scale_inv(aux_dict[recipe_name], _col).float(), unbiased=False + ), + lambda buffers, _sv=stat_name_var, _sn=stat_name_numel, _ss=stat_name_sum: compute_variance( + _get(buffers, _sv), _get(buffers, _sn), _get(buffers, _ss) + ), + ) + STATS[stat_name_numel] = ( + lambda x, aux_dict, _col=columnwise: get_scale_inv(aux_dict[recipe_name], _col).numel(), + lambda buffers, _sn=stat_name_numel: sum(_get(buffers, _sn)), + ) + STATS[stat_name_sum] = ( + lambda x, aux_dict, _col=columnwise: get_scale_inv(aux_dict[recipe_name], _col) + .float() + .sum(), + lambda buffers, _ss=stat_name_sum: sum(_get(buffers, _ss)), + ) + STATS[stat_name_std] = ( + lambda x, aux_dict, _col=columnwise: torch.std( + get_scale_inv(aux_dict[recipe_name], _col).float(), unbiased=False + ), + lambda buffers, _sv=stat_name_var, _sn=stat_name_numel, _ss=stat_name_sum: compute_std( + _get(buffers, _sv), _get(buffers, _sn), _get(buffers, _ss) + ), + ) DEPENDENCIES[stat_name_min] = {stat_name_min} DEPENDENCIES[stat_name_max] = {stat_name_max} + DEPENDENCIES[stat_name_numel] = {stat_name_numel} + DEPENDENCIES[stat_name_sum] = {stat_name_sum} + DEPENDENCIES[stat_name_var] = {stat_name_var, stat_name_numel, stat_name_sum} + DEPENDENCIES[stat_name_std] = {stat_name_var, stat_name_numel, stat_name_sum} def add_mse_stats(recipe_name: str, columnwise: bool = False): @@ -505,3 +558,5 @@ def add_nvfp4_underflows_stats(): # Register NVFP4 stats add_nvfp4_underflows_stats() add_mse_stats("nvfp4") # Reuse existing MSE function +for _columnwise in [True, False]: + add_scale_inv_stats("nvfp4", _columnwise)