From 02fd39cb6404a1f065a6f4652f90b365003a36e0 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 26 May 2026 16:58:52 -0400 Subject: [PATCH 01/31] Add tool to evaluate layer-wise numerical-error propagation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A new `tools/evaluate_precision.py` (`RunnableConfig`) drives a fp32 reference run plus one one-iteration trainer run per named variant from a Fast-LLM training YAML, then extracts per-layer forward activations and input gradients from the saved tensor logs and reports per-tensor RMS and max diffs (absolute and scaled). Variants are flat dicts of dotted-path overrides, the same syntax as Fast-LLM CLI key=value args, so they can sweep arbitrary configuration knobs (dtype, attention implementation, optimizer dtype, etc.) — not just compute_dtype. Also moves `compare_tensor_logs.py` into the `fast_llm` package so it is importable from `tools/` (the test tree isn't on sys.path for script entry points), and factors a `_compute_diff` helper out of `CompareConfig.compare_tensors` so the tool can extract numbers for every tensor rather than only those that breach a tolerance. Existing test callers are unaffected. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../config_utils}/compare_tensor_logs.py | 55 +++-- tests/models/test_checkpoint.py | 2 +- tests/models/test_match_megatron.py | 2 +- tests/utils/distributed_configs.py | 2 +- tools/evaluate_precision.py | 201 ++++++++++++++++++ 5 files changed, 243 insertions(+), 19 deletions(-) rename {tests/utils => fast_llm/engine/config_utils}/compare_tensor_logs.py (79%) create mode 100644 tools/evaluate_precision.py diff --git a/tests/utils/compare_tensor_logs.py b/fast_llm/engine/config_utils/compare_tensor_logs.py similarity index 79% rename from tests/utils/compare_tensor_logs.py rename to fast_llm/engine/config_utils/compare_tensor_logs.py index f02d62c79..080510036 100644 --- a/tests/utils/compare_tensor_logs.py +++ b/fast_llm/engine/config_utils/compare_tensor_logs.py @@ -87,6 +87,30 @@ def _compare_dict_keys(self, dict_ref, dict_test, errors, name): # Avoid set to preserve ordering. return [key for key in dict_test if key in dict_ref] + def _compute_diff(self, tensor_ref, tensor_test, step_name, tensor_name) -> dict | None: + # Returns per-tensor error metrics, or None on shape/sampling mismatch. + if tensor_ref["shape"] != tensor_test["shape"]: + return None + if tensor_ref["step"] != tensor_test["step"]: + return None + sub_config = self._get_sub_config(step_name, tensor_name) + samples_ref = tensor_ref["samples"].flatten().float() + samples_test = tensor_test["samples"].flatten().float() + if sub_config.scale != 1.0: + samples_test = samples_test / sub_config.scale + scale_unreg = (samples_ref**2).mean() ** 0.5 + rms_scale = (scale_unreg**2 + sub_config.rms_eps**2) ** 0.5 + rms = ((samples_ref - samples_test) ** 2).mean() ** 0.5 + max_diff = (samples_ref - samples_test).abs().max() + return { + "rms_abs": rms.item(), + "rms_rel": (rms / rms_scale).item(), + "max_abs": max_diff.item(), + "max_rel": (max_diff / rms_scale).item(), + "ref_scale": scale_unreg.item(), + "ref_scale_regularized": rms_scale.item(), + } + def compare_tensors(self, tensor_ref, tensor_test, errors, step_name, tensor_name): sub_config = self._get_sub_config(step_name, tensor_name) if tensor_ref["shape"] != tensor_test["shape"]: @@ -108,34 +132,33 @@ def compare_tensors(self, tensor_ref, tensor_test, errors, step_name, tensor_nam ) return - samples_ref = tensor_ref["samples"].flatten().float() - samples_test = tensor_test["samples"].flatten().float() - if sub_config.scale != 1.0: - samples_test = samples_test / sub_config.scale - scale_unreg = (samples_ref**2).mean() ** 0.5 - rms_scale = (scale_unreg**2 + sub_config.rms_eps**2) ** 0.5 - rms = ((samples_ref - samples_test) ** 2).mean() ** 0.5 - max_diff = (samples_ref - samples_test).abs().max() + metrics = self._compute_diff(tensor_ref, tensor_test, step_name, tensor_name) + rms_scale = metrics["ref_scale_regularized"] + scale_unreg = metrics["ref_scale"] tensor_errors = [] - if rms > sub_config.rms_abs_tolerance: - tensor_errors.append(f" * RMS diff absolute = {rms} > {sub_config.rms_abs_tolerance}") + if metrics["rms_abs"] > sub_config.rms_abs_tolerance: + tensor_errors.append(f" * RMS diff absolute = {metrics['rms_abs']} > {sub_config.rms_abs_tolerance}") - if rms / rms_scale > sub_config.rms_rel_tolerance: + if metrics["rms_rel"] > sub_config.rms_rel_tolerance: tensor_errors.append( - f" * RMS diff scaled = {rms / rms_scale} > {sub_config.rms_rel_tolerance} (scale={rms_scale}, unregularized={scale_unreg})" + f" * RMS diff scaled = {metrics['rms_rel']} > {sub_config.rms_rel_tolerance} (scale={rms_scale}, unregularized={scale_unreg})" ) - if max_diff > sub_config.max_abs_tolerance: - tensor_errors.append(f" * Max diff absolute = {max_diff} > {sub_config.max_abs_tolerance}") + if metrics["max_abs"] > sub_config.max_abs_tolerance: + tensor_errors.append(f" * Max diff absolute = {metrics['max_abs']} > {sub_config.max_abs_tolerance}") - if max_diff / rms_scale > sub_config.max_rel_tolerance: + if metrics["max_rel"] > sub_config.max_rel_tolerance: tensor_errors.append( - f" * Max diff scaled = {max_diff / rms_scale} > {sub_config.max_rel_tolerance} (scale={rms_scale}, unregularized={scale_unreg})" + f" * Max diff scaled = {metrics['max_rel']} > {sub_config.max_rel_tolerance} (scale={rms_scale}, unregularized={scale_unreg})" ) if tensor_errors: + samples_ref = tensor_ref["samples"].flatten().float() + samples_test = tensor_test["samples"].flatten().float() + if sub_config.scale != 1.0: + samples_test = samples_test / sub_config.scale tensor_errors.extend( [ f" Test samples: " + "".join(f"{x:12.4e}" for x in samples_test[: self.show_samples].tolist()), diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 0b4dbafc1..f3febae4b 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -18,9 +18,9 @@ ModelConfigType, ) from fast_llm.engine.checkpoint.convert import ConvertConfig +from fast_llm.engine.config_utils.compare_tensor_logs import CompareConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName, StageMode from fast_llm.utils import Assert, header -from tests.utils.compare_tensor_logs import CompareConfig from tests.utils.distributed_configs import DistributedTestingConfig from tests.utils.model_configs import ModelTestingConfig, ModelTestingGroup from tests.utils.save_load_configs import DISTRIBUTED_SAVE_LOAD_CONFIGS, DistributedSaveLoadConfig diff --git a/tests/models/test_match_megatron.py b/tests/models/test_match_megatron.py index 03ebac757..3c95d0dea 100644 --- a/tests/models/test_match_megatron.py +++ b/tests/models/test_match_megatron.py @@ -18,9 +18,9 @@ from fast_llm.data.dataset.sampled import logger from fast_llm.data.document.language_model import LanguageModelDocument from fast_llm.data.preparation.tokenizer import TokenizerConfig +from fast_llm.engine.config_utils.compare_tensor_logs import CompareConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert -from tests.utils.compare_tensor_logs import CompareConfig from tests.utils.dataset import get_common_test_dataset from tests.utils.distributed_configs import DistributedTestingConfig from tests.utils.global_variables import DATASET_CACHE, MODEL_TEST_VOCAB_SIZE, TOKENIZER_NAME diff --git a/tests/utils/distributed_configs.py b/tests/utils/distributed_configs.py index f3bbbac8d..d08b023b9 100644 --- a/tests/utils/distributed_configs.py +++ b/tests/utils/distributed_configs.py @@ -4,7 +4,7 @@ import torch -from tests.utils.compare_tensor_logs import CompareConfig +from fast_llm.engine.config_utils.compare_tensor_logs import CompareConfig logger = logging.getLogger(__name__) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py new file mode 100644 index 000000000..782ff996d --- /dev/null +++ b/tools/evaluate_precision.py @@ -0,0 +1,201 @@ +import json +import logging +import pathlib +import typing + +import yaml + +from fast_llm.config import Field, FieldHint, config_class +from fast_llm.engine.config_utils.compare_tensor_logs import CompareConfig +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.engine.config_utils.runnable import RunnableConfig +from fast_llm.engine.training.config import TrainerConfig + +# Populate the trainer dynamic-type registry. +import fast_llm.data.auto # noqa: F401 # isort:skip +import fast_llm.engine.checkpoint.convert # noqa: F401 # isort:skip +import fast_llm.models.auto # noqa: F401 # isort:skip + +logger = logging.getLogger(__name__) + + +# Tensor-log verbosity level. 13 gives 2**(13-3)=1024 sampled values per tensor, +# matching the convention in the existing layer-comparison tests. +_LOG_LEVEL = 13 +_REFERENCE_NAME = "reference" + + +@config_class() +class EvaluatePrecisionConfig(RunnableConfig): + training_config: pathlib.Path = Field( + desc="Path to a Fast-LLM training YAML serving as the fp32 reference configuration.", + hint=FieldHint.core, + ) + model_type: str = Field( + desc="Trainer dynamic-type name (e.g. 'gpt') used to dispatch to the right TrainerConfig subclass.", + hint=FieldHint.core, + ) + variants: dict[str, typing.Any] = Field( + desc="Named override bundles to evaluate against the fp32 reference." + " Each value is a flat dict mapping dotted-path keys (same syntax as the Fast-LLM CLI) to values.", + hint=FieldHint.core, + ) + output_dir: pathlib.Path = Field( + desc="Directory for per-run tensor-log artifacts and the final JSON report.", + hint=FieldHint.core, + ) + num_samples: int = Field( + default=1024, + desc="Number of sampled values stored per logged tensor.", + hint=FieldHint.feature, + ) + + def _validate(self) -> None: + super()._validate() + assert self.training_config.is_file(), f"Training config not found: {self.training_config}" + assert _REFERENCE_NAME not in self.variants, f"'{_REFERENCE_NAME}' is reserved for the fp32 baseline." + for name, overrides in self.variants.items(): + assert isinstance(overrides, dict) and all( + isinstance(k, str) for k in overrides + ), f"Variant {name!r} must be a flat dict of dotted-path string keys." + + def run(self) -> None: + base_dict = yaml.safe_load(self.training_config.read_text()) + for field_name in ("compute_dtype", "optimization_dtype"): + current = _get_nested(base_dict, ("model", "distributed", field_name)) + if current is not None and DataType(current) is not DataType.float32: + logger.warning( + f"Base config sets model.distributed.{field_name}={current!r};" + f" overriding to float32 for the reference run." + ) + + runs: dict[str, dict[str, typing.Any]] = {_REFERENCE_NAME: {}} + runs.update(self.variants) + for name, variant_overrides in runs.items(): + self._run_one(name, variant_overrides) + + ref_artifacts = self._artifact_path(_REFERENCE_NAME) + results = {name: self._compare(ref_artifacts, self._artifact_path(name)) for name in self.variants} + + report_path = self.output_dir / "precision_report.json" + report_path.parent.mkdir(parents=True, exist_ok=True) + report_path.write_text(json.dumps(results, indent=2)) + logger.info(f"Wrote report to {report_path}") + + for name, rows in results.items(): + _print_table(name, rows) + + def _artifact_path(self, name: str) -> pathlib.Path: + return self.output_dir / name / "runs" / "0" / "artifacts" + + def _run_one(self, name: str, variant_overrides: dict[str, typing.Any]) -> None: + experiment_dir = (self.output_dir / name).resolve() + forced_fp32 = { + "model.distributed.compute_dtype": "float32", + "model.distributed.optimization_dtype": "float32", + } + tool_overrides = { + "training.train_iters": 1, + "training.checkpoint.interval": None, + "run.tensor_logs.save": True, + "run.tensor_logs.show": False, + "run.tensor_logs.max_elements": self.num_samples, + "run.experiment_dir": str(experiment_dir), + "model.multi_stage.debug_layer_outputs": _LOG_LEVEL, + "model.multi_stage.debug_layer_gradients": _LOG_LEVEL, + } + # Compose: forced fp32 first so a variant can override it (e.g. compute_dtype=bfloat16); + # tool overrides last so logging and single-iteration mode always win. + combined = {**forced_fp32, **variant_overrides, **tool_overrides} + cli_overrides = [f"{key}={yaml.safe_dump(value).strip()}" for key, value in combined.items()] + logger.info(f"=== Running {name!r} ===") + if variant_overrides: + logger.info(f"Variant overrides: {variant_overrides}") + TrainerConfig.parse_and_run([self.model_type, "-c", str(self.training_config), *cli_overrides]) + + def _compare(self, ref_path: pathlib.Path, test_path: pathlib.Path) -> list[dict[str, typing.Any]]: + compare_config = CompareConfig() + errors: list[str] = [] + ref_logs = compare_config._extract_tensor_logs(ref_path, errors) + test_logs = compare_config._extract_tensor_logs(test_path, errors) + for error in errors: + logger.warning(error) + rows: list[dict[str, typing.Any]] = [] + for step_name in sorted(ref_logs): + if step_name not in test_logs: + logger.warning(f"Step {step_name!r} missing from test logs") + continue + step_ref = ref_logs[step_name] + step_test = test_logs[step_name] + for tensor_name, ref in step_ref.items(): + if tensor_name not in step_test: + continue + metrics = compare_config._compute_diff(ref, step_test[tensor_name], step_name, tensor_name) + if metrics is None: + continue + rows.append( + { + "step": step_name, + "tensor_name": tensor_name, + "kind": _classify(tensor_name), + "shape": ref["shape"], + **metrics, + } + ) + return rows + + +def _classify(tensor_name: str) -> str: + # Stage._log_layer_forward / _log_layer_backward produce " fw[, mb=…]" + # and " bw[, mb=…]"; log_distributed_tensor may prefix the name + # with "Global " and append a ": " suffix when reconstructing a + # tensor-parallel-global tensor. + for kind in ("fw", "bw"): + if f" {kind}:" in tensor_name or f" {kind}," in tensor_name or tensor_name.endswith(f" {kind}"): + return kind + return "other" + + +def _get_nested(d: typing.Any, keys: tuple[str, ...]) -> typing.Any: + for k in keys: + if not isinstance(d, dict) or k not in d: + return None + d = d[k] + return d + + +def _print_table(name: str, rows: list[dict[str, typing.Any]]) -> None: + print(f"\n=== Variant: {name} ===") + if not rows: + print("(no matching tensors)") + return + columns = [ + ("step", "step", 6), + ("kind", "kind", 6), + ("tensor_name", "tensor", 48), + ("shape", "shape", 22), + ("ref_scale", "ref_scale", 12), + ("rms_abs", "rms_abs", 12), + ("rms_rel", "rms_rel", 12), + ("max_abs", "max_abs", 12), + ("max_rel", "max_rel", 12), + ] + header = " ".join(f"{title:<{width}}" for _, title, width in columns) + print(header) + print("-" * len(header)) + for row in rows: + parts = [] + for key, _, width in columns: + value = row[key] + if isinstance(value, float): + cell = f"{value:.4e}" + elif isinstance(value, list): + cell = "x".join(str(x) for x in value) + else: + cell = str(value) + parts.append(f"{cell:<{width}}") + print(" ".join(parts)) + + +if __name__ == "__main__": + EvaluatePrecisionConfig.parse_and_run() From 4dd6c1498b4a39d733a0702c0ab381e64cfafd0a Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 27 May 2026 14:57:14 -0400 Subject: [PATCH 02/31] Collapse to a single config; require a checkpoint MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The tool now takes a single YAML containing `pretrained:` (the checkpoint that defines the model architecture + weights), `variants:`, `output_dir:` and a few optional knobs (`model_type`, `num_samples`, `micro_batch_size`, `sequence_length`). The training/optimizer/data sections of the underlying training config are hardcoded — they have no bearing on the propagation measurement (1 iteration, no checkpoint save, random tokens, dummy learning rate, optimization dtype forced to float32 alongside compute dtype). A variant can still override any of the hardcoded fields via the dotted-path mechanism if needed. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 115 ++++++++++++++++++++++-------------- 1 file changed, 71 insertions(+), 44 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index 782ff996d..016744468 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -7,7 +7,6 @@ from fast_llm.config import Field, FieldHint, config_class from fast_llm.engine.config_utils.compare_tensor_logs import CompareConfig -from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.training.config import TrainerConfig @@ -27,12 +26,9 @@ @config_class() class EvaluatePrecisionConfig(RunnableConfig): - training_config: pathlib.Path = Field( - desc="Path to a Fast-LLM training YAML serving as the fp32 reference configuration.", - hint=FieldHint.core, - ) - model_type: str = Field( - desc="Trainer dynamic-type name (e.g. 'gpt') used to dispatch to the right TrainerConfig subclass.", + pretrained: dict[str, typing.Any] = Field( + desc="Fast-LLM `CheckpointLoadConfig` dict (e.g. `{path: HuggingFaceTB/SmolLM2-135M, format: llama}`)." + " The model architecture and weights are loaded from this checkpoint.", hint=FieldHint.core, ) variants: dict[str, typing.Any] = Field( @@ -44,15 +40,29 @@ class EvaluatePrecisionConfig(RunnableConfig): desc="Directory for per-run tensor-log artifacts and the final JSON report.", hint=FieldHint.core, ) + model_type: str = Field( + default="gpt", + desc="Trainer dynamic-type name used to dispatch to the right `TrainerConfig` subclass.", + hint=FieldHint.optional, + ) num_samples: int = Field( default=1024, desc="Number of sampled values stored per logged tensor.", hint=FieldHint.feature, ) + micro_batch_size: int = Field( + default=1, + desc="Micro-batch size for the single forward+backward pass.", + hint=FieldHint.feature, + ) + sequence_length: int = Field( + default=2048, + desc="Sequence length (maximum document length) for the random input.", + hint=FieldHint.feature, + ) def _validate(self) -> None: super()._validate() - assert self.training_config.is_file(), f"Training config not found: {self.training_config}" assert _REFERENCE_NAME not in self.variants, f"'{_REFERENCE_NAME}' is reserved for the fp32 baseline." for name, overrides in self.variants.items(): assert isinstance(overrides, dict) and all( @@ -60,15 +70,7 @@ def _validate(self) -> None: ), f"Variant {name!r} must be a flat dict of dotted-path string keys." def run(self) -> None: - base_dict = yaml.safe_load(self.training_config.read_text()) - for field_name in ("compute_dtype", "optimization_dtype"): - current = _get_nested(base_dict, ("model", "distributed", field_name)) - if current is not None and DataType(current) is not DataType.float32: - logger.warning( - f"Base config sets model.distributed.{field_name}={current!r};" - f" overriding to float32 for the reference run." - ) - + self.output_dir.mkdir(parents=True, exist_ok=True) runs: dict[str, dict[str, typing.Any]] = {_REFERENCE_NAME: {}} runs.update(self.variants) for name, variant_overrides in runs.items(): @@ -78,7 +80,6 @@ def run(self) -> None: results = {name: self._compare(ref_artifacts, self._artifact_path(name)) for name in self.variants} report_path = self.output_dir / "precision_report.json" - report_path.parent.mkdir(parents=True, exist_ok=True) report_path.write_text(json.dumps(results, indent=2)) logger.info(f"Wrote report to {report_path}") @@ -89,29 +90,57 @@ def _artifact_path(self, name: str) -> pathlib.Path: return self.output_dir / name / "runs" / "0" / "artifacts" def _run_one(self, name: str, variant_overrides: dict[str, typing.Any]) -> None: - experiment_dir = (self.output_dir / name).resolve() - forced_fp32 = { - "model.distributed.compute_dtype": "float32", - "model.distributed.optimization_dtype": "float32", - } - tool_overrides = { - "training.train_iters": 1, - "training.checkpoint.interval": None, - "run.tensor_logs.save": True, - "run.tensor_logs.show": False, - "run.tensor_logs.max_elements": self.num_samples, - "run.experiment_dir": str(experiment_dir), - "model.multi_stage.debug_layer_outputs": _LOG_LEVEL, - "model.multi_stage.debug_layer_gradients": _LOG_LEVEL, - } - # Compose: forced fp32 first so a variant can override it (e.g. compute_dtype=bfloat16); - # tool overrides last so logging and single-iteration mode always win. - combined = {**forced_fp32, **variant_overrides, **tool_overrides} - cli_overrides = [f"{key}={yaml.safe_dump(value).strip()}" for key, value in combined.items()] + config_dict = self._build_config_dict(name) + # Apply variant overrides on top of the forced-fp32 baseline so a variant can set + # `model.distributed.compute_dtype: bfloat16` (etc.) and have it win. + for dotted_key, value in variant_overrides.items(): + _set_nested(config_dict, dotted_key.split("."), value) + config_yaml = self.output_dir / f"{name}_config.yaml" + config_yaml.write_text(yaml.safe_dump(config_dict)) logger.info(f"=== Running {name!r} ===") if variant_overrides: logger.info(f"Variant overrides: {variant_overrides}") - TrainerConfig.parse_and_run([self.model_type, "-c", str(self.training_config), *cli_overrides]) + TrainerConfig.parse_and_run([self.model_type, "-c", str(config_yaml)]) + + def _build_config_dict(self, name: str) -> dict[str, typing.Any]: + return { + "pretrained": self.pretrained, + "training": { + "train_iters": 1, + "num_workers": 0, + "logs": {"interval": 1}, + }, + "optimizer": { + "learning_rate": { + "base": 0.0, + "decay_style": "constant", + "warmup_iterations": 0, + }, + }, + "data": { + "datasets": {"training": {"type": "random"}}, + "micro_batch_size": self.micro_batch_size, + "maximum_document_length": self.sequence_length, + }, + "run": { + "experiment_dir": str((self.output_dir / name).resolve()), + "tensor_logs": { + "save": True, + "show": False, + "max_elements": self.num_samples, + }, + }, + "model": { + "distributed": { + "compute_dtype": "float32", + "optimization_dtype": "float32", + }, + "multi_stage": { + "debug_layer_outputs": _LOG_LEVEL, + "debug_layer_gradients": _LOG_LEVEL, + }, + }, + } def _compare(self, ref_path: pathlib.Path, test_path: pathlib.Path) -> list[dict[str, typing.Any]]: compare_config = CompareConfig() @@ -156,12 +185,10 @@ def _classify(tensor_name: str) -> str: return "other" -def _get_nested(d: typing.Any, keys: tuple[str, ...]) -> typing.Any: - for k in keys: - if not isinstance(d, dict) or k not in d: - return None - d = d[k] - return d +def _set_nested(d: dict[str, typing.Any], keys: list[str], value: typing.Any) -> None: + for key in keys[:-1]: + d = d.setdefault(key, {}) + d[keys[-1]] = value def _print_table(name: str, rows: list[dict[str, typing.Any]]) -> None: From 5ebea3374483330ccb8507b1fbebd5515a91c7c3 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 27 May 2026 15:07:47 -0400 Subject: [PATCH 03/31] Expose `model:` alongside `pretrained:` in the tool config MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The tool's input mirrors the trainer config's top-level shape: both `model:` (FastLLMModelConfig dict) and `pretrained:` are user-facing, and either or both may be set. Pretrained-from-HF is one config choice among many — a user can also specify the architecture inline, or load from HF and override individual fields. The forced fp32 dtypes and tool-required debug levels are now applied as overrides on top of whatever the user supplies, instead of being hardcoded into the model section. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index 016744468..09925f263 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -26,11 +26,19 @@ @config_class() class EvaluatePrecisionConfig(RunnableConfig): - pretrained: dict[str, typing.Any] = Field( - desc="Fast-LLM `CheckpointLoadConfig` dict (e.g. `{path: HuggingFaceTB/SmolLM2-135M, format: llama}`)." - " The model architecture and weights are loaded from this checkpoint.", + model: dict[str, typing.Any] = Field( + default_factory=dict, + desc="`FastLLMModelConfig` dict (`base_model`, `distributed`, `multi_stage`)." + " Forwarded into the trainer config as-is alongside `pretrained`. Either or both" + " can be set: `pretrained` to load architecture/weights from a checkpoint," + " `model` to specify the architecture inline or override pretrained fields.", hint=FieldHint.core, ) + pretrained: dict[str, typing.Any] = Field( + default_factory=dict, + desc="`CheckpointLoadConfig` dict, e.g. `{path: HuggingFaceTB/SmolLM2-135M, format: llama}`.", + hint=FieldHint.optional, + ) variants: dict[str, typing.Any] = Field( desc="Named override bundles to evaluate against the fp32 reference." " Each value is a flat dict mapping dotted-path keys (same syntax as the Fast-LLM CLI) to values.", @@ -91,10 +99,14 @@ def _artifact_path(self, name: str) -> pathlib.Path: def _run_one(self, name: str, variant_overrides: dict[str, typing.Any]) -> None: config_dict = self._build_config_dict(name) - # Apply variant overrides on top of the forced-fp32 baseline so a variant can set - # `model.distributed.compute_dtype: bfloat16` (etc.) and have it win. + # Force fp32 on the reference baseline (variants apply on top and can re-override). + _set_nested(config_dict, ["model", "distributed", "compute_dtype"], "float32") + _set_nested(config_dict, ["model", "distributed", "optimization_dtype"], "float32") for dotted_key, value in variant_overrides.items(): _set_nested(config_dict, dotted_key.split("."), value) + # Tool-required overrides always win — variants must not silently disable tensor logging. + _set_nested(config_dict, ["model", "multi_stage", "debug_layer_outputs"], _LOG_LEVEL) + _set_nested(config_dict, ["model", "multi_stage", "debug_layer_gradients"], _LOG_LEVEL) config_yaml = self.output_dir / f"{name}_config.yaml" config_yaml.write_text(yaml.safe_dump(config_dict)) logger.info(f"=== Running {name!r} ===") @@ -105,6 +117,7 @@ def _run_one(self, name: str, variant_overrides: dict[str, typing.Any]) -> None: def _build_config_dict(self, name: str) -> dict[str, typing.Any]: return { "pretrained": self.pretrained, + "model": self.model, "training": { "train_iters": 1, "num_workers": 0, @@ -130,16 +143,6 @@ def _build_config_dict(self, name: str) -> dict[str, typing.Any]: "max_elements": self.num_samples, }, }, - "model": { - "distributed": { - "compute_dtype": "float32", - "optimization_dtype": "float32", - }, - "multi_stage": { - "debug_layer_outputs": _LOG_LEVEL, - "debug_layer_gradients": _LOG_LEVEL, - }, - }, } def _compare(self, ref_path: pathlib.Path, test_path: pathlib.Path) -> list[dict[str, typing.Any]]: From 4c444d81f52955dad0c07c2b91adbfc7e38ac6aa Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 27 May 2026 15:15:59 -0400 Subject: [PATCH 04/31] Inherit PretrainedGPTModelConfig; use Config update mechanism MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The tool now inherits from `PretrainedGPTModelConfig` so `model` and `pretrained` are typed `FastLLMModelConfig` / `CheckpointLoadConfig` fields rather than loose dicts — validated, autocompleted, and introspectable like any other Fast-LLM config block. Per-variant trainer configs are built with `TrainerConfig.get_subclass(...) .from_dict(base, *updates)` instead of mutating a dict and round-tripping through YAML. Updates use tuple-keyed dotted paths so forced-fp32, variant overrides, and tool-required debug-logging overrides compose cleanly in the right precedence. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 93 ++++++++++++++----------------------- 1 file changed, 36 insertions(+), 57 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index 09925f263..4c56848b6 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -3,12 +3,11 @@ import pathlib import typing -import yaml - from fast_llm.config import Field, FieldHint, config_class from fast_llm.engine.config_utils.compare_tensor_logs import CompareConfig from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.training.config import TrainerConfig +from fast_llm.models.gpt.config import PretrainedGPTModelConfig # Populate the trainer dynamic-type registry. import fast_llm.data.auto # noqa: F401 # isort:skip @@ -22,23 +21,20 @@ # matching the convention in the existing layer-comparison tests. _LOG_LEVEL = 13 _REFERENCE_NAME = "reference" +_MODEL_TYPE = "gpt" @config_class() -class EvaluatePrecisionConfig(RunnableConfig): - model: dict[str, typing.Any] = Field( - default_factory=dict, - desc="`FastLLMModelConfig` dict (`base_model`, `distributed`, `multi_stage`)." - " Forwarded into the trainer config as-is alongside `pretrained`. Either or both" - " can be set: `pretrained` to load architecture/weights from a checkpoint," - " `model` to specify the architecture inline or override pretrained fields.", - hint=FieldHint.core, - ) - pretrained: dict[str, typing.Any] = Field( - default_factory=dict, - desc="`CheckpointLoadConfig` dict, e.g. `{path: HuggingFaceTB/SmolLM2-135M, format: llama}`.", - hint=FieldHint.optional, - ) +class EvaluatePrecisionConfig(PretrainedGPTModelConfig, RunnableConfig): + """Evaluate layer-wise numerical-error propagation against an fp32 reference. + + Inherits `model` and `pretrained` from `PretrainedGPTModelConfig`: either or both + can be set in the YAML. The tool runs one fp32 reference + one trainer invocation + per variant, captures per-layer forward activations and input gradients via the + standard tensor-logs pipeline, and reports per-tensor RMS / max diffs. + """ + + _abstract = False variants: dict[str, typing.Any] = Field( desc="Named override bundles to evaluate against the fp32 reference." " Each value is a flat dict mapping dotted-path keys (same syntax as the Fast-LLM CLI) to values.", @@ -48,11 +44,6 @@ class EvaluatePrecisionConfig(RunnableConfig): desc="Directory for per-run tensor-log artifacts and the final JSON report.", hint=FieldHint.core, ) - model_type: str = Field( - default="gpt", - desc="Trainer dynamic-type name used to dispatch to the right `TrainerConfig` subclass.", - hint=FieldHint.optional, - ) num_samples: int = Field( default=1024, desc="Number of sampled values stored per logged tensor.", @@ -98,37 +89,18 @@ def _artifact_path(self, name: str) -> pathlib.Path: return self.output_dir / name / "runs" / "0" / "artifacts" def _run_one(self, name: str, variant_overrides: dict[str, typing.Any]) -> None: - config_dict = self._build_config_dict(name) - # Force fp32 on the reference baseline (variants apply on top and can re-override). - _set_nested(config_dict, ["model", "distributed", "compute_dtype"], "float32") - _set_nested(config_dict, ["model", "distributed", "optimization_dtype"], "float32") - for dotted_key, value in variant_overrides.items(): - _set_nested(config_dict, dotted_key.split("."), value) - # Tool-required overrides always win — variants must not silently disable tensor logging. - _set_nested(config_dict, ["model", "multi_stage", "debug_layer_outputs"], _LOG_LEVEL) - _set_nested(config_dict, ["model", "multi_stage", "debug_layer_gradients"], _LOG_LEVEL) - config_yaml = self.output_dir / f"{name}_config.yaml" - config_yaml.write_text(yaml.safe_dump(config_dict)) - logger.info(f"=== Running {name!r} ===") - if variant_overrides: - logger.info(f"Variant overrides: {variant_overrides}") - TrainerConfig.parse_and_run([self.model_type, "-c", str(config_yaml)]) - - def _build_config_dict(self, name: str) -> dict[str, typing.Any]: - return { - "pretrained": self.pretrained, - "model": self.model, + # Base config: hardcoded training/optimizer/data/run skeleton plus the user's model/pretrained. + # Forced fp32 on the reference baseline lives in here too so a variant can override it. + base_dict: dict[str, typing.Any] = { + "pretrained": self.pretrained.to_dict(), + "model": self.model.to_dict(), "training": { "train_iters": 1, "num_workers": 0, "logs": {"interval": 1}, }, "optimizer": { - "learning_rate": { - "base": 0.0, - "decay_style": "constant", - "warmup_iterations": 0, - }, + "learning_rate": {"base": 0.0, "decay_style": "constant", "warmup_iterations": 0}, }, "data": { "datasets": {"training": {"type": "random"}}, @@ -137,13 +109,26 @@ def _build_config_dict(self, name: str) -> dict[str, typing.Any]: }, "run": { "experiment_dir": str((self.output_dir / name).resolve()), - "tensor_logs": { - "save": True, - "show": False, - "max_elements": self.num_samples, - }, + "tensor_logs": {"save": True, "show": False, "max_elements": self.num_samples}, }, } + fp32_dtypes = { + ("model", "distributed", "compute_dtype"): "float32", + ("model", "distributed", "optimization_dtype"): "float32", + } + variant_updates = {tuple(key.split(".")): value for key, value in variant_overrides.items()} + # Tool-required overrides win over variants — a variant must not silently disable tensor logging. + tool_overrides = { + ("model", "multi_stage", "debug_layer_outputs"): _LOG_LEVEL, + ("model", "multi_stage", "debug_layer_gradients"): _LOG_LEVEL, + } + logger.info(f"=== Running {name!r} ===") + if variant_overrides: + logger.info(f"Variant overrides: {variant_overrides}") + trainer_class = TrainerConfig.get_subclass(_MODEL_TYPE) + trainer_config = trainer_class.from_dict(base_dict, fp32_dtypes, variant_updates, tool_overrides) + trainer_config.configure_logging() + trainer_config._get_runnable()() def _compare(self, ref_path: pathlib.Path, test_path: pathlib.Path) -> list[dict[str, typing.Any]]: compare_config = CompareConfig() @@ -188,12 +173,6 @@ def _classify(tensor_name: str) -> str: return "other" -def _set_nested(d: dict[str, typing.Any], keys: list[str], value: typing.Any) -> None: - for key in keys[:-1]: - d = d.setdefault(key, {}) - d[keys[-1]] = value - - def _print_table(name: str, rows: list[dict[str, typing.Any]]) -> None: print(f"\n=== Variant: {name} ===") if not rows: From 35206a6c2a37e2ee0d669bcabbed6b5c0cd885cc Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 27 May 2026 15:46:47 -0400 Subject: [PATCH 05/31] Expand HF metadata allowlist for newer transformers configs `transformers.PretrainedConfig.to_dict()` serializes a growing set of generic defaults (generation knobs, family markers, encoder-decoder flags). The Fast-LLM allowlist covered only a subset, so loading any modern HF Llama checkpoint via `pretrained.format: llama` tripped the coverage walker on keys like `torchscript`, `is_decoder`, `is_llama_config`, `rope_interleaved`, and the full set of generation defaults. Fill in the missing entries, grouped by category. None of them are architecture knobs that Fast-LLM consumes. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/engine/checkpoint/huggingface.py | 41 +++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index c055a7f2c..a4810dc1a 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -128,20 +128,32 @@ def _export_config(cls, config: FastLLMModelConfig) -> dict[str, typing.Any]: { # transformers PretrainedConfig "_name_or_path", + "add_cross_attention", "architectures", "auto_map", "chunk_size_feed_forward", + "cross_attention_hidden_size", "dtype", + "finetuning_task", "id2label", + "is_decoder", "is_encoder_decoder", "label2id", "model_type", "output_attentions", "output_hidden_states", + "prefix", "problem_type", + "pruned_heads", "return_dict", + "task_specific_params", + "tf_legacy_loss", + "tie_encoder_decoder", + "tokenizer_class", "torch_dtype", + "torchscript", "transformers_version", + "use_bfloat16", "use_cache", # Token ids — generation/inference, not architecture. "bos_token_id", @@ -149,10 +161,39 @@ def _export_config(cls, config: FastLLMModelConfig) -> dict[str, typing.Any]: "eos_token_id", "pad_token_id", "sep_token_id", + # Generation defaults — never architecture. + "bad_words_ids", + "begin_suppress_tokens", + "diversity_penalty", + "do_sample", + "early_stopping", + "encoder_no_repeat_ngram_size", + "exponential_decay_length_penalty", + "forced_bos_token_id", + "forced_eos_token_id", + "length_penalty", + "max_length", + "min_length", + "no_repeat_ngram_size", + "num_beam_groups", + "num_beams", + "num_return_sequences", + "output_scores", + "remove_invalid_values", + "repetition_penalty", + "return_dict_in_generate", + "suppress_tokens", + "temperature", + "top_k", + "top_p", + "typical_p", # Initialization / pretraining metadata Fast-LLM does not consume. "initializer_range", "max_position_embeddings", "pretraining_tp", + # Family markers / default-valued knobs serialized by recent transformers versions. + "is_llama_config", + "rope_interleaved", } ) From bde1efa903288e06bb4d67386a11811aeabf033c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 27 May 2026 16:14:51 -0400 Subject: [PATCH 06/31] Reshape console table for readability Drop step / shape / max_rel columns, shorten the tensor name to the description after the colon, reorder to Tensor / Kind / Relative / Absolute / Max / Scale, format Relative as percent and the rest with `.3g`. The JSON report keeps every field. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 31 +++++++++---------------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index 4c56848b6..0e7f70707 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -178,32 +178,19 @@ def _print_table(name: str, rows: list[dict[str, typing.Any]]) -> None: if not rows: print("(no matching tensors)") return - columns = [ - ("step", "step", 6), - ("kind", "kind", 6), - ("tensor_name", "tensor", 48), - ("shape", "shape", 22), - ("ref_scale", "ref_scale", 12), - ("rms_abs", "rms_abs", 12), - ("rms_rel", "rms_rel", 12), - ("max_abs", "max_abs", 12), - ("max_rel", "max_rel", 12), + columns: list[tuple[str, str, int, typing.Callable[[typing.Any], str]]] = [ + ("tensor_name", "Tensor", 28, lambda v: v.split(":", 1)[-1].strip()), + ("kind", "Kind", 4, str), + ("rms_rel", "Relative", 9, lambda v: f"{v * 100:.3g}%"), + ("rms_abs", "Absolute", 10, lambda v: f"{v:.3g}"), + ("max_abs", "Max", 10, lambda v: f"{v:.3g}"), + ("ref_scale", "Scale", 10, lambda v: f"{v:.3g}"), ] - header = " ".join(f"{title:<{width}}" for _, title, width in columns) + header = " ".join(f"{title:<{width}}" for _, title, width, _ in columns) print(header) print("-" * len(header)) for row in rows: - parts = [] - for key, _, width in columns: - value = row[key] - if isinstance(value, float): - cell = f"{value:.4e}" - elif isinstance(value, list): - cell = "x".join(str(x) for x in value) - else: - cell = str(value) - parts.append(f"{cell:<{width}}") - print(" ".join(parts)) + print(" ".join(f"{format_fn(row[key]):<{width}}" for key, _, width, format_fn in columns)) if __name__ == "__main__": From 8099b51cae606914c3a6c3c08b443165b061519d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 27 May 2026 16:22:48 -0400 Subject: [PATCH 07/31] Merge tensor+kind, fix decimal precision in console table Drop the separate Kind column and append `(fw)` / `(bw)` to the shortened tensor name. Switch numeric formatting to fixed precision: Relative shows `.2f` percent, Absolute / Max / Scale show `.2e` scientific. Every column now lines up on a consistent digit count. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index 0e7f70707..f74a31649 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -178,19 +178,18 @@ def _print_table(name: str, rows: list[dict[str, typing.Any]]) -> None: if not rows: print("(no matching tensors)") return - columns: list[tuple[str, str, int, typing.Callable[[typing.Any], str]]] = [ - ("tensor_name", "Tensor", 28, lambda v: v.split(":", 1)[-1].strip()), - ("kind", "Kind", 4, str), - ("rms_rel", "Relative", 9, lambda v: f"{v * 100:.3g}%"), - ("rms_abs", "Absolute", 10, lambda v: f"{v:.3g}"), - ("max_abs", "Max", 10, lambda v: f"{v:.3g}"), - ("ref_scale", "Scale", 10, lambda v: f"{v:.3g}"), + columns: list[tuple[str, int, typing.Callable[[dict[str, typing.Any]], str]]] = [ + ("Tensor", 26, lambda r: f"{r['tensor_name'].split(':', 1)[-1].strip()} ({r['kind']})"), + ("Relative", 8, lambda r: f"{r['rms_rel'] * 100:.2f}%"), + ("Absolute", 10, lambda r: f"{r['rms_abs']:.2e}"), + ("Max", 10, lambda r: f"{r['max_abs']:.2e}"), + ("Scale", 10, lambda r: f"{r['ref_scale']:.2e}"), ] - header = " ".join(f"{title:<{width}}" for _, title, width, _ in columns) + header = " ".join(f"{title:<{width}}" for title, width, _ in columns) print(header) print("-" * len(header)) for row in rows: - print(" ".join(f"{format_fn(row[key]):<{width}}" for key, _, width, format_fn in columns)) + print(" ".join(f"{format_fn(row):<{width}}" for _, width, format_fn in columns)) if __name__ == "__main__": From dbd7702f7db29e85a9a31021475c48f1bce508e8 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 27 May 2026 16:26:36 -0400 Subject: [PATCH 08/31] Switch back to fixed-decimal formatting in the table Scientific notation was overkill for values that mostly land between 0.01 and a few hundred. `.3f` is more readable while keeping the per-column digit count consistent. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index f74a31649..6cd30a224 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -181,9 +181,9 @@ def _print_table(name: str, rows: list[dict[str, typing.Any]]) -> None: columns: list[tuple[str, int, typing.Callable[[dict[str, typing.Any]], str]]] = [ ("Tensor", 26, lambda r: f"{r['tensor_name'].split(':', 1)[-1].strip()} ({r['kind']})"), ("Relative", 8, lambda r: f"{r['rms_rel'] * 100:.2f}%"), - ("Absolute", 10, lambda r: f"{r['rms_abs']:.2e}"), - ("Max", 10, lambda r: f"{r['max_abs']:.2e}"), - ("Scale", 10, lambda r: f"{r['ref_scale']:.2e}"), + ("Absolute", 10, lambda r: f"{r['rms_abs']:.3f}"), + ("Max", 10, lambda r: f"{r['max_abs']:.3f}"), + ("Scale", 10, lambda r: f"{r['ref_scale']:.3f}"), ] header = " ".join(f"{title:<{width}}" for title, width, _ in columns) print(header) From 152ffc36df08da66f301052286f9e93ed3d7f056 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 27 May 2026 16:43:19 -0400 Subject: [PATCH 09/31] Wipe per-variant experiment dir before each run MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fast-LLM's `Run.__init__` picks the next free `runs/` subdirectory based on what already exists, but `_artifact_path` reads `runs/0` unconditionally. Without this wipe, re-running the tool against the same `output_dir` reads stale artifacts from the first invocation and silently reports old numbers — even though the trainer correctly ran with the new config. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index 6cd30a224..a9a17b2e4 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -1,6 +1,7 @@ import json import logging import pathlib +import shutil import typing from fast_llm.config import Field, FieldHint, config_class @@ -89,6 +90,12 @@ def _artifact_path(self, name: str) -> pathlib.Path: return self.output_dir / name / "runs" / "0" / "artifacts" def _run_one(self, name: str, variant_overrides: dict[str, typing.Any]) -> None: + # The trainer's Run picks the next `runs/` subdir based on what already exists; wipe + # any prior contents so each invocation lands in `runs/0` and stale artifacts can't be + # read by `_artifact_path` below. + experiment_dir = self.output_dir / name + if experiment_dir.exists(): + shutil.rmtree(experiment_dir) # Base config: hardcoded training/optimizer/data/run skeleton plus the user's model/pretrained. # Forced fp32 on the reference baseline lives in here too so a variant can override it. base_dict: dict[str, typing.Any] = { From 7e98500d85c3f20d44f410d7d8cd07ca07a4abed Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 12:04:45 -0400 Subject: [PATCH 10/31] Support pre-generated memmap dataset; misc table-format polish MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a `data_path` field to the tool. When set, the tool lazily generates a tokenized memmap dataset with random advantages and old_logprobs at the given path (via the test helper `tests/utils/dataset._get_test_dataset`) and uses it as the training input. Required for policy-gradient losses like GSPO/GRPO that consume those fields. Without it, the tool falls back to the random token generator as before. Console table now formats numeric columns with `.4g` so 1e-7-scale GSPO gradients aren't rounded to zero while normal CE-magnitude values still read as fixed-point numbers. Rename `download_santacoder_tokenizer` to `download_test_tokenizer` — it actually downloads the GPT-2 tokenizer. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/data/test_tokenizer.py | 4 ++-- tests/models/test_lm_eval.py | 4 ++-- tests/utils/dataset.py | 4 ++-- tools/evaluate_precision.py | 39 ++++++++++++++++++++++++++++++++---- 4 files changed, 41 insertions(+), 10 deletions(-) diff --git a/tests/data/test_tokenizer.py b/tests/data/test_tokenizer.py index 184294551..04a24e2ae 100644 --- a/tests/data/test_tokenizer.py +++ b/tests/data/test_tokenizer.py @@ -2,13 +2,13 @@ from fast_llm.data.preparation.tokenizer import Tokenizer, TokenizerConfig from fast_llm.utils import Assert -from tests.utils.dataset import download_santacoder_tokenizer +from tests.utils.dataset import download_test_tokenizer from tests.utils.global_variables import TOKENIZER_PATH @pytest.fixture(scope="session") def common_tokenizer() -> Tokenizer: - download_santacoder_tokenizer() + download_test_tokenizer() return TokenizerConfig(path=TOKENIZER_PATH).get_tokenizer() diff --git a/tests/models/test_lm_eval.py b/tests/models/test_lm_eval.py index 7ae26c2d6..c8b5fd004 100644 --- a/tests/models/test_lm_eval.py +++ b/tests/models/test_lm_eval.py @@ -3,7 +3,7 @@ import pytest -from tests.utils.dataset import download_santacoder_tokenizer +from tests.utils.dataset import download_test_tokenizer from tests.utils.distributed_configs import DistributedTestingConfig from tests.utils.global_variables import TOKENIZER_PATH from tests.utils.model_configs import ModelTestingGroup @@ -15,7 +15,7 @@ @pytest.fixture(scope="module") def tokenizer_path(): - download_santacoder_tokenizer() + download_test_tokenizer() return TOKENIZER_PATH diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index a2ea2f46e..e7b206cf5 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -14,7 +14,7 @@ from tests.utils.global_variables import DATASET_CACHE, MODEL_TEST_VOCAB_SIZE, TOKENIZER_FILE, TOKENIZER_PATH -def download_santacoder_tokenizer(): +def download_test_tokenizer(): if not TOKENIZER_FILE.is_file(): import transformers @@ -218,7 +218,7 @@ def _get_test_dataset( if has_grpo_data: source_schema["advantages"] = "advantages" - download_santacoder_tokenizer() + download_test_tokenizer() preparator_config = GPTMemmapDatasetPreparatorConfig.from_dict( { "dataset": { diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index a9a17b2e4..55d01cc1f 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -60,6 +60,13 @@ class EvaluatePrecisionConfig(PretrainedGPTModelConfig, RunnableConfig): desc="Sequence length (maximum document length) for the random input.", hint=FieldHint.feature, ) + data_path: pathlib.Path | None = Field( + default=None, + desc="If set, prepare a tokenized memmap dataset with advantages and `old_log_probabilities`" + " at this path (using the test helper `_get_test_dataset`) and use it as the training" + " input — required for policy-gradient losses like GSPO/GRPO. If unset, uses random tokens.", + hint=FieldHint.feature, + ) def _validate(self) -> None: super()._validate() @@ -71,6 +78,7 @@ def _validate(self) -> None: def run(self) -> None: self.output_dir.mkdir(parents=True, exist_ok=True) + self._prepare_data() runs: dict[str, dict[str, typing.Any]] = {_REFERENCE_NAME: {}} runs.update(self.variants) for name, variant_overrides in runs.items(): @@ -86,6 +94,23 @@ def run(self) -> None: for name, rows in results.items(): _print_table(name, rows) + def _prepare_data(self) -> None: + if self.data_path is None: + return + if (self.data_path / "fast_llm_config.yaml").is_file(): + return + # Couples `tools/` to `tests/utils/` for now — extract later if it sticks. + from tests.utils.dataset import _get_test_dataset + + self.data_path.mkdir(parents=True, exist_ok=True) + logger.info(f"Preparing memmap dataset at {self.data_path}") + _get_test_dataset( + self.data_path, + seed=42, + has_grpo_data=True, + max_vocab_size=self.model.base_model.embeddings.vocab_size, + ) + def _artifact_path(self, name: str) -> pathlib.Path: return self.output_dir / name / "runs" / "0" / "artifacts" @@ -110,7 +135,13 @@ def _run_one(self, name: str, variant_overrides: dict[str, typing.Any]) -> None: "learning_rate": {"base": 0.0, "decay_style": "constant", "warmup_iterations": 0}, }, "data": { - "datasets": {"training": {"type": "random"}}, + "datasets": { + "training": ( + {"type": "file", "path": str(self.data_path / "fast_llm_config.yaml")} + if self.data_path is not None + else {"type": "random"} + ) + }, "micro_batch_size": self.micro_batch_size, "maximum_document_length": self.sequence_length, }, @@ -188,9 +219,9 @@ def _print_table(name: str, rows: list[dict[str, typing.Any]]) -> None: columns: list[tuple[str, int, typing.Callable[[dict[str, typing.Any]], str]]] = [ ("Tensor", 26, lambda r: f"{r['tensor_name'].split(':', 1)[-1].strip()} ({r['kind']})"), ("Relative", 8, lambda r: f"{r['rms_rel'] * 100:.2f}%"), - ("Absolute", 10, lambda r: f"{r['rms_abs']:.3f}"), - ("Max", 10, lambda r: f"{r['max_abs']:.3f}"), - ("Scale", 10, lambda r: f"{r['ref_scale']:.3f}"), + ("Absolute", 10, lambda r: f"{r['rms_abs']:.4g}"), + ("Max", 10, lambda r: f"{r['max_abs']:.4g}"), + ("Scale", 10, lambda r: f"{r['ref_scale']:.4g}"), ] header = " ".join(f"{title:<{width}}" for title, width, _ in columns) print(header) From 173ae0de6e3f6d9c5adf2ea6fe7fde8c6d5ca4f3 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 12:19:20 -0400 Subject: [PATCH 11/31] Print per-variant summary at the end of the run After the per-tensor tables, emit a short summary block per variant showing first/last/max/median for forward and backward separately. Aggregates over the intermediate layers per metric column (max and median are computed per-column, so each row is a per-metric envelope of the intermediate band rather than the metrics of any single layer). Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index 55d01cc1f..beed05564 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -2,6 +2,7 @@ import logging import pathlib import shutil +import statistics import typing from fast_llm.config import Field, FieldHint, config_class @@ -93,6 +94,9 @@ def run(self) -> None: for name, rows in results.items(): _print_table(name, rows) + print("\n=== Summary ===") + for name, rows in results.items(): + _print_table(name, _summary_rows(rows)) def _prepare_data(self) -> None: if self.data_path is None: @@ -200,6 +204,26 @@ def _compare(self, ref_path: pathlib.Path, test_path: pathlib.Path) -> list[dict return rows +def _summary_rows(rows: list[dict[str, typing.Any]]) -> list[dict[str, typing.Any]]: + out: list[dict[str, typing.Any]] = [] + metric_keys = ("rms_rel", "rms_abs", "max_abs", "ref_scale") + for kind in ("fw", "bw"): + group = [r for r in rows if r["kind"] == kind] + if not group: + continue + first, last = group[0], group[-1] + intermediate = group[1:-1] + out.append({**first, "tensor_name": "first", "kind": kind}) + out.append({**last, "tensor_name": "last", "kind": kind}) + if intermediate: + for agg_name, agg in (("max", max), ("median", statistics.median)): + aggregated = {"tensor_name": agg_name, "kind": kind} + for key in metric_keys: + aggregated[key] = agg(r[key] for r in intermediate) + out.append(aggregated) + return out + + def _classify(tensor_name: str) -> str: # Stage._log_layer_forward / _log_layer_backward produce " fw[, mb=…]" # and " bw[, mb=…]"; log_distributed_tensor may prefix the name From 005fd6222b07e6be5751cea8e6134940eecbb9d9 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 12:21:02 -0400 Subject: [PATCH 12/31] =?UTF-8?q?Reshape=20end-of-run=20summary:=20variant?= =?UTF-8?q?s=20=C3=97=20aggregations,=20relative=20only?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Single compact table with one row per variant and columns for fw/bw first/last/max/median Relative %. Max/median are over intermediate layers (excluding first/last) when there is at least one intermediate row. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 46 ++++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index beed05564..3b40c7e17 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -94,9 +94,7 @@ def run(self) -> None: for name, rows in results.items(): _print_table(name, rows) - print("\n=== Summary ===") - for name, rows in results.items(): - _print_table(name, _summary_rows(rows)) + _print_summary(results) def _prepare_data(self) -> None: if self.data_path is None: @@ -204,24 +202,30 @@ def _compare(self, ref_path: pathlib.Path, test_path: pathlib.Path) -> list[dict return rows -def _summary_rows(rows: list[dict[str, typing.Any]]) -> list[dict[str, typing.Any]]: - out: list[dict[str, typing.Any]] = [] - metric_keys = ("rms_rel", "rms_abs", "max_abs", "ref_scale") - for kind in ("fw", "bw"): - group = [r for r in rows if r["kind"] == kind] - if not group: - continue - first, last = group[0], group[-1] - intermediate = group[1:-1] - out.append({**first, "tensor_name": "first", "kind": kind}) - out.append({**last, "tensor_name": "last", "kind": kind}) - if intermediate: - for agg_name, agg in (("max", max), ("median", statistics.median)): - aggregated = {"tensor_name": agg_name, "kind": kind} - for key in metric_keys: - aggregated[key] = agg(r[key] for r in intermediate) - out.append(aggregated) - return out +def _print_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: + columns = [(f"{kind} {agg}", kind, agg) for kind in ("fw", "bw") for agg in ("first", "last", "max", "median")] + name_width = max((len(name) for name in results), default=7) + 2 + cell_width = 10 + print("\n=== Summary (Relative %) ===") + header = f"{'Variant':<{name_width}}" + "".join(f"{h:<{cell_width}}" for h, _, _ in columns) + print(header) + print("-" * len(header)) + for name, rows in results.items(): + cells = [] + for _, kind, agg in columns: + group = [r["rms_rel"] for r in rows if r["kind"] == kind] + if not group: + cells.append("n/a") + continue + if agg == "first": + value = group[0] + elif agg == "last": + value = group[-1] + else: + intermediate = group[1:-1] or group + value = max(intermediate) if agg == "max" else statistics.median(intermediate) + cells.append(f"{value * 100:.2f}%") + print(f"{name:<{name_width}}" + "".join(f"{c:<{cell_width}}" for c in cells)) def _classify(tensor_name: str) -> str: From c59465889ca16e141c85035e270f33657b861920 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 12:27:08 -0400 Subject: [PATCH 13/31] Clarify intermediate aggregation in summary header Rename `max`/`median` columns to `mid max`/`mid med` and add a header note (`mid = excluding first/last`) so it's clear the aggregation excludes the boundary layers. Also fix a column-collision bug where labels at exactly the cell width touched without separator. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index 3b40c7e17..4feb10792 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -203,11 +203,12 @@ def _compare(self, ref_path: pathlib.Path, test_path: pathlib.Path) -> list[dict def _print_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: - columns = [(f"{kind} {agg}", kind, agg) for kind in ("fw", "bw") for agg in ("first", "last", "max", "median")] + agg_labels = {"first": "first", "last": "last", "max": "mid max", "median": "mid med"} + columns = [(f"{kind} {agg_labels[agg]}", kind, agg) for kind in ("fw", "bw") for agg in agg_labels] name_width = max((len(name) for name in results), default=7) + 2 - cell_width = 10 - print("\n=== Summary (Relative %) ===") - header = f"{'Variant':<{name_width}}" + "".join(f"{h:<{cell_width}}" for h, _, _ in columns) + cell_width = max(len(label) for label, _, _ in columns) + 1 + print("\n=== Summary (Relative %; mid = excluding first/last) ===") + header = f"{'Variant':<{name_width}}" + " ".join(f"{h:<{cell_width}}" for h, _, _ in columns) print(header) print("-" * len(header)) for name, rows in results.items(): @@ -225,7 +226,7 @@ def _print_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: intermediate = group[1:-1] or group value = max(intermediate) if agg == "max" else statistics.median(intermediate) cells.append(f"{value * 100:.2f}%") - print(f"{name:<{name_width}}" + "".join(f"{c:<{cell_width}}" for c in cells)) + print(f"{name:<{name_width}}" + " ".join(f"{c:<{cell_width}}" for c in cells)) def _classify(tensor_name: str) -> str: From 3159f73efeb99564d6c42df550c4ee8537e6df94 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 12:31:25 -0400 Subject: [PATCH 14/31] Split summary across fw/bw rows; one extra precision digit Each variant now occupies two rows in the summary (fw on the first, bw on the second), with the metric columns shared. Reads more naturally and keeps the table half as wide. Percent precision goes from .2f to .3f so single-digit-percent differences between variants are visible. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index 4feb10792..7b4d1fd05 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -204,29 +204,30 @@ def _compare(self, ref_path: pathlib.Path, test_path: pathlib.Path) -> list[dict def _print_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: agg_labels = {"first": "first", "last": "last", "max": "mid max", "median": "mid med"} - columns = [(f"{kind} {agg_labels[agg]}", kind, agg) for kind in ("fw", "bw") for agg in agg_labels] name_width = max((len(name) for name in results), default=7) + 2 - cell_width = max(len(label) for label, _, _ in columns) + 1 + cell_width = max(len(label) for label in agg_labels.values()) + 2 print("\n=== Summary (Relative %; mid = excluding first/last) ===") - header = f"{'Variant':<{name_width}}" + " ".join(f"{h:<{cell_width}}" for h, _, _ in columns) + header = f"{'Variant':<{name_width}}{'':<4}" + " ".join(f"{label:<{cell_width}}" for label in agg_labels.values()) print(header) print("-" * len(header)) for name, rows in results.items(): - cells = [] - for _, kind, agg in columns: + for index, kind in enumerate(("fw", "bw")): group = [r["rms_rel"] for r in rows if r["kind"] == kind] - if not group: - cells.append("n/a") - continue - if agg == "first": - value = group[0] - elif agg == "last": - value = group[-1] - else: - intermediate = group[1:-1] or group - value = max(intermediate) if agg == "max" else statistics.median(intermediate) - cells.append(f"{value * 100:.2f}%") - print(f"{name:<{name_width}}" + " ".join(f"{c:<{cell_width}}" for c in cells)) + cells = [] + for agg in agg_labels: + if not group: + cells.append("n/a") + continue + if agg == "first": + value = group[0] + elif agg == "last": + value = group[-1] + else: + intermediate = group[1:-1] or group + value = max(intermediate) if agg == "max" else statistics.median(intermediate) + cells.append(f"{value * 100:.3f}%") + name_cell = name if index == 0 else "" + print(f"{name_cell:<{name_width}}{kind:<4}" + " ".join(f"{c:<{cell_width}}" for c in cells)) def _classify(tensor_name: str) -> str: From 6ef153e154ed86cffb609307c7b3e32febf60bf9 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 12:34:20 -0400 Subject: [PATCH 15/31] Two-row column header in summary; chronological column order MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Top header line groups columns under `fw` / `bw`; the second line lists the per-pass aggregations. Aggregations are ordered chronologically along the pass — first → mid med → mid max → last — so reading left to right traces the propagation. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 42 ++++++++++++++++++++++++------------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index 7b4d1fd05..f5d3d8e4e 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -203,31 +203,43 @@ def _compare(self, ref_path: pathlib.Path, test_path: pathlib.Path) -> list[dict def _print_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: - agg_labels = {"first": "first", "last": "last", "max": "mid max", "median": "mid med"} + # Chronological column order: first → intermediate (median, max) → last. + aggs = ("first", "median", "max", "last") + agg_labels = {"first": "first", "median": "mid med", "max": "mid max", "last": "last"} + kinds = ("fw", "bw") name_width = max((len(name) for name in results), default=7) + 2 - cell_width = max(len(label) for label in agg_labels.values()) + 2 + cell_width = max(len(label) for label in agg_labels.values()) + 1 + group_sep = " " + group_width = len(" ".join(f"{agg_labels[a]:<{cell_width}}" for a in aggs)) print("\n=== Summary (Relative %; mid = excluding first/last) ===") - header = f"{'Variant':<{name_width}}{'':<4}" + " ".join(f"{label:<{cell_width}}" for label in agg_labels.values()) - print(header) - print("-" * len(header)) + top = f"{'':<{name_width}}" + group_sep.join(f"{kind:^{group_width}}" for kind in kinds) + bottom = f"{'Variant':<{name_width}}" + group_sep.join( + " ".join(f"{agg_labels[a]:<{cell_width}}" for a in aggs) for _ in kinds + ) + print(top) + print(bottom) + print("-" * len(bottom)) for name, rows in results.items(): - for index, kind in enumerate(("fw", "bw")): - group = [r["rms_rel"] for r in rows if r["kind"] == kind] + groups = [] + for kind in kinds: + values = [r["rms_rel"] for r in rows if r["kind"] == kind] + intermediate = values[1:-1] or values cells = [] - for agg in agg_labels: - if not group: + for agg in aggs: + if not values: cells.append("n/a") continue if agg == "first": - value = group[0] + value = values[0] elif agg == "last": - value = group[-1] + value = values[-1] + elif agg == "max": + value = max(intermediate) else: - intermediate = group[1:-1] or group - value = max(intermediate) if agg == "max" else statistics.median(intermediate) + value = statistics.median(intermediate) cells.append(f"{value * 100:.3f}%") - name_cell = name if index == 0 else "" - print(f"{name_cell:<{name_width}}{kind:<4}" + " ".join(f"{c:<{cell_width}}" for c in cells)) + groups.append(" ".join(f"{c:<{cell_width}}" for c in cells)) + print(f"{name:<{name_width}}" + group_sep.join(groups)) def _classify(tensor_name: str) -> str: From 7327932e47f781405e7e02862fe52662780ea74b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 12:40:20 -0400 Subject: [PATCH 16/31] Add fp32_lm_head flag for vLLM precision parity MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds an `fp32_lm_head` field on `LanguageModelHeadConfig`. When `True`, the LM head linear's input and weight are upcast to FP32 before the matmul, matching vLLM's `bf16_last_layer_fp32` quantization. This lets the trainer compute log-probabilities at the same numerical precision as the actor's sampling, so the importance-sampling ratio starts near 1.0 instead of being artificially inflated by a trainer/actor precision mismatch. The detached FP32 weight has `requires_grad=False`, which makes `output_parallel_linear_backward` skip the weight-grad path. The FSDP gradient contract is restored by computing `grad_weight = grad.t() @ saved_input` explicitly and accumulating into the original BF16 param's `grad_buffer` via `accumulate_gradient`. Off by default — disabled path is byte-identical to before. Cherry-picked from #526 to unblock the precision-evaluation tool's GSPO smoke test, which compares fp32_lm_head=true vs false. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/language_model/config.py | 7 +++++ fast_llm/layers/language_model/head.py | 34 +++++++++++++++++++----- 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index bde33f297..6a0bfcfd6 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -131,6 +131,13 @@ class LanguageModelHeadConfig(BlockConfig): hint=FieldHint.architecture, valid=skip_valid_if_none(check_field(Assert.gt, 0)), ) + fp32_lm_head: bool = Field( + default=False, + desc="Upcast input and weight to float32 before the lm_head linear. " + "Matches vLLM's bf16_last_layer_fp32 quantization so new_logprobs and old_logprobs " + "are computed at the same numerical precision, keeping the IS ratio near 1 at init.", + hint=FieldHint.feature, + ) prediction_heads: int = Field( default=1, desc="Prediction heads.", diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 22c750082..eb67cd553 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -22,7 +22,7 @@ ) from fast_llm.layers.language_model.loss.config import LanguageModelLabelEntropyLossConfig from fast_llm.layers.language_model.loss.loss import LanguageModelLoss -from fast_llm.tensor import TensorMeta +from fast_llm.tensor import TensorMeta, accumulate_gradient from fast_llm.utils import Assert, safe_merge_dicts logger = logging.getLogger(__name__) @@ -252,9 +252,17 @@ def _logits_loss_forward_backward_partial( split_index: int = 0, return_logits: bool = False, ) -> tuple[torch.Tensor | None, torch.Tensor | None]: + if self._config.fp32_lm_head: + input_dtype = input_.dtype + input_ = input_.to(torch.float32) + # detach → requires_grad=False → output_parallel_linear_backward skips weight grad + weight = self.output_weights.detach().to(torch.float32) + else: + weight = self.output_weights + logits, context = output_parallel_linear_forward( input_=input_, - weight=self.output_weights, + weight=weight, bias=None, group=self._parallel_dim.group if self._vocab_parallel else None, sequence_parallel=self._sequence_parallel and self._vocab_parallel, @@ -285,12 +293,26 @@ def _logits_loss_forward_backward_partial( if loss_value is not None: losses_.append(loss_value.detach()) - if grad is not None and self._config.final_logit_softcap is not None: + if not self.training or grad is None: + return sum(losses_) if losses_ else None, None + + if self._config.final_logit_softcap is not None: grad = _softcap_backward(grad, logits, self._config.final_logit_softcap) - return sum(losses_) if losses_ else None, ( - output_parallel_linear_backward(grad, context) if self.training else None - ) + input_grad = output_parallel_linear_backward(grad, context) + if self._config.fp32_lm_head: + # Weight grad was skipped because weight.requires_grad=False; accumulate manually. + # context: (input_, weight, bias, group, sequence_parallel, ...) + saved_input = context[0] + if context[4]: # sequence_parallel + from fast_llm.core.ops import gather_op + + saved_input = gather_op(saved_input, context[3], dim=0) + grad_weight = grad.flatten(0, -2).t().mm(saved_input.flatten(0, -2)) + accumulate_gradient(self.output_weights, grad_weight.to(self.output_weights.dtype)) + input_grad = input_grad.to(input_dtype) + + return sum(losses_) if losses_ else None, input_grad def get_loss_definitions(self) -> list[LossDef]: return [ From 76335dffd75b96c6baa7d008301bf87d2e7c5170 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 13:06:53 -0400 Subject: [PATCH 17/31] Extract layer-name labels for summary first/last columns MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Instead of generic `first` / `last` headers in the summary, use the actual layer name pulled from the matching tensor's `Global :` prefix. For the SmolLM2 smoke run that surfaces as `embeddings` / `head` on fw and `head` / `decoder.0` on bw — directly showing which layer the boundary values come from rather than making the reader guess. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 38 ++++++++++++++++++++++++++++++++----- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index f5d3d8e4e..e27909333 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -202,19 +202,47 @@ def _compare(self, ref_path: pathlib.Path, test_path: pathlib.Path) -> list[dict return rows +def _layer_name(tensor_name: str) -> str: + # Stage hooks name tensors `Global fw: ...` / `Global bw: ...`; + # extract the layer to use as a meaningful column label. + prefix = tensor_name.split(":", 1)[0].strip().split() + if prefix and prefix[0] == "Global": + prefix = prefix[1:] + if prefix and prefix[-1] in ("fw", "bw"): + prefix = prefix[:-1] + return " ".join(prefix) if prefix else "?" + + def _print_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: # Chronological column order: first → intermediate (median, max) → last. aggs = ("first", "median", "max", "last") - agg_labels = {"first": "first", "median": "mid med", "max": "mid max", "last": "last"} + # Per-pass labels for `first`/`last` come from the actual layer name on the matching row. + sample = next(iter(results.values())) + endpoint_labels: dict[tuple[str, str], str] = { + ("fw", "first"): "first", + ("fw", "last"): "last", + ("bw", "first"): "first", + ("bw", "last"): "last", + } + for kind in ("fw", "bw"): + group = [r for r in sample if r["kind"] == kind] + if group: + endpoint_labels[(kind, "first")] = _layer_name(group[0]["tensor_name"]) + endpoint_labels[(kind, "last")] = _layer_name(group[-1]["tensor_name"]) + mid_labels = {"median": "mid med", "max": "mid max"} kinds = ("fw", "bw") + + def _label(kind: str, agg: str) -> str: + return endpoint_labels[(kind, agg)] if agg in ("first", "last") else mid_labels[agg] + name_width = max((len(name) for name in results), default=7) + 2 - cell_width = max(len(label) for label in agg_labels.values()) + 1 + cell_width = max(len(_label(k, a)) for k in kinds for a in aggs) + 1 group_sep = " " - group_width = len(" ".join(f"{agg_labels[a]:<{cell_width}}" for a in aggs)) + group_widths = {kind: len(" ".join(f"{_label(kind, a):<{cell_width}}" for a in aggs)) for kind in kinds} print("\n=== Summary (Relative %; mid = excluding first/last) ===") - top = f"{'':<{name_width}}" + group_sep.join(f"{kind:^{group_width}}" for kind in kinds) + top = f"{'':<{name_width}}" + group_sep.join(f"{kind:^{group_widths[kind]}}" for kind in kinds) bottom = f"{'Variant':<{name_width}}" + group_sep.join( - " ".join(f"{agg_labels[a]:<{cell_width}}" for a in aggs) for _ in kinds + " ".join(f"{_label(kind, a):<{cell_width}}" for a in aggs) for kind in kinds ) print(top) print(bottom) From 8122946df08040dd3cc2ab2640b0cf18d5f1b619 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 13:24:29 -0400 Subject: [PATCH 18/31] Add `debug_hidden_states_log` to capture named tensors via output_hidden_states MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously the only way to get a non-layer-output tensor (e.g. the LM head's logits) into `tensor_logs` was to crank `model_debug_level`, which logs every single `_debug`-emitted tensor (~700 per step for a 30-layer model). Add a `MultiStageConfig.debug_hidden_states_log: list[str]` field — regex patterns that get appended to each model input's `output_hidden_states` set. Matching tensors are still populated into `kwargs[hidden_states]` (existing contract for the HF inference wrapper); now they're also written to `tensor_logs` so the precision tool can compare them across variants. `_debug` already had the `output_hidden_state`-matched branch but only used it to populate `kwargs[hidden_states]`. Extending it to also call `log_distributed_tensor` at a fixed verbosity (13, matching the test convention so samples are recorded) is a small gating change. Plumbed through `GPTModel.get_preprocessing_config` → `LanguageModelBatchPreprocessingConfig.output_hidden_states` → `LanguageModelBatch.get_model_inputs`, which compiles the patterns and unions them into each `LanguageModelInput.output_hidden_states`. The precision tool now sets `[r"head\.logits"]` and surfaces logits as a dedicated `logits` column on the fw side of the summary table. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/data/document/config.py | 6 +++++ fast_llm/data/document/language_model.py | 7 +++++ fast_llm/engine/multi_stage/config.py | 8 ++++++ fast_llm/layers/block/block.py | 15 ++++++++--- fast_llm/models/gpt/model.py | 1 + tools/evaluate_precision.py | 34 +++++++++++++++++++----- 6 files changed, 61 insertions(+), 10 deletions(-) diff --git a/fast_llm/data/document/config.py b/fast_llm/data/document/config.py index a90bcdebc..fbfe60ac3 100644 --- a/fast_llm/data/document/config.py +++ b/fast_llm/data/document/config.py @@ -80,6 +80,12 @@ class LanguageModelBatchPreprocessingConfig(TokenPreprocessingConfig): use_preference_spans: bool = Field(default=False) use_grpo_data: bool = Field(default=False) return_label_counts: bool = Field(default=False) + output_hidden_states: list[str] = Field( + default_factory=list, + desc="Regex patterns to add to each model input's `output_hidden_states` set." + " Matching `_debug`-named tensors get populated into `kwargs[hidden_states]`" + " and (when running under a `Run` context) emitted into `tensor_logs`.", + ) def _validate(self) -> None: super()._validate() diff --git a/fast_llm/data/document/language_model.py b/fast_llm/data/document/language_model.py index 16114cb80..000fcc01d 100644 --- a/fast_llm/data/document/language_model.py +++ b/fast_llm/data/document/language_model.py @@ -161,6 +161,13 @@ def get_model_inputs(self, config: LanguageModelBatchPreprocessingConfig) -> lis self._set_target_inputs(model_inputs, config) + if config.output_hidden_states: + import re + + patterns = {re.compile(pattern) for pattern in config.output_hidden_states} + for model_input in model_inputs: + model_input.output_hidden_states.update(patterns) + return model_inputs def _set_target_inputs( diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index 958a3d228..96cb52f09 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -139,6 +139,14 @@ class StageConfig(Config): desc="Check for tensor-parallel desyncs and log an error if a desync is found. High overhead", hint=FieldHint.logging, ) + debug_hidden_states_log: list[str] = Field( + default_factory=list, + desc="Regex patterns for `_debug`-named tensors (`.`, e.g. `head.logits`," + " `decoder.0.norm_1`) to log to `tensor_logs`. Patterns are appended to each model" + " input's `output_hidden_states` set, so matching tensors are both populated into" + " `kwargs[hidden_states]` for downstream consumers and emitted into `tensor_logs`.", + hint=FieldHint.logging, + ) @config_class() diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 805eae1e5..0476a8107 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -18,6 +18,12 @@ logger = logging.getLogger(__name__) +# Verbosity used for `output_hidden_states`-driven tensor logging. `log_tensor` collects sampled +# tensor values only at level >= 3; 13 matches the convention in the layer-comparison tests +# (1024 sampled values per tensor). +_HIDDEN_STATE_LOG_LEVEL = 13 + + class DebugLayer: """ A debugging utility for blocks. @@ -55,11 +61,14 @@ def __call__( if level > 1: log_pipeline_parallel_main_rank(lambda: log_memory_usage(name, str)) - if level > 0 and tensor is not None: + # `output_hidden_state` requests full-fidelity capture even when `model_debug_level` is + # off — clamp the log level so samples are saved alongside summary stats. + log_level = max(level, _HIDDEN_STATE_LOG_LEVEL) if output_hidden_state else level + if log_level > 0 and tensor is not None: log_distributed_tensor( "", tensor, - level=level, + level=log_level, meta=meta, **logging_kwargs, ) @@ -67,7 +76,7 @@ def __call__( log_distributed_grad( "", tensor, - level=level, + level=log_level, meta=self._get_meta(tensor, f"{name}.grad", dims), **logging_kwargs, ) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 2e9b4365b..f4d4b286a 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -112,6 +112,7 @@ def get_preprocessing_config( return LanguageModelBatchPreprocessingConfig( phase=phase, micro_batch_splits=micro_batch_splits, + output_hidden_states=list(self._config.multi_stage.debug_hidden_states_log), **self._base_model.get_preprocessing_config(), ) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index e27909333..afcb818c0 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -161,6 +161,9 @@ def _run_one(self, name: str, variant_overrides: dict[str, typing.Any]) -> None: tool_overrides = { ("model", "multi_stage", "debug_layer_outputs"): _LOG_LEVEL, ("model", "multi_stage", "debug_layer_gradients"): _LOG_LEVEL, + # Capture the LM-head logits via the `output_hidden_states` mechanism: the head's + # `_debug(logits, ...)` call matches this pattern and emits to `tensor_logs`. + ("model", "multi_stage", "debug_hidden_states_log"): [r"head\.logits"], } logger.info(f"=== Running {name!r} ===") if variant_overrides: @@ -213,9 +216,14 @@ def _layer_name(tensor_name: str) -> str: return " ".join(prefix) if prefix else "?" +def _logits_row(rows: list[dict[str, typing.Any]]) -> dict[str, typing.Any] | None: + return next( + (r for r in rows if r["tensor_name"].split(":", 1)[-1].strip() == "head.logits"), + None, + ) + + def _print_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: - # Chronological column order: first → intermediate (median, max) → last. - aggs = ("first", "median", "max", "last") # Per-pass labels for `first`/`last` come from the actual layer name on the matching row. sample = next(iter(results.values())) endpoint_labels: dict[tuple[str, str], str] = { @@ -229,31 +237,41 @@ def _print_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: if group: endpoint_labels[(kind, "first")] = _layer_name(group[0]["tensor_name"]) endpoint_labels[(kind, "last")] = _layer_name(group[-1]["tensor_name"]) - mid_labels = {"median": "mid med", "max": "mid max"} + mid_labels = {"median": "mid med", "max": "mid max", "logits": "logits"} + # Logits show up on the fw side via `output_hidden_states` ("Global : head.logits"); + # add a dedicated column for it (chronologically just before the head output / loss). + has_logits = _logits_row(sample) is not None + aggs_per_kind = { + "fw": ("first", "median", "max", "logits", "last") if has_logits else ("first", "median", "max", "last"), + "bw": ("first", "median", "max", "last"), + } kinds = ("fw", "bw") def _label(kind: str, agg: str) -> str: return endpoint_labels[(kind, agg)] if agg in ("first", "last") else mid_labels[agg] name_width = max((len(name) for name in results), default=7) + 2 - cell_width = max(len(_label(k, a)) for k in kinds for a in aggs) + 1 + cell_width = max(len(_label(k, a)) for k in kinds for a in aggs_per_kind[k]) + 1 group_sep = " " - group_widths = {kind: len(" ".join(f"{_label(kind, a):<{cell_width}}" for a in aggs)) for kind in kinds} + group_widths = { + kind: len(" ".join(f"{_label(kind, a):<{cell_width}}" for a in aggs_per_kind[kind])) for kind in kinds + } print("\n=== Summary (Relative %; mid = excluding first/last) ===") top = f"{'':<{name_width}}" + group_sep.join(f"{kind:^{group_widths[kind]}}" for kind in kinds) bottom = f"{'Variant':<{name_width}}" + group_sep.join( - " ".join(f"{_label(kind, a):<{cell_width}}" for a in aggs) for kind in kinds + " ".join(f"{_label(kind, a):<{cell_width}}" for a in aggs_per_kind[kind]) for kind in kinds ) print(top) print(bottom) print("-" * len(bottom)) for name, rows in results.items(): + logits_value = _logits_row(rows)["rms_rel"] if _logits_row(rows) else float("nan") groups = [] for kind in kinds: values = [r["rms_rel"] for r in rows if r["kind"] == kind] intermediate = values[1:-1] or values cells = [] - for agg in aggs: + for agg in aggs_per_kind[kind]: if not values: cells.append("n/a") continue @@ -261,6 +279,8 @@ def _label(kind: str, agg: str) -> str: value = values[0] elif agg == "last": value = values[-1] + elif agg == "logits": + value = logits_value elif agg == "max": value = max(intermediate) else: From 4633bfde1d45ff78b8b3674d9678cac53d0dc0e4 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 13:29:42 -0400 Subject: [PATCH 19/31] Capture logit gradients; expose them in the summary MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The head's `logits` tensor has `requires_grad=False` (output of a custom-autograd Function), so the existing `_debug(logits, ...)` could only capture the forward value. Add a second `_debug(grad, "logits.grad", ...)` call right after the loss returns the explicit `dL/d_logits` so the gradient is captured at the same fidelity. With the precision tool's `output_hidden_states` pattern `r"head\.logits"`, both `head.logits` and `head.logits.grad` end up in tensor_logs. Tool summary surfaces both via dedicated `logits` columns — placed at end-of-fw and start-of-bw chronologically. For GSPO the bw-logits column reveals that the dL/dlogits computation itself is extremely precise (~0.001% relative error), and the apparent backward noise actually enters through the head matmul further downstream. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/language_model/head.py | 12 +++++++++ tools/evaluate_precision.py | 35 +++++++++++++++----------- 2 files changed, 32 insertions(+), 15 deletions(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index eb67cd553..8dd511480 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -293,6 +293,18 @@ def _logits_loss_forward_backward_partial( if loss_value is not None: losses_.append(loss_value.detach()) + if grad is not None: + # `logits` has `requires_grad=False` (custom-autograd), so the existing + # `_debug(logits, ...)` can't auto-capture the gradient. Log it explicitly here + # so `output_hidden_states` patterns covering `head.logits` also catch the grad. + self._debug( + grad, + f"logits.grad{"" if self._config.cross_entropy_splits == 1 else f"_{split_index}"}", + (kwargs.get(LanguageModelKwargs.hidden_token_dim), self._vocab_dim), + kwargs, + scale=self._config.logits_scale_factor, + ) + if not self.training or grad is None: return sum(losses_) if losses_ else None, None diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index afcb818c0..61d612895 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -216,11 +216,8 @@ def _layer_name(tensor_name: str) -> str: return " ".join(prefix) if prefix else "?" -def _logits_row(rows: list[dict[str, typing.Any]]) -> dict[str, typing.Any] | None: - return next( - (r for r in rows if r["tensor_name"].split(":", 1)[-1].strip() == "head.logits"), - None, - ) +def _named_row(rows: list[dict[str, typing.Any]], name: str) -> dict[str, typing.Any] | None: + return next((r for r in rows if r["tensor_name"].split(":", 1)[-1].strip() == name), None) def _print_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: @@ -238,13 +235,14 @@ def _print_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: endpoint_labels[(kind, "first")] = _layer_name(group[0]["tensor_name"]) endpoint_labels[(kind, "last")] = _layer_name(group[-1]["tensor_name"]) mid_labels = {"median": "mid med", "max": "mid max", "logits": "logits"} - # Logits show up on the fw side via `output_hidden_states` ("Global : head.logits"); - # add a dedicated column for it (chronologically just before the head output / loss). - has_logits = _logits_row(sample) is not None - aggs_per_kind = { - "fw": ("first", "median", "max", "logits", "last") if has_logits else ("first", "median", "max", "last"), - "bw": ("first", "median", "max", "last"), - } + # Logits show up via `output_hidden_states` (`Global : head.logits` on the fw side and + # `Global : head.logits.grad` on the bw side once the loss has computed dL/dlogits). + # Each gets a dedicated column placed chronologically: end-of-fw and start-of-bw. + has_fw_logits = _named_row(sample, "head.logits") is not None + has_bw_logits = _named_row(sample, "head.logits.grad") is not None + fw_aggs = ("first", "median", "max") + (("logits",) if has_fw_logits else ()) + ("last",) + bw_aggs = (("logits",) if has_bw_logits else ()) + ("first", "median", "max", "last") + aggs_per_kind = {"fw": fw_aggs, "bw": bw_aggs} kinds = ("fw", "bw") def _label(kind: str, agg: str) -> str: @@ -265,7 +263,12 @@ def _label(kind: str, agg: str) -> str: print(bottom) print("-" * len(bottom)) for name, rows in results.items(): - logits_value = _logits_row(rows)["rms_rel"] if _logits_row(rows) else float("nan") + logits_fw = _named_row(rows, "head.logits") + logits_bw = _named_row(rows, "head.logits.grad") + logits_value = { + "fw": logits_fw["rms_rel"] if logits_fw else float("nan"), + "bw": logits_bw["rms_rel"] if logits_bw else float("nan"), + } groups = [] for kind in kinds: values = [r["rms_rel"] for r in rows if r["kind"] == kind] @@ -280,7 +283,7 @@ def _label(kind: str, agg: str) -> str: elif agg == "last": value = values[-1] elif agg == "logits": - value = logits_value + value = logits_value[kind] elif agg == "max": value = max(intermediate) else: @@ -294,7 +297,9 @@ def _classify(tensor_name: str) -> str: # Stage._log_layer_forward / _log_layer_backward produce " fw[, mb=…]" # and " bw[, mb=…]"; log_distributed_tensor may prefix the name # with "Global " and append a ": " suffix when reconstructing a - # tensor-parallel-global tensor. + # tensor-parallel-global tensor. Other entries (e.g. `Global : head.logits`, + # `Global : head.logits.grad`) come from the `_debug` / `output_hidden_states` path + # and are surfaced via dedicated logits columns in the summary. for kind in ("fw", "bw"): if f" {kind}:" in tensor_name or f" {kind}," in tensor_name or tensor_name.endswith(f" {kind}"): return kind From 9ca17115835b53b26a79da21a7cb4b9f188fb6ba Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 13:35:48 -0400 Subject: [PATCH 20/31] Place logits after head in bw summary; widen format for sub-percent values MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `.3f%` was rounding the bw-logits values down to 0.001%-0.000%, hiding real signal. Switch to `.4g%` so values across 5 orders of magnitude (0.0001% to ~20%) all render with meaningful precision; large values keep 4 significant figures, tiny ones spell out their leading non-zero digits or fall back to scientific. Bw column order is now first / logits / mid med / mid max / last so `logits` sits right after `head` (the first bw row) — semantically the gradient at logits is what the head's backward consumes before producing the gradient at its input. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index 61d612895..5f23d206e 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -241,7 +241,7 @@ def _print_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: has_fw_logits = _named_row(sample, "head.logits") is not None has_bw_logits = _named_row(sample, "head.logits.grad") is not None fw_aggs = ("first", "median", "max") + (("logits",) if has_fw_logits else ()) + ("last",) - bw_aggs = (("logits",) if has_bw_logits else ()) + ("first", "median", "max", "last") + bw_aggs = ("first",) + (("logits",) if has_bw_logits else ()) + ("median", "max", "last") aggs_per_kind = {"fw": fw_aggs, "bw": bw_aggs} kinds = ("fw", "bw") @@ -288,7 +288,7 @@ def _label(kind: str, agg: str) -> str: value = max(intermediate) else: value = statistics.median(intermediate) - cells.append(f"{value * 100:.3f}%") + cells.append(f"{value * 100:.4g}%") groups.append(" ".join(f"{c:<{cell_width}}" for c in cells)) print(f"{name:<{name_width}}" + group_sep.join(groups)) From f2655f39223ea13555bf9b7aec4efd96dbcc5eac Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 13:40:11 -0400 Subject: [PATCH 21/31] =?UTF-8?q?Pick=20per-column=20decimals=20to=20guara?= =?UTF-8?q?ntee=20=E2=89=A52=20sig=20figs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Keep the prior `.3f%` default in the summary so most columns still show `0.000%` / `12.672%` style values, but compute a per-column decimal count based on the smallest non-zero value in that column — bumping up just enough that every cell carries at least two significant figures. Decimal count is uniform within a column. For the GSPO run, only the bw-logits column hits the threshold and gets bumped from 3 to 5 decimals, surfacing values like `0.00095%` that previously rounded to `0.001%` or worse. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 49 +++++++++++++++++++++++++++++-------- 1 file changed, 39 insertions(+), 10 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index 5f23d206e..6f6d86657 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -1,5 +1,6 @@ import json import logging +import math import pathlib import shutil import statistics @@ -262,6 +263,9 @@ def _label(kind: str, agg: str) -> str: print(top) print(bottom) print("-" * len(bottom)) + # Collect raw values first so we can pick a per-column decimal count: keep the previous + # .3f% default, but bump up just enough to give every cell in a column ≥ 2 sig figs. + raw: dict[str, dict[tuple[str, str], float | None]] = {} for name, rows in results.items(): logits_fw = _named_row(rows, "head.logits") logits_bw = _named_row(rows, "head.logits.grad") @@ -269,30 +273,55 @@ def _label(kind: str, agg: str) -> str: "fw": logits_fw["rms_rel"] if logits_fw else float("nan"), "bw": logits_bw["rms_rel"] if logits_bw else float("nan"), } - groups = [] + cells: dict[tuple[str, str], float | None] = {} for kind in kinds: values = [r["rms_rel"] for r in rows if r["kind"] == kind] intermediate = values[1:-1] or values - cells = [] for agg in aggs_per_kind[kind]: if not values: - cells.append("n/a") + cells[(kind, agg)] = None continue if agg == "first": - value = values[0] + cells[(kind, agg)] = values[0] elif agg == "last": - value = values[-1] + cells[(kind, agg)] = values[-1] elif agg == "logits": - value = logits_value[kind] + cells[(kind, agg)] = logits_value[kind] elif agg == "max": - value = max(intermediate) + cells[(kind, agg)] = max(intermediate) else: - value = statistics.median(intermediate) - cells.append(f"{value * 100:.4g}%") - groups.append(" ".join(f"{c:<{cell_width}}" for c in cells)) + cells[(kind, agg)] = statistics.median(intermediate) + raw[name] = cells + + column_decimals: dict[tuple[str, str], int] = {} + for kind in kinds: + for agg in aggs_per_kind[kind]: + column_decimals[(kind, agg)] = _column_decimals( + cells[(kind, agg)] for cells in raw.values() if cells[(kind, agg)] is not None + ) + for name, cells in raw.items(): + groups = [] + for kind in kinds: + formatted = [] + for agg in aggs_per_kind[kind]: + value = cells[(kind, agg)] + if value is None: + formatted.append("n/a") + else: + formatted.append(f"{value * 100:.{column_decimals[(kind, agg)]}f}%") + groups.append(" ".join(f"{c:<{cell_width}}" for c in formatted)) print(f"{name:<{name_width}}" + group_sep.join(groups)) +def _column_decimals(values: typing.Iterable[float], min_sig_figs: int = 2, default: int = 3) -> int: + # Keep the previous default precision, but bump up so the smallest non-zero value + # carries at least `min_sig_figs` significant digits when formatted as percent. + smallest = min((abs(v) * 100 for v in values if v != 0), default=None) + if smallest is None or smallest >= 10 ** -(default - min_sig_figs + 1): + return default + return max(default, -math.floor(math.log10(smallest)) + min_sig_figs - 1) + + def _classify(tensor_name: str) -> str: # Stage._log_layer_forward / _log_layer_backward produce " fw[, mb=…]" # and " bw[, mb=…]"; log_distributed_tensor may prefix the name From 7f8ef96cebce3b0e4ad7dfa76eb224ad8d1ebfae Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 13:48:10 -0400 Subject: [PATCH 22/31] Tighten summary table spacing Cell width drops from `max_label + 1` to `max_label`, inter-cell sep from two spaces to one, group sep from four spaces to three. About 18 chars narrower on the GSPO smoke run with no loss of alignment or readability. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index 6f6d86657..ab35e3e2f 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -249,16 +249,17 @@ def _print_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: def _label(kind: str, agg: str) -> str: return endpoint_labels[(kind, agg)] if agg in ("first", "last") else mid_labels[agg] - name_width = max((len(name) for name in results), default=7) + 2 - cell_width = max(len(_label(k, a)) for k in kinds for a in aggs_per_kind[k]) + 1 - group_sep = " " + name_width = max((len(name) for name in results), default=7) + 1 + cell_width = max(len(_label(k, a)) for k in kinds for a in aggs_per_kind[k]) + cell_sep = " " + group_sep = " " group_widths = { - kind: len(" ".join(f"{_label(kind, a):<{cell_width}}" for a in aggs_per_kind[kind])) for kind in kinds + kind: len(cell_sep.join(f"{_label(kind, a):<{cell_width}}" for a in aggs_per_kind[kind])) for kind in kinds } print("\n=== Summary (Relative %; mid = excluding first/last) ===") top = f"{'':<{name_width}}" + group_sep.join(f"{kind:^{group_widths[kind]}}" for kind in kinds) bottom = f"{'Variant':<{name_width}}" + group_sep.join( - " ".join(f"{_label(kind, a):<{cell_width}}" for a in aggs_per_kind[kind]) for kind in kinds + cell_sep.join(f"{_label(kind, a):<{cell_width}}" for a in aggs_per_kind[kind]) for kind in kinds ) print(top) print(bottom) @@ -309,7 +310,7 @@ def _label(kind: str, agg: str) -> str: formatted.append("n/a") else: formatted.append(f"{value * 100:.{column_decimals[(kind, agg)]}f}%") - groups.append(" ".join(f"{c:<{cell_width}}" for c in formatted)) + groups.append(cell_sep.join(f"{c:<{cell_width}}" for c in formatted)) print(f"{name:<{name_width}}" + group_sep.join(groups)) From 08b163745529db50fbeaa81c0aa51ad1162c082f Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 14:07:18 -0400 Subject: [PATCH 23/31] Support HF Hub model ids in pretrained.path Lets `pretrained.path: org/model-id` resolve via huggingface_hub.snapshot_download when not a local directory, matching transformers' from_pretrained behavior. Local paths pass through unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/engine/checkpoint/huggingface.py | 39 +++++++++++++++-------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index a4810dc1a..4c99798c5 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -100,6 +100,18 @@ def _get_key(cls, parameter_name: str, shard_name: str) -> str: Assert.eq(shard_name, "weights") return parameter_name + @classmethod + def _resolve_path(cls, path: pathlib.Path) -> pathlib.Path: + """Resolve a local directory or HF Hub model id (e.g. ``meta-llama/Llama-3.2-1B``) to a + local snapshot directory. Local directories pass through unchanged; everything else is + materialized via :func:`huggingface_hub.snapshot_download` (cached on subsequent calls). + """ + if path.is_dir(): + return path + import huggingface_hub + + return pathlib.Path(huggingface_hub.snapshot_download(str(path))) + # Use custom config instead of relying on the transformers library @classmethod def _load_config(cls, directory: pathlib.Path | str) -> dict: @@ -222,28 +234,29 @@ def _load_weights( import transformers Assert.eq(self.get_shard_names(config), ("weights",)) - if (config.path / transformers.utils.SAFE_WEIGHTS_NAME).is_file(): - paths = {config.path / transformers.utils.SAFE_WEIGHTS_NAME} - elif (config.path / transformers.utils.SAFE_WEIGHTS_INDEX_NAME).is_file(): - logger.info(f"Loading index from {config.path / transformers.utils.SAFE_WEIGHTS_INDEX_NAME}") + directory = self._resolve_path(config.path) + if (directory / transformers.utils.SAFE_WEIGHTS_NAME).is_file(): + paths = {directory / transformers.utils.SAFE_WEIGHTS_NAME} + elif (directory / transformers.utils.SAFE_WEIGHTS_INDEX_NAME).is_file(): + logger.info(f"Loading index from {directory / transformers.utils.SAFE_WEIGHTS_INDEX_NAME}") paths = { - config.path / path - for path in json.loads((config.path / transformers.utils.SAFE_WEIGHTS_INDEX_NAME).read_text())[ + directory / path + for path in json.loads((directory / transformers.utils.SAFE_WEIGHTS_INDEX_NAME).read_text())[ "weight_map" ].values() } - elif (config.path / transformers.utils.WEIGHTS_NAME).is_file(): - paths = {config.path / transformers.utils.WEIGHTS_NAME} - elif (config.path / transformers.utils.WEIGHTS_INDEX_NAME).is_file(): - logger.info(f"Loading index from {config.path / transformers.utils.WEIGHTS_INDEX_NAME}") + elif (directory / transformers.utils.WEIGHTS_NAME).is_file(): + paths = {directory / transformers.utils.WEIGHTS_NAME} + elif (directory / transformers.utils.WEIGHTS_INDEX_NAME).is_file(): + logger.info(f"Loading index from {directory / transformers.utils.WEIGHTS_INDEX_NAME}") paths = { - config.path / path - for path in json.loads((config.path / transformers.utils.WEIGHTS_INDEX_NAME).read_text())[ + directory / path + for path in json.loads((directory / transformers.utils.WEIGHTS_INDEX_NAME).read_text())[ "weight_map" ].values() } else: - raise FileNotFoundError(f"No compatible checkpoint found in {config.path}") + raise FileNotFoundError(f"No compatible checkpoint found in {directory}") for path in paths: logger.info(f"Loading from {path}") From 77eae22226dcb46d258d22822754979ca6d977f4 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 14:07:25 -0400 Subject: [PATCH 24/31] Add example precision-evaluation configs Two ready-to-run configs for tools/evaluate_precision: smol.yaml sweeps precision-stability features (full_precision_gradients, full_precision_residual, fp32_lm_head) on SmolLM2-135M; smol_gspo.yaml repeats the sweep with the GSPO policy-gradient loss enabled. Co-Authored-By: Claude Opus 4.7 (1M context) --- examples/evaluate_precision/smol.yaml | 34 ++++++++++++++++++++ examples/evaluate_precision/smol_gspo.yaml | 37 ++++++++++++++++++++++ 2 files changed, 71 insertions(+) create mode 100644 examples/evaluate_precision/smol.yaml create mode 100644 examples/evaluate_precision/smol_gspo.yaml diff --git a/examples/evaluate_precision/smol.yaml b/examples/evaluate_precision/smol.yaml new file mode 100644 index 000000000..2d443d3ba --- /dev/null +++ b/examples/evaluate_precision/smol.yaml @@ -0,0 +1,34 @@ +# Example precision-evaluation config: sweep precision-stability features on SmolLM2-135M. +# +# Run with: +# python -m tools.evaluate_precision -c examples/evaluate_precision/smol.yaml +# +# `pretrained.path` accepts either a local checkpoint directory or a HF Hub model id +# (auto-downloaded via `huggingface_hub.snapshot_download` on first use). +pretrained: + path: HuggingFaceTB/SmolLM2-135M + format: llama +output_dir: /tmp/fast_llm_tests/evaluate_precision/features +sequence_length: 128 +num_samples: 512 +variants: + # Baseline bf16: compute_dtype=bf16 + Fast-LLM defaults (fp32 gradient accumulation, bf16 residual, bf16 lm_head). + bf16: + model.distributed.compute_dtype: bfloat16 + # Turn OFF the default fp32 gradient accumulation — gradients accumulate in bf16. + bf16_no_fp32_gradients: + model.distributed.compute_dtype: bfloat16 + model.multi_stage.full_precision_gradients: false + # Turn ON full-precision residual stream. + bf16_fp32_residual: + model.distributed.compute_dtype: bfloat16 + model.base_model.embeddings.full_precision_residual: true + # Turn ON fp32 LM head matmul (PR #526). + bf16_fp32_lm_head: + model.distributed.compute_dtype: bfloat16 + model.base_model.head.fp32_lm_head: true + # Both stability features on (most precise bf16-compute configuration). + bf16_max_precision: + model.distributed.compute_dtype: bfloat16 + model.base_model.embeddings.full_precision_residual: true + model.base_model.head.fp32_lm_head: true diff --git a/examples/evaluate_precision/smol_gspo.yaml b/examples/evaluate_precision/smol_gspo.yaml new file mode 100644 index 000000000..9e7188529 --- /dev/null +++ b/examples/evaluate_precision/smol_gspo.yaml @@ -0,0 +1,37 @@ +# Example precision-evaluation config: sweep precision-stability features on SmolLM2-135M +# with the GSPO policy-gradient loss (uses advantages and old log-probabilities). +# +# Run with: +# python -m tools.evaluate_precision -c examples/evaluate_precision/smol_gspo.yaml +# +# `pretrained.path` accepts either a local checkpoint directory or a HF Hub model id +# (auto-downloaded via `huggingface_hub.snapshot_download` on first use). +pretrained: + path: HuggingFaceTB/SmolLM2-135M + format: llama +model: + base_model: + head: + losses: + gspo: + type: gspo +output_dir: /tmp/fast_llm_tests/evaluate_precision/gspo +data_path: /tmp/fast_llm_tests/evaluate_precision/gspo_data +sequence_length: 128 +num_samples: 512 +variants: + bf16: + model.distributed.compute_dtype: bfloat16 + bf16_no_fp32_gradients: + model.distributed.compute_dtype: bfloat16 + model.multi_stage.full_precision_gradients: false + bf16_fp32_residual: + model.distributed.compute_dtype: bfloat16 + model.base_model.embeddings.full_precision_residual: true + bf16_fp32_lm_head: + model.distributed.compute_dtype: bfloat16 + model.base_model.head.fp32_lm_head: true + bf16_max_precision: + model.distributed.compute_dtype: bfloat16 + model.base_model.embeddings.full_precision_residual: true + model.base_model.head.fp32_lm_head: true From efa95b1fd4ee609581ed3e698a68b99a8cb39a90 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 14:15:05 -0400 Subject: [PATCH 25/31] Drop bf16_no_fp32_gradients variant from example configs A single forward+backward pass with micro_batch_size=1 has no gradient accumulation, so toggling full_precision_gradients produces bit-identical results to the bf16 baseline. Co-Authored-By: Claude Opus 4.7 (1M context) --- examples/evaluate_precision/smol.yaml | 4 ---- examples/evaluate_precision/smol_gspo.yaml | 3 --- 2 files changed, 7 deletions(-) diff --git a/examples/evaluate_precision/smol.yaml b/examples/evaluate_precision/smol.yaml index 2d443d3ba..8e052cbef 100644 --- a/examples/evaluate_precision/smol.yaml +++ b/examples/evaluate_precision/smol.yaml @@ -15,10 +15,6 @@ variants: # Baseline bf16: compute_dtype=bf16 + Fast-LLM defaults (fp32 gradient accumulation, bf16 residual, bf16 lm_head). bf16: model.distributed.compute_dtype: bfloat16 - # Turn OFF the default fp32 gradient accumulation — gradients accumulate in bf16. - bf16_no_fp32_gradients: - model.distributed.compute_dtype: bfloat16 - model.multi_stage.full_precision_gradients: false # Turn ON full-precision residual stream. bf16_fp32_residual: model.distributed.compute_dtype: bfloat16 diff --git a/examples/evaluate_precision/smol_gspo.yaml b/examples/evaluate_precision/smol_gspo.yaml index 9e7188529..c64276bdd 100644 --- a/examples/evaluate_precision/smol_gspo.yaml +++ b/examples/evaluate_precision/smol_gspo.yaml @@ -22,9 +22,6 @@ num_samples: 512 variants: bf16: model.distributed.compute_dtype: bfloat16 - bf16_no_fp32_gradients: - model.distributed.compute_dtype: bfloat16 - model.multi_stage.full_precision_gradients: false bf16_fp32_residual: model.distributed.compute_dtype: bfloat16 model.base_model.embeddings.full_precision_residual: true From 46bc5b8ea74957e5e553dfd3e3397560b7138c07 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 14:49:27 -0400 Subject: [PATCH 26/31] Add weight gradients to per-variant report tables Enables debug_all_param_gradients so every parameter's reduced gradient is captured in tensor_logs alongside the existing layer activations and input gradients. New rows are tagged with kind 'grad' and appear in the per-variant table but stay out of the fw/bw summary table. Also makes the per-variant table's Tensor column width fit the longest name (parameter gradients can be 40+ chars) and bumps the Relative column to adaptive precision (capped at 5 decimals) so legitimately tiny values stay legible. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 38 +++++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index ab35e3e2f..e38ce347a 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -162,6 +162,7 @@ def _run_one(self, name: str, variant_overrides: dict[str, typing.Any]) -> None: tool_overrides = { ("model", "multi_stage", "debug_layer_outputs"): _LOG_LEVEL, ("model", "multi_stage", "debug_layer_gradients"): _LOG_LEVEL, + ("model", "multi_stage", "debug_all_param_gradients"): _LOG_LEVEL, # Capture the LM-head logits via the `output_hidden_states` mechanism: the head's # `_debug(logits, ...)` call matches this pattern and emits to `tensor_logs`. ("model", "multi_stage", "debug_hidden_states_log"): [r"head\.logits"], @@ -314,22 +315,32 @@ def _label(kind: str, agg: str) -> str: print(f"{name:<{name_width}}" + group_sep.join(groups)) -def _column_decimals(values: typing.Iterable[float], min_sig_figs: int = 2, default: int = 3) -> int: - # Keep the previous default precision, but bump up so the smallest non-zero value - # carries at least `min_sig_figs` significant digits when formatted as percent. +def _column_decimals( + values: typing.Iterable[float], min_sig_figs: int = 2, default: int = 3, max_decimals: int | None = None +) -> int: + # Keep the default precision, but bump up so the smallest non-zero value carries at least + # `min_sig_figs` significant digits when formatted as percent. `max_decimals` caps the + # bump so a single tiny noisy value doesn't widen the whole column. smallest = min((abs(v) * 100 for v in values if v != 0), default=None) if smallest is None or smallest >= 10 ** -(default - min_sig_figs + 1): - return default - return max(default, -math.floor(math.log10(smallest)) + min_sig_figs - 1) + result = default + else: + result = max(default, -math.floor(math.log10(smallest)) + min_sig_figs - 1) + return min(result, max_decimals) if max_decimals is not None else result def _classify(tensor_name: str) -> str: # Stage._log_layer_forward / _log_layer_backward produce " fw[, mb=…]" # and " bw[, mb=…]"; log_distributed_tensor may prefix the name # with "Global " and append a ": " suffix when reconstructing a - # tensor-parallel-global tensor. Other entries (e.g. `Global : head.logits`, - # `Global : head.logits.grad`) come from the `_debug` / `output_hidden_states` path - # and are surfaced via dedicated logits columns in the summary. + # tensor-parallel-global tensor. Per-parameter gradient logs come from + # `Fsdp.log_shard(name="gradient", ...)` and are tagged "grad" so they appear + # in the per-variant table but stay out of the fw/bw summary aggregation. + # Other entries (e.g. `Global : head.logits`, `Global : head.logits.grad`) come + # from the `_debug` / `output_hidden_states` path and are surfaced via dedicated + # logits columns in the summary. + if "gradient:" in tensor_name: + return "grad" for kind in ("fw", "bw"): if f" {kind}:" in tensor_name or f" {kind}," in tensor_name or tensor_name.endswith(f" {kind}"): return kind @@ -341,9 +352,16 @@ def _print_table(name: str, rows: list[dict[str, typing.Any]]) -> None: if not rows: print("(no matching tensors)") return + name_fn = lambda r: f"{r['tensor_name'].split(':', 1)[-1].strip()} ({r['kind']})" + name_width = max(len("Tensor"), max(len(name_fn(r)) for r in rows)) + # Adaptive precision for the relative column: bump decimals so small but real values + # (typical for weight gradients) stay legible, capped at 5 to bound column width. + relative_decimals = _column_decimals((r["rms_rel"] for r in rows), default=2, max_decimals=5) + relative_fn = lambda r: f"{r['rms_rel'] * 100:.{relative_decimals}f}%" + relative_width = max(len("Relative"), max(len(relative_fn(r)) for r in rows)) columns: list[tuple[str, int, typing.Callable[[dict[str, typing.Any]], str]]] = [ - ("Tensor", 26, lambda r: f"{r['tensor_name'].split(':', 1)[-1].strip()} ({r['kind']})"), - ("Relative", 8, lambda r: f"{r['rms_rel'] * 100:.2f}%"), + ("Tensor", name_width, name_fn), + ("Relative", relative_width, relative_fn), ("Absolute", 10, lambda r: f"{r['rms_abs']:.4g}"), ("Max", 10, lambda r: f"{r['max_abs']:.4g}"), ("Scale", 10, lambda r: f"{r['ref_scale']:.4g}"), From bef2f0db6a94b31501767c97d56e9c265c40e270 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 14:56:03 -0400 Subject: [PATCH 27/31] Separate fw/bw/grad rows in per-variant tables MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Group rows in the per-variant tables by display group with blank lines between fw, bw, and grad. The reduce_gradients hook emits parameter gradients chronologically interleaved with the backward pass, which made the previous table hard to scan. Display grouping is independent of `kind` so the summary aggregation is unaffected — head.logits.grad just moves to the bw block visually. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index e38ce347a..f220df2e1 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -329,6 +329,17 @@ def _column_decimals( return min(result, max_decimals) if max_decimals is not None else result +def _display_group(row: dict[str, typing.Any]) -> str: + # Map each row to one of "fw"/"bw"/"grad" for the per-variant table, independent + # of `kind`: head.logits is a forward activation, head.logits.grad is a backward + # quantity, parameter gradients are their own group. + if row["kind"] == "grad": + return "grad" + if row["kind"] == "bw" or row["tensor_name"].endswith(".grad"): + return "bw" + return "fw" + + def _classify(tensor_name: str) -> str: # Stage._log_layer_forward / _log_layer_backward produce " fw[, mb=…]" # and " bw[, mb=…]"; log_distributed_tensor may prefix the name @@ -369,8 +380,22 @@ def _print_table(name: str, rows: list[dict[str, typing.Any]]) -> None: header = " ".join(f"{title:<{width}}" for title, width, _ in columns) print(header) print("-" * len(header)) + # Display grouping (fw / bw / grad) separates the chronologically-interleaved + # backward and reduce_gradients hooks. Independent of `kind` so the summary + # aggregation isn't affected. + groups = ("fw", "bw", "grad") + grouped: dict[str, list[dict[str, typing.Any]]] = {g: [] for g in groups} for row in rows: - print(" ".join(f"{format_fn(row):<{width}}" for _, width, format_fn in columns)) + grouped[_display_group(row)].append(row) + first = True + for group in groups: + if not grouped[group]: + continue + if not first: + print() + first = False + for row in grouped[group]: + print(" ".join(f"{format_fn(row):<{width}}" for _, width, format_fn in columns)) if __name__ == "__main__": From 4fecad4860de22b300f0591b7484aadecef6da07 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 15:15:59 -0400 Subject: [PATCH 28/31] Split summary into three tables (fw, bw, grad) Each pass gets its own self-contained Variant x columns table with labels picked from the actual first/last logged tensor. Weight gradients get a head/mid med/mid max/embeddings layout mirroring the bw structure; the grad table makes large norm_1 outliers (>200% relative) immediately visible at a glance. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 126 ++++++++++++++++-------------------- 1 file changed, 54 insertions(+), 72 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index f220df2e1..236871b35 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -209,8 +209,11 @@ def _compare(self, ref_path: pathlib.Path, test_path: pathlib.Path) -> list[dict def _layer_name(tensor_name: str) -> str: # Stage hooks name tensors `Global fw: ...` / `Global bw: ...`; - # extract the layer to use as a meaningful column label. + # Fsdp.log_shard names weight gradients `Global gradient: `. prefix = tensor_name.split(":", 1)[0].strip().split() + if prefix == ["Global", "gradient"]: + param = tensor_name.split(":", 1)[1].strip() + return param.split(".")[0] if prefix and prefix[0] == "Global": prefix = prefix[1:] if prefix and prefix[-1] in ("fw", "bw"): @@ -223,51 +226,38 @@ def _named_row(rows: list[dict[str, typing.Any]], name: str) -> dict[str, typing def _print_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: - # Per-pass labels for `first`/`last` come from the actual layer name on the matching row. sample = next(iter(results.values())) - endpoint_labels: dict[tuple[str, str], str] = { - ("fw", "first"): "first", - ("fw", "last"): "last", - ("bw", "first"): "first", - ("bw", "last"): "last", - } - for kind in ("fw", "bw"): - group = [r for r in sample if r["kind"] == kind] - if group: - endpoint_labels[(kind, "first")] = _layer_name(group[0]["tensor_name"]) - endpoint_labels[(kind, "last")] = _layer_name(group[-1]["tensor_name"]) - mid_labels = {"median": "mid med", "max": "mid max", "logits": "logits"} - # Logits show up via `output_hidden_states` (`Global : head.logits` on the fw side and - # `Global : head.logits.grad` on the bw side once the loss has computed dL/dlogits). - # Each gets a dedicated column placed chronologically: end-of-fw and start-of-bw. has_fw_logits = _named_row(sample, "head.logits") is not None has_bw_logits = _named_row(sample, "head.logits.grad") is not None + # Each kind's aggregation columns are listed chronologically (left-to-right matches + # the order tensors are logged). Logits show up via `output_hidden_states` on the + # fw/bw boundary; weight gradients have no logits hook. fw_aggs = ("first", "median", "max") + (("logits",) if has_fw_logits else ()) + ("last",) bw_aggs = ("first",) + (("logits",) if has_bw_logits else ()) + ("median", "max", "last") - aggs_per_kind = {"fw": fw_aggs, "bw": bw_aggs} - kinds = ("fw", "bw") + grad_aggs = ("first", "median", "max", "last") + aggs_per_kind = {"fw": fw_aggs, "bw": bw_aggs, "grad": grad_aggs} + for kind in ("fw", "bw", "grad"): + _print_summary_table(results, kind, aggs_per_kind[kind]) + - def _label(kind: str, agg: str) -> str: - return endpoint_labels[(kind, agg)] if agg in ("first", "last") else mid_labels[agg] +def _print_summary_table(results: dict[str, list[dict[str, typing.Any]]], kind: str, aggs: tuple[str, ...]) -> None: + sample = next(iter(results.values())) + group = [r for r in sample if r["kind"] == kind] + if not group: + return + endpoint_labels = { + "first": _layer_name(group[0]["tensor_name"]), + "last": _layer_name(group[-1]["tensor_name"]), + } + mid_labels = {"median": "mid med", "max": "mid max", "logits": "logits"} + + def _label(agg: str) -> str: + return endpoint_labels[agg] if agg in endpoint_labels else mid_labels[agg] name_width = max((len(name) for name in results), default=7) + 1 - cell_width = max(len(_label(k, a)) for k in kinds for a in aggs_per_kind[k]) + cell_width = max(len(_label(a)) for a in aggs) cell_sep = " " - group_sep = " " - group_widths = { - kind: len(cell_sep.join(f"{_label(kind, a):<{cell_width}}" for a in aggs_per_kind[kind])) for kind in kinds - } - print("\n=== Summary (Relative %; mid = excluding first/last) ===") - top = f"{'':<{name_width}}" + group_sep.join(f"{kind:^{group_widths[kind]}}" for kind in kinds) - bottom = f"{'Variant':<{name_width}}" + group_sep.join( - cell_sep.join(f"{_label(kind, a):<{cell_width}}" for a in aggs_per_kind[kind]) for kind in kinds - ) - print(top) - print(bottom) - print("-" * len(bottom)) - # Collect raw values first so we can pick a per-column decimal count: keep the previous - # .3f% default, but bump up just enough to give every cell in a column ≥ 2 sig figs. - raw: dict[str, dict[tuple[str, str], float | None]] = {} + raw: dict[str, dict[str, float | None]] = {} for name, rows in results.items(): logits_fw = _named_row(rows, "head.logits") logits_bw = _named_row(rows, "head.logits.grad") @@ -275,44 +265,36 @@ def _label(kind: str, agg: str) -> str: "fw": logits_fw["rms_rel"] if logits_fw else float("nan"), "bw": logits_bw["rms_rel"] if logits_bw else float("nan"), } - cells: dict[tuple[str, str], float | None] = {} - for kind in kinds: - values = [r["rms_rel"] for r in rows if r["kind"] == kind] - intermediate = values[1:-1] or values - for agg in aggs_per_kind[kind]: - if not values: - cells[(kind, agg)] = None - continue - if agg == "first": - cells[(kind, agg)] = values[0] - elif agg == "last": - cells[(kind, agg)] = values[-1] - elif agg == "logits": - cells[(kind, agg)] = logits_value[kind] - elif agg == "max": - cells[(kind, agg)] = max(intermediate) - else: - cells[(kind, agg)] = statistics.median(intermediate) + values = [r["rms_rel"] for r in rows if r["kind"] == kind] + intermediate = values[1:-1] or values + cells: dict[str, float | None] = {} + for agg in aggs: + if not values: + cells[agg] = None + elif agg == "first": + cells[agg] = values[0] + elif agg == "last": + cells[agg] = values[-1] + elif agg == "logits": + cells[agg] = logits_value[kind] + elif agg == "max": + cells[agg] = max(intermediate) + else: + cells[agg] = statistics.median(intermediate) raw[name] = cells - column_decimals: dict[tuple[str, str], int] = {} - for kind in kinds: - for agg in aggs_per_kind[kind]: - column_decimals[(kind, agg)] = _column_decimals( - cells[(kind, agg)] for cells in raw.values() if cells[(kind, agg)] is not None - ) + column_decimals = { + agg: _column_decimals(cells[agg] for cells in raw.values() if cells[agg] is not None) for agg in aggs + } + print(f"\n=== Summary: {kind} (Relative %; mid = excluding first/last) ===") + header = f"{'Variant':<{name_width}}" + cell_sep.join(f"{_label(a):<{cell_width}}" for a in aggs) + print(header) + print("-" * len(header)) for name, cells in raw.items(): - groups = [] - for kind in kinds: - formatted = [] - for agg in aggs_per_kind[kind]: - value = cells[(kind, agg)] - if value is None: - formatted.append("n/a") - else: - formatted.append(f"{value * 100:.{column_decimals[(kind, agg)]}f}%") - groups.append(cell_sep.join(f"{c:<{cell_width}}" for c in formatted)) - print(f"{name:<{name_width}}" + group_sep.join(groups)) + formatted = [ + f"{cells[agg] * 100:.{column_decimals[agg]}f}%" if cells[agg] is not None else "n/a" for agg in aggs + ] + print(f"{name:<{name_width}}" + cell_sep.join(f"{c:<{cell_width}}" for c in formatted)) def _column_decimals( From 4f47dc045bab9def7f4e6a0286ea253c5c8262f0 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 15:26:01 -0400 Subject: [PATCH 29/31] Split grad summary by parameter category Replace the chronological first/last columns in the grad table with named lookups (lm_head / embeddings) and split the intermediate aggregation by category: linear weights, norm weights, biases. The bias columns appear only when biases exist. lm_head shows n/a when the LM head weight is tied to the embedding (e.g. SmolLM2), since the combined gradient is recorded under the embedding parameter. Co-Authored-By: Claude Opus 4.7 (1M context) --- tools/evaluate_precision.py | 80 ++++++++++++++++++++++++++++++++----- 1 file changed, 69 insertions(+), 11 deletions(-) diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index 236871b35..abbc1e2da 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -225,21 +225,41 @@ def _named_row(rows: list[dict[str, typing.Any]], name: str) -> dict[str, typing return next((r for r in rows if r["tensor_name"].split(":", 1)[-1].strip() == name), None) +_LM_HEAD_NAME = "head.output_weights" +_EMBEDDINGS_NAME = "embeddings.word_embeddings_weight" + + def _print_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: sample = next(iter(results.values())) has_fw_logits = _named_row(sample, "head.logits") is not None has_bw_logits = _named_row(sample, "head.logits.grad") is not None + has_bias = any( + r["kind"] == "grad" and r["tensor_name"].split(":", 1)[-1].strip().endswith(".bias") for r in sample + ) # Each kind's aggregation columns are listed chronologically (left-to-right matches # the order tensors are logged). Logits show up via `output_hidden_states` on the # fw/bw boundary; weight gradients have no logits hook. fw_aggs = ("first", "median", "max") + (("logits",) if has_fw_logits else ()) + ("last",) bw_aggs = ("first",) + (("logits",) if has_bw_logits else ()) + ("median", "max", "last") - grad_aggs = ("first", "median", "max", "last") + grad_aggs = ( + ("lm_head", "linear_med", "linear_max", "norm_med", "norm_max") + + (("bias_med", "bias_max") if has_bias else ()) + + ("embeddings",) + ) aggs_per_kind = {"fw": fw_aggs, "bw": bw_aggs, "grad": grad_aggs} for kind in ("fw", "bw", "grad"): _print_summary_table(results, kind, aggs_per_kind[kind]) +def _grad_category(tensor_name: str) -> str: + name = tensor_name.split(":", 1)[-1].strip() + if name.endswith(".bias"): + return "bias" + if ".norm_" in name or name.endswith(".norm.weight"): + return "norm" + return "linear" + + def _print_summary_table(results: dict[str, list[dict[str, typing.Any]]], kind: str, aggs: tuple[str, ...]) -> None: sample = next(iter(results.values())) group = [r for r in sample if r["kind"] == kind] @@ -249,7 +269,19 @@ def _print_summary_table(results: dict[str, list[dict[str, typing.Any]]], kind: "first": _layer_name(group[0]["tensor_name"]), "last": _layer_name(group[-1]["tensor_name"]), } - mid_labels = {"median": "mid med", "max": "mid max", "logits": "logits"} + mid_labels = { + "median": "mid med", + "max": "mid max", + "logits": "logits", + "lm_head": "lm head", + "embeddings": "embeddings", + "linear_med": "linear med", + "linear_max": "linear max", + "norm_med": "norm med", + "norm_max": "norm max", + "bias_med": "bias med", + "bias_max": "bias max", + } def _label(agg: str) -> str: return endpoint_labels[agg] if agg in endpoint_labels else mid_labels[agg] @@ -265,28 +297,54 @@ def _label(agg: str) -> str: "fw": logits_fw["rms_rel"] if logits_fw else float("nan"), "bw": logits_bw["rms_rel"] if logits_bw else float("nan"), } - values = [r["rms_rel"] for r in rows if r["kind"] == kind] + kind_rows = [r for r in rows if r["kind"] == kind] + values = [r["rms_rel"] for r in kind_rows] + if kind == "grad": + decoder_rows = [r for r in kind_rows if r["tensor_name"].split(":", 1)[-1].strip().startswith("decoder.")] + category_values: dict[str, list[float]] = {"linear": [], "norm": [], "bias": []} + for r in decoder_rows: + category_values[_grad_category(r["tensor_name"])].append(r["rms_rel"]) + lm_head_row = _named_row(kind_rows, _LM_HEAD_NAME) + embeddings_row = _named_row(kind_rows, _EMBEDDINGS_NAME) + else: + category_values = {} + lm_head_row = embeddings_row = None intermediate = values[1:-1] or values cells: dict[str, float | None] = {} for agg in aggs: - if not values: - cells[agg] = None - elif agg == "first": - cells[agg] = values[0] + if agg == "first": + cells[agg] = values[0] if values else None elif agg == "last": - cells[agg] = values[-1] + cells[agg] = values[-1] if values else None elif agg == "logits": cells[agg] = logits_value[kind] + elif agg == "lm_head": + cells[agg] = lm_head_row["rms_rel"] if lm_head_row else None + elif agg == "embeddings": + cells[agg] = embeddings_row["rms_rel"] if embeddings_row else None + elif "_" in agg and agg.split("_", 1)[0] in category_values: + cat, stat = agg.split("_", 1) + cat_values = category_values[cat] + if not cat_values: + cells[agg] = None + elif stat == "max": + cells[agg] = max(cat_values) + else: + cells[agg] = statistics.median(cat_values) elif agg == "max": - cells[agg] = max(intermediate) + cells[agg] = max(intermediate) if intermediate else None else: - cells[agg] = statistics.median(intermediate) + cells[agg] = statistics.median(intermediate) if intermediate else None raw[name] = cells column_decimals = { agg: _column_decimals(cells[agg] for cells in raw.values() if cells[agg] is not None) for agg in aggs } - print(f"\n=== Summary: {kind} (Relative %; mid = excluding first/last) ===") + if kind == "grad": + subtitle = " (Relative %)" + else: + subtitle = " (Relative %; mid = excluding first/last)" + print(f"\n=== Summary: {kind}{subtitle} ===") header = f"{'Variant':<{name_width}}" + cell_sep.join(f"{_label(a):<{cell_width}}" for a in aggs) print(header) print("-" * len(header)) From 5198c2551ebef9e6f53d273243037daf13c2763c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 28 May 2026 15:50:18 -0400 Subject: [PATCH 30/31] Per-tensor sample-density overrides in TensorLogsConfig Add `sample_level_overrides: dict[str, int]` (regex pattern -> level) to `TensorLogsConfig`. `log_tensor` raises the effective level for any tensor whose logged name matches a pattern, so callers can collect more samples for specific tensors without changing the default. Useful for sparsely-non-zero tensors like embedding-weight gradients, where the default uniform stride misses every non-zero row. evaluate_precision: switch `num_samples` to actually drive the level (was only cropping the text log), bump default to 8192, default sequence length to 2048 in the example yamls, and add a 1M-sample override for `Global gradient: embeddings.*` to make embedding-grad errors measurable on small batches. Co-Authored-By: Claude Opus 4.7 (1M context) --- examples/evaluate_precision/smol.yaml | 3 +-- examples/evaluate_precision/smol_gspo.yaml | 3 +-- fast_llm/engine/config_utils/logging.py | 9 +++++++ fast_llm/logging.py | 9 +++++++ tools/evaluate_precision.py | 29 +++++++++++++++------- 5 files changed, 40 insertions(+), 13 deletions(-) diff --git a/examples/evaluate_precision/smol.yaml b/examples/evaluate_precision/smol.yaml index 8e052cbef..cf0f8554b 100644 --- a/examples/evaluate_precision/smol.yaml +++ b/examples/evaluate_precision/smol.yaml @@ -9,8 +9,7 @@ pretrained: path: HuggingFaceTB/SmolLM2-135M format: llama output_dir: /tmp/fast_llm_tests/evaluate_precision/features -sequence_length: 128 -num_samples: 512 +sequence_length: 2048 variants: # Baseline bf16: compute_dtype=bf16 + Fast-LLM defaults (fp32 gradient accumulation, bf16 residual, bf16 lm_head). bf16: diff --git a/examples/evaluate_precision/smol_gspo.yaml b/examples/evaluate_precision/smol_gspo.yaml index c64276bdd..5e3545573 100644 --- a/examples/evaluate_precision/smol_gspo.yaml +++ b/examples/evaluate_precision/smol_gspo.yaml @@ -17,8 +17,7 @@ model: type: gspo output_dir: /tmp/fast_llm_tests/evaluate_precision/gspo data_path: /tmp/fast_llm_tests/evaluate_precision/gspo_data -sequence_length: 128 -num_samples: 512 +sequence_length: 2048 variants: bf16: model.distributed.compute_dtype: bfloat16 diff --git a/fast_llm/engine/config_utils/logging.py b/fast_llm/engine/config_utils/logging.py index 32deb4562..b82d4c847 100644 --- a/fast_llm/engine/config_utils/logging.py +++ b/fast_llm/engine/config_utils/logging.py @@ -76,6 +76,15 @@ class TensorLogsConfig(Config): valid=check_field(Assert.gt, 0), ) full_tensors: bool = Field(default=False, desc="Save and/or print entire tensors.") + sample_level_overrides: dict[str, int] = Field( + default_factory=dict, + desc="Per-tensor sample-density overrides (regex pattern -> level)." + " For tensors whose logged name matches a pattern, the effective `log_tensor` level is" + " raised to the matching override (samples = 2 ** (level - 3))." + " Useful for sparse tensors like embedding-weight gradients where the default sampling" + " stride misses most non-zero rows.", + hint=FieldHint.logging, + ) class TensorLogs: diff --git a/fast_llm/logging.py b/fast_llm/logging.py index 2619883d6..6326e7e4b 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -131,6 +131,15 @@ def log_tensor[T]( ) -> T | None: if level < 1: return + # Per-tensor sample-density override: lets users boost the effective level for specific + # tensors (e.g. sparse embedding-weight gradients) via `TensorLogsConfig`. + overrides = TensorLogs.config.sample_level_overrides if TensorLogs.config else None + if overrides: + import re + + for pattern, override in overrides.items(): + if re.search(pattern, name): + level = max(level, override) tensor = tensor.detach() if tensor.ndim == 0: tensor = tensor[None] diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index abbc1e2da..02131cd6a 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -20,11 +20,14 @@ logger = logging.getLogger(__name__) -# Tensor-log verbosity level. 13 gives 2**(13-3)=1024 sampled values per tensor, -# matching the convention in the existing layer-comparison tests. -_LOG_LEVEL = 13 _REFERENCE_NAME = "reference" _MODEL_TYPE = "gpt" +# Embedding-weight gradients are row-sparse (only input-token rows non-zero), so a +# uniformly-spaced sample of vocab_size entries usually misses all of them. The pattern +# is applied via `TensorLogsConfig.sample_level_overrides` and picked up inside +# `log_tensor` (samples = 2 ** (level - 3) -> level 23 yields ~1M samples per tensor). +_SPARSE_GRAD_LEVEL = 23 +_SPARSE_GRAD_OVERRIDES = {r"Global gradient: embeddings\.": _SPARSE_GRAD_LEVEL} @config_class() @@ -48,8 +51,10 @@ class EvaluatePrecisionConfig(PretrainedGPTModelConfig, RunnableConfig): hint=FieldHint.core, ) num_samples: int = Field( - default=1024, - desc="Number of sampled values stored per logged tensor.", + default=8192, + desc="Number of sampled values stored per logged tensor (rounded up to next power of 2)." + " Sparse tensors (e.g. embedding-weight gradients) get a higher level via" + " `TensorLogsConfig.sample_level_overrides`.", hint=FieldHint.feature, ) micro_batch_size: int = Field( @@ -150,9 +155,15 @@ def _run_one(self, name: str, variant_overrides: dict[str, typing.Any]) -> None: }, "run": { "experiment_dir": str((self.output_dir / name).resolve()), - "tensor_logs": {"save": True, "show": False, "max_elements": self.num_samples}, + "tensor_logs": { + "save": True, + "show": False, + "sample_level_overrides": _SPARSE_GRAD_OVERRIDES, + }, }, } + # Translate `num_samples` to a `log_tensor` level: 2**(level-3) = samples. + log_level = math.ceil(math.log2(max(self.num_samples, 1))) + 3 fp32_dtypes = { ("model", "distributed", "compute_dtype"): "float32", ("model", "distributed", "optimization_dtype"): "float32", @@ -160,9 +171,9 @@ def _run_one(self, name: str, variant_overrides: dict[str, typing.Any]) -> None: variant_updates = {tuple(key.split(".")): value for key, value in variant_overrides.items()} # Tool-required overrides win over variants — a variant must not silently disable tensor logging. tool_overrides = { - ("model", "multi_stage", "debug_layer_outputs"): _LOG_LEVEL, - ("model", "multi_stage", "debug_layer_gradients"): _LOG_LEVEL, - ("model", "multi_stage", "debug_all_param_gradients"): _LOG_LEVEL, + ("model", "multi_stage", "debug_layer_outputs"): log_level, + ("model", "multi_stage", "debug_layer_gradients"): log_level, + ("model", "multi_stage", "debug_all_param_gradients"): log_level, # Capture the LM-head logits via the `output_hidden_states` mechanism: the head's # `_debug(logits, ...)` call matches this pattern and emits to `tensor_logs`. ("model", "multi_stage", "debug_hidden_states_log"): [r"head\.logits"], From 312343e7cdca3b56699f4494c25b0908bfe7d447 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 29 May 2026 15:06:16 -0400 Subject: [PATCH 31/31] Chosen-logprob loss, per-variant grad-scale auto-calibration, fp16 variants MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - New `chosen_logprob` LM loss: logs `log_softmax(logits)[label]` per position with no gradient contribution. Tool auto-adds it and surfaces a dedicated summary with bias, correlation, slope, and residual-after-linear-fit. - `_compute_diff` reports bias_abs/rel, correlation, slope, residual_rms_abs/rel — the linear decomposition separates systematic shift/scale from per-position noise. - Per-variant auto-calibrated power-of-2 gradient scale: each variant runs a calibration pass at scale=1 to measure max unscaled gradient, then the real run picks the largest power-of-2 scale that fits within fp16 range (with a small safety factor for fused-kernel partial sums). `_compare` unscales per variant. - Tool: backend-override mechanism (`_torch_backend.*`) and `_torch_matmul_precision` variant keys for diagnostic variants. New variants: `bf16_in_fp32_out` (probes whether `fp32_lm_head`'s gain is from output dtype vs matmul precision), `bf16_reduced_reduction` (probes the split-K reduction path), and a full fp16 sweep mirroring the bf16 variants. - Fix: `data.micro_batch_size` in Fast-LLM is the per-sample sequence length, not the batch dim. Tool was passing 1 → every prior run was on 1-token inputs. Co-Authored-By: Claude Opus 4.7 (1M context) --- examples/evaluate_precision/smol.yaml | 30 +++ examples/evaluate_precision/smol_gspo.yaml | 19 ++ .../config_utils/compare_tensor_logs.py | 26 +- .../language_model/loss/chosen_logprob.py | 41 ++++ fast_llm/layers/language_model/loss/config.py | 25 ++ tools/evaluate_precision.py | 222 ++++++++++++++++-- 6 files changed, 342 insertions(+), 21 deletions(-) create mode 100644 fast_llm/layers/language_model/loss/chosen_logprob.py diff --git a/examples/evaluate_precision/smol.yaml b/examples/evaluate_precision/smol.yaml index cf0f8554b..cc17c19e0 100644 --- a/examples/evaluate_precision/smol.yaml +++ b/examples/evaluate_precision/smol.yaml @@ -27,3 +27,33 @@ variants: model.distributed.compute_dtype: bfloat16 model.base_model.embeddings.full_precision_residual: true model.base_model.head.fp32_lm_head: true + # Diagnostic: enable bf16 reduced-precision reductions in cuBLAS GEMMs. Tests whether the + # within-engine bf16-vs-fp32 gap is sensitive to the partial-sum reduction precision (the + # MMA accumulator is fp32 by hardware on H100/A100; this flag affects split-K reductions). + bf16_reduced_reduction: + model.distributed.compute_dtype: bfloat16 + _torch_backend.cuda.matmul.allow_bf16_reduced_precision_reduction: true + # Diagnostic: simulate a "bf16 inputs, fp32 output" lm-head matmul kernel. fp32_lm_head=True + # upcasts inputs+weights to fp32, then matmul_precision='medium' runs the matmul through + # bf16 Tensor Cores anyway, then logits stay fp32. Tests whether fp32_lm_head's gain comes + # from input precision or from skipping the bf16 output cast. + bf16_in_fp32_out: + model.distributed.compute_dtype: bfloat16 + model.base_model.head.fp32_lm_head: true + _torch_matmul_precision: medium + # fp16 sweep: probes whether the precision-vs-noise picture (rms noise ~0.1 nats per token + # for bf16) shrinks ~8× for fp16 (10 mantissa bits vs 7), as the literature's "switch to + # fp16" recommendation implies. Default dynamic grad-scaler (initial 2^16) is uniform + # across variants, so relative comparisons stay meaningful. + fp16: + model.distributed.compute_dtype: float16 + fp16_fp32_residual: + model.distributed.compute_dtype: float16 + model.base_model.embeddings.full_precision_residual: true + fp16_fp32_lm_head: + model.distributed.compute_dtype: float16 + model.base_model.head.fp32_lm_head: true + fp16_max_precision: + model.distributed.compute_dtype: float16 + model.base_model.embeddings.full_precision_residual: true + model.base_model.head.fp32_lm_head: true diff --git a/examples/evaluate_precision/smol_gspo.yaml b/examples/evaluate_precision/smol_gspo.yaml index 5e3545573..b0e8e319d 100644 --- a/examples/evaluate_precision/smol_gspo.yaml +++ b/examples/evaluate_precision/smol_gspo.yaml @@ -31,3 +31,22 @@ variants: model.distributed.compute_dtype: bfloat16 model.base_model.embeddings.full_precision_residual: true model.base_model.head.fp32_lm_head: true + bf16_reduced_reduction: + model.distributed.compute_dtype: bfloat16 + _torch_backend.cuda.matmul.allow_bf16_reduced_precision_reduction: true + bf16_in_fp32_out: + model.distributed.compute_dtype: bfloat16 + model.base_model.head.fp32_lm_head: true + _torch_matmul_precision: medium + fp16: + model.distributed.compute_dtype: float16 + fp16_fp32_residual: + model.distributed.compute_dtype: float16 + model.base_model.embeddings.full_precision_residual: true + fp16_fp32_lm_head: + model.distributed.compute_dtype: float16 + model.base_model.head.fp32_lm_head: true + fp16_max_precision: + model.distributed.compute_dtype: float16 + model.base_model.embeddings.full_precision_residual: true + model.base_model.head.fp32_lm_head: true diff --git a/fast_llm/engine/config_utils/compare_tensor_logs.py b/fast_llm/engine/config_utils/compare_tensor_logs.py index 080510036..dbad78a25 100644 --- a/fast_llm/engine/config_utils/compare_tensor_logs.py +++ b/fast_llm/engine/config_utils/compare_tensor_logs.py @@ -100,8 +100,24 @@ def _compute_diff(self, tensor_ref, tensor_test, step_name, tensor_name) -> dict samples_test = samples_test / sub_config.scale scale_unreg = (samples_ref**2).mean() ** 0.5 rms_scale = (scale_unreg**2 + sub_config.rms_eps**2) ** 0.5 - rms = ((samples_ref - samples_test) ** 2).mean() ** 0.5 - max_diff = (samples_ref - samples_test).abs().max() + diff = samples_test - samples_ref + rms = (diff**2).mean() ** 0.5 + max_diff = diff.abs().max() + bias = diff.mean() + # Linear-regression decomposition: `test ≈ slope * ref + intercept + residual`. + # Useful for separating systematic distortion (slope ≠ 1) from per-position decorrelated + # noise (residual). For RL importance ratios, slope ≠ 1 indicates likely-token-dependent + # bias which is more dangerous than a uniform shift. + centered_test = samples_test - samples_test.mean() + centered_ref = samples_ref - samples_ref.mean() + var_ref = (centered_ref**2).mean() + var_test = (centered_test**2).mean() + cov = (centered_test * centered_ref).mean() + denom = (var_test * var_ref) ** 0.5 + correlation = (cov / denom).item() if denom > 0 else float("nan") + slope = (cov / var_ref).item() if var_ref > 0 else float("nan") + residual_var = (var_test - cov**2 / var_ref).clamp(min=0.0) if var_ref > 0 else var_test + residual_rms = residual_var**0.5 return { "rms_abs": rms.item(), "rms_rel": (rms / rms_scale).item(), @@ -109,6 +125,12 @@ def _compute_diff(self, tensor_ref, tensor_test, step_name, tensor_name) -> dict "max_rel": (max_diff / rms_scale).item(), "ref_scale": scale_unreg.item(), "ref_scale_regularized": rms_scale.item(), + "bias_abs": bias.item(), + "bias_rel": (bias / rms_scale).item(), + "correlation": correlation, + "slope": slope, + "residual_rms_abs": residual_rms.item(), + "residual_rms_rel": (residual_rms / rms_scale).item(), } def compare_tensors(self, tensor_ref, tensor_test, errors, step_name, tensor_name): diff --git a/fast_llm/layers/language_model/loss/chosen_logprob.py b/fast_llm/layers/language_model/loss/chosen_logprob.py new file mode 100644 index 000000000..cb99e7c17 --- /dev/null +++ b/fast_llm/layers/language_model/loss/chosen_logprob.py @@ -0,0 +1,41 @@ +import math +import typing + +import torch + +from fast_llm.layers.language_model.loss.config import LanguageModelChosenLogprobLossConfig +from fast_llm.layers.language_model.loss.loss import LanguageModelLoss +from fast_llm.logging import log_tensor + + +class LanguageModelChosenLogprobLoss[ConfigType: LanguageModelChosenLogprobLossConfig](LanguageModelLoss[ConfigType]): + """Logs log π(label) per position via the tensor-log pipeline; contributes nothing to gradients.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + # Don't surface a "chosen_logprob: 0" line in the training metrics. + self._do_register_loss = False + + def _forward_backward( + self, + logits: "torch.Tensor", + kwargs: dict[str, typing.Any], + losses: dict | None = None, + split_index: int = 0, + grad_logits: torch.Tensor | None = None, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + if self._vocab_parallel: + raise NotImplementedError("chosen_logprob loss does not support vocab parallel") + labels = self._get_labels(kwargs, split_index).reshape(-1).long() + with torch.no_grad(): + log_probs = torch.log_softmax(logits.float() * self._logits_scale_factor, dim=-1) + # Mask out-of-range labels (e.g. -100 for prompt tokens in RL data) before gather to + # avoid CUDA assert. Fast-LLM convention: any label < 0 is masked. + valid = labels >= 0 + safe_labels = labels.clamp(min=0) + chosen_logprob = log_probs.gather(-1, safe_labels.unsqueeze(-1)).squeeze(-1) + chosen_logprob = chosen_logprob[valid] + # Capture the full tensor: bias is the mean over all positions, not a sampled subset. + level = math.ceil(math.log2(max(chosen_logprob.numel(), 1))) + 3 + log_tensor(f"Global : {self._name}", chosen_logprob, level=level) + return torch.zeros((), dtype=logits.dtype, device=logits.device), grad_logits diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 9a220aacf..aa05fbb9a 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -9,6 +9,7 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: + from fast_llm.layers.language_model.loss.chosen_logprob import LanguageModelChosenLogprobLoss from fast_llm.layers.language_model.loss.dpo import LanguageModelDPOLoss from fast_llm.layers.language_model.loss.entropy_loss import ( LanguageModelDistillationLoss, @@ -186,6 +187,30 @@ def get_reference_models(self) -> set[str]: return {self.reference_model} +@config_class(dynamic_type={LanguageModelLossConfig: "chosen_logprob"}) +class LanguageModelChosenLogprobLossConfig(LanguageModelLossConfig): + """No-gradient diagnostic loss that logs log π(label) per position via the tensor-log pipeline. + + The chosen-token log-prob is the scalar that policy-gradient importance ratios depend on, + so its precision drift is a more direct signal than bulk-logit RMS. + """ + + _abstract: typing.ClassVar[bool] = False + + weight: float = Field( + default=0.0, + hint=FieldHint.derived, + desc="Forced to 0: this loss has no gradient contribution.", + valid=check_field(Assert.eq, 0.0), + ) + + @property + def loss_class(self) -> "type[LanguageModelChosenLogprobLoss]": + from fast_llm.layers.language_model.loss.chosen_logprob import LanguageModelChosenLogprobLoss + + return LanguageModelChosenLogprobLoss + + @config_class(dynamic_type={LanguageModelLossConfig: "z_loss"}) class LanguageModelZLossConfig(LanguageModelLossConfig): """Z-loss regularization to prevent overconfidence.""" diff --git a/tools/evaluate_precision.py b/tools/evaluate_precision.py index 02131cd6a..9da8904a1 100644 --- a/tools/evaluate_precision.py +++ b/tools/evaluate_precision.py @@ -28,6 +28,23 @@ # `log_tensor` (samples = 2 ** (level - 3) -> level 23 yields ~1M samples per tensor). _SPARSE_GRAD_LEVEL = 23 _SPARSE_GRAD_OVERRIDES = {r"Global gradient: embeddings\.": _SPARSE_GRAD_LEVEL} +_CHOSEN_LOGPROB_NAME = "chosen_logprob" +# Auto-calibration of the constant gradient scaler. Each variant runs a calibration pass at +# `scale=1` (no overflow risk), then the actual run uses the largest power-of-2 scale that +# keeps logged gradient magnitudes (and a small safety factor for hidden in-kernel +# intermediates like norm partial sums) within fp16's representable range. Per-variant +# unscaling at compare time lets different variants pick different scales without polluting +# the relative metrics. +_HIDDEN_INTERMEDIATE_HEADROOM = 4.0 # safety factor for fused-kernel partial sums we don't log +_CALIBRATION_SUBDIR_PREFIX = ".calibration_" +# Variant-override keys starting with this prefix are interpreted as `torch.backends.` and +# applied before each run. Used for diagnostics (e.g. enabling bf16 reduced-precision reductions); +# entries are listed in `_TORCH_BACKEND_DEFAULTS` and reset to their defaults before applying. +_TORCH_BACKEND_PREFIX = "_torch_backend." +_TORCH_BACKEND_DEFAULTS = { + "cuda.matmul.allow_bf16_reduced_precision_reduction": False, +} +_TORCH_MATMUL_PRECISION_KEY = "_torch_matmul_precision" @config_class() @@ -57,14 +74,10 @@ class EvaluatePrecisionConfig(PretrainedGPTModelConfig, RunnableConfig): " `TensorLogsConfig.sample_level_overrides`.", hint=FieldHint.feature, ) - micro_batch_size: int = Field( - default=1, - desc="Micro-batch size for the single forward+backward pass.", - hint=FieldHint.feature, - ) sequence_length: int = Field( default=2048, - desc="Sequence length (maximum document length) for the random input.", + desc="Sequence length per micro-batch sample. Drives both `data.micro_batch_size` (the" + " per-sample token count, despite the name) and `data.maximum_document_length`.", hint=FieldHint.feature, ) data_path: pathlib.Path | None = Field( @@ -88,20 +101,50 @@ def run(self) -> None: self._prepare_data() runs: dict[str, dict[str, typing.Any]] = {_REFERENCE_NAME: {}} runs.update(self.variants) + scales: dict[str, float] = {} for name, variant_overrides in runs.items(): - self._run_one(name, variant_overrides) + scales[name] = self._calibrate_and_run(name, variant_overrides) ref_artifacts = self._artifact_path(_REFERENCE_NAME) - results = {name: self._compare(ref_artifacts, self._artifact_path(name)) for name in self.variants} + results = { + name: self._compare(ref_artifacts, self._artifact_path(name), scales[_REFERENCE_NAME], scales[name]) + for name in self.variants + } report_path = self.output_dir / "precision_report.json" - report_path.write_text(json.dumps(results, indent=2)) + report_path.write_text(json.dumps({"scales": scales, "variants": results}, indent=2)) logger.info(f"Wrote report to {report_path}") + logger.info(f"Per-variant gradient scales: {scales}") for name, rows in results.items(): _print_table(name, rows) _print_summary(results) + def _calibrate_and_run(self, name: str, variant_overrides: dict[str, typing.Any]) -> float: + """Pick a power-of-2 gradient scale for this variant via a calibration pass, then run with it. + + Calibration runs with `constant=1.0` so no overflow is possible; scanning logged gradients + then gives us `max_unscaled`. The largest safe power of 2 keeps `scale * max_unscaled` below + `fp16_max / hidden_intermediate_budget`, where the budget reserves headroom for partial sums + inside fused kernels (e.g. norm-weight grads sum over the sequence dimension). + """ + import torch + + cal_dir = self.output_dir / f"{_CALIBRATION_SUBDIR_PREFIX}{name}" + self._run_one(name, variant_overrides, constant_scale=1.0, experiment_dir=cal_dir) + max_unscaled = _scan_max_grad(cal_dir / "runs" / "0" / "artifacts") + shutil.rmtree(cal_dir) + if max_unscaled <= 0.0: + scale = 1.0 + logger.warning(f"[{name}] calibration found no nonzero gradient — falling back to scale=1.0") + else: + fp16_max = torch.finfo(torch.float16).max + optimal_unrounded = fp16_max / max_unscaled / _HIDDEN_INTERMEDIATE_HEADROOM + scale = float(2 ** max(0, math.floor(math.log2(optimal_unrounded)))) + logger.info(f"[{name}] calibration: max_unscaled={max_unscaled:.4e} -> gradient_scaler.constant={scale:g}") + self._run_one(name, variant_overrides, constant_scale=scale) + return scale + def _prepare_data(self) -> None: if self.data_path is None: return @@ -122,15 +165,28 @@ def _prepare_data(self) -> None: def _artifact_path(self, name: str) -> pathlib.Path: return self.output_dir / name / "runs" / "0" / "artifacts" - def _run_one(self, name: str, variant_overrides: dict[str, typing.Any]) -> None: + def _run_one( + self, + name: str, + variant_overrides: dict[str, typing.Any], + *, + constant_scale: float | None = None, + experiment_dir: pathlib.Path | None = None, + ) -> None: # The trainer's Run picks the next `runs/` subdir based on what already exists; wipe # any prior contents so each invocation lands in `runs/0` and stale artifacts can't be # read by `_artifact_path` below. - experiment_dir = self.output_dir / name + if experiment_dir is None: + experiment_dir = self.output_dir / name if experiment_dir.exists(): shutil.rmtree(experiment_dir) # Base config: hardcoded training/optimizer/data/run skeleton plus the user's model/pretrained. # Forced fp32 on the reference baseline lives in here too so a variant can override it. + optimizer_config: dict[str, typing.Any] = { + "learning_rate": {"base": 0.0, "decay_style": "constant", "warmup_iterations": 0}, + } + if constant_scale is not None: + optimizer_config["gradient_scaler"] = {"constant": float(constant_scale)} base_dict: dict[str, typing.Any] = { "pretrained": self.pretrained.to_dict(), "model": self.model.to_dict(), @@ -139,9 +195,7 @@ def _run_one(self, name: str, variant_overrides: dict[str, typing.Any]) -> None: "num_workers": 0, "logs": {"interval": 1}, }, - "optimizer": { - "learning_rate": {"base": 0.0, "decay_style": "constant", "warmup_iterations": 0}, - }, + "optimizer": optimizer_config, "data": { "datasets": { "training": ( @@ -150,11 +204,13 @@ def _run_one(self, name: str, variant_overrides: dict[str, typing.Any]) -> None: else {"type": "random"} ) }, - "micro_batch_size": self.micro_batch_size, + # Despite the name, Fast-LLM's `data.micro_batch_size` is the per-sample sequence + # length, not the batch dimension. Default 2048 → 2048-token sample. + "micro_batch_size": self.sequence_length, "maximum_document_length": self.sequence_length, }, "run": { - "experiment_dir": str((self.output_dir / name).resolve()), + "experiment_dir": str(experiment_dir.resolve()), "tensor_logs": { "save": True, "show": False, @@ -168,16 +224,36 @@ def _run_one(self, name: str, variant_overrides: dict[str, typing.Any]) -> None: ("model", "distributed", "compute_dtype"): "float32", ("model", "distributed", "optimization_dtype"): "float32", } - variant_updates = {tuple(key.split(".")): value for key, value in variant_overrides.items()} + # Split off torch-backend overrides before passing the rest to Fast-LLM's config system. + backend_overrides = { + key[len(_TORCH_BACKEND_PREFIX) :]: value + for key, value in variant_overrides.items() + if key.startswith(_TORCH_BACKEND_PREFIX) + } + _apply_torch_backend_overrides(backend_overrides) + matmul_precision = variant_overrides.get(_TORCH_MATMUL_PRECISION_KEY, "highest") + _apply_torch_matmul_precision(matmul_precision) + variant_updates = { + tuple(key.split(".")): value + for key, value in variant_overrides.items() + if not key.startswith(_TORCH_BACKEND_PREFIX) and key != _TORCH_MATMUL_PRECISION_KEY + } # Tool-required overrides win over variants — a variant must not silently disable tensor logging. - tool_overrides = { + tool_overrides: dict[tuple[str, ...], typing.Any] = { ("model", "multi_stage", "debug_layer_outputs"): log_level, ("model", "multi_stage", "debug_layer_gradients"): log_level, ("model", "multi_stage", "debug_all_param_gradients"): log_level, # Capture the LM-head logits via the `output_hidden_states` mechanism: the head's # `_debug(logits, ...)` call matches this pattern and emits to `tensor_logs`. ("model", "multi_stage", "debug_hidden_states_log"): [r"head\.logits"], + # Diagnostic loss that logs log π(label) per position via the tensor-log pipeline. + # Contributes no gradient (weight=0); the comparison code picks it up by name. + ("model", "base_model", "head", "losses", _CHOSEN_LOGPROB_NAME): {"type": "chosen_logprob"}, } + # When the user hasn't configured any loss, the head defaults to cross-entropy. Adding a + # loss explicitly suppresses that default, so re-add it so gradients still flow. + if not (self.model.base_model.head.losses or {}): + tool_overrides[("model", "base_model", "head", "losses", "cross_entropy")] = {"type": "label"} logger.info(f"=== Running {name!r} ===") if variant_overrides: logger.info(f"Variant overrides: {variant_overrides}") @@ -186,13 +262,23 @@ def _run_one(self, name: str, variant_overrides: dict[str, typing.Any]) -> None: trainer_config.configure_logging() trainer_config._get_runnable()() - def _compare(self, ref_path: pathlib.Path, test_path: pathlib.Path) -> list[dict[str, typing.Any]]: + def _compare( + self, + ref_path: pathlib.Path, + test_path: pathlib.Path, + ref_scale: float, + test_scale: float, + ) -> list[dict[str, typing.Any]]: compare_config = CompareConfig() errors: list[str] = [] ref_logs = compare_config._extract_tensor_logs(ref_path, errors) test_logs = compare_config._extract_tensor_logs(test_path, errors) for error in errors: logger.warning(error) + # Each variant's gradient logs are scaled by its own `constant` factor (auto-calibrated). + # Undo per-variant scaling so the relative comparison reflects unscaled gradient diffs. + _unscale_gradients_in_place(ref_logs, ref_scale) + _unscale_gradients_in_place(test_logs, test_scale) rows: list[dict[str, typing.Any]] = [] for step_name in sorted(ref_logs): if step_name not in test_logs: @@ -218,6 +304,66 @@ def _compare(self, ref_path: pathlib.Path, test_path: pathlib.Path) -> list[dict return rows +def _is_gradient_like(tensor_name: str) -> bool: + # Anything affected by the loss-scaling multiplier: parameter gradients from `Fsdp.log_shard`, + # backward activations from layer hooks, and explicit `.grad` debug entries (e.g. logits.grad). + return ("gradient:" in tensor_name) or (" bw" in tensor_name) or (".grad" in tensor_name) + + +def _scan_max_grad(artifact_path: pathlib.Path) -> float: + max_abs = 0.0 + compare_config = CompareConfig() + errors: list[str] = [] + logs = compare_config._extract_tensor_logs(artifact_path, errors) + for step_logs in logs.values(): + for tensor_name, entry in step_logs.items(): + if not _is_gradient_like(tensor_name): + continue + # Saved stats include min/max; fall back to samples if absent. + if "max" in entry and "min" in entry: + value = max(abs(float(entry["max"])), abs(float(entry["min"]))) + else: + value = float(entry["samples"].abs().max().item()) + if math.isfinite(value) and value > max_abs: + max_abs = value + return max_abs + + +def _unscale_gradients_in_place(logs: dict, scale: float) -> None: + if scale == 1.0: + return + inv = 1.0 / scale + for step_logs in logs.values(): + for tensor_name, entry in step_logs.items(): + if not _is_gradient_like(tensor_name): + continue + entry["samples"] = entry["samples"].float() * inv + for key in ("min", "max", "mu", "std"): + if key in entry and entry[key] is not None: + entry[key] = float(entry[key]) * inv + + +def _apply_torch_backend_overrides(overrides: dict[str, typing.Any]) -> None: + import torch + + unknown = set(overrides) - set(_TORCH_BACKEND_DEFAULTS) + if unknown: + logger.warning(f"Unknown torch backend overrides (ignored): {sorted(unknown)}") + for path, default in _TORCH_BACKEND_DEFAULTS.items(): + value = overrides.get(path, default) + obj: typing.Any = torch.backends + parts = path.split(".") + for part in parts[:-1]: + obj = getattr(obj, part) + setattr(obj, parts[-1], value) + + +def _apply_torch_matmul_precision(precision: str) -> None: + import torch + + torch.set_float32_matmul_precision(precision) + + def _layer_name(tensor_name: str) -> str: # Stage hooks name tensors `Global fw: ...` / `Global bw: ...`; # Fsdp.log_shard names weight gradients `Global gradient: `. @@ -260,6 +406,40 @@ def _print_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: aggs_per_kind = {"fw": fw_aggs, "bw": bw_aggs, "grad": grad_aggs} for kind in ("fw", "bw", "grad"): _print_summary_table(results, kind, aggs_per_kind[kind]) + if _named_row(sample, _CHOSEN_LOGPROB_NAME) is not None: + _print_chosen_logprob_summary(results) + + +def _print_chosen_logprob_summary(results: dict[str, list[dict[str, typing.Any]]]) -> None: + rows_by_variant = {name: _named_row(rows, _CHOSEN_LOGPROB_NAME) for name, rows in results.items()} + # log π(label) is the scalar that policy-gradient importance ratios depend on. Bias persists + # under per-document averaging where RMS shrinks ~1/√T, so for RL stability it's the more + # informative signal — surface it alongside RMS, slope and residual. + rms_rel_decimals = _column_decimals((r["rms_rel"] for r in rows_by_variant.values()), default=3, max_decimals=5) + bias_rel_decimals = _column_decimals((r["bias_rel"] for r in rows_by_variant.values()), default=3, max_decimals=5) + resid_rel_decimals = _column_decimals( + (r["residual_rms_rel"] for r in rows_by_variant.values()), default=3, max_decimals=5 + ) + name_width = max((len(name) for name in results), default=7) + 1 + cols = [ + ("RMS rel", lambda r: f"{r['rms_rel'] * 100:.{rms_rel_decimals}f}%"), + ("Bias rel", lambda r: f"{r['bias_rel'] * 100:+.{bias_rel_decimals}f}%"), + ("Resid rel", lambda r: f"{r['residual_rms_rel'] * 100:.{resid_rel_decimals}f}%"), + ("Corr", lambda r: f"{r['correlation']:.5f}"), + ("Slope", lambda r: f"{r['slope']:+.5f}"), + ("Max abs", lambda r: f"{r['max_abs']:.4g}"), + ("Scale", lambda r: f"{r['ref_scale']:.4g}"), + ] + widths = [max(len(label), max(len(fn(r)) for r in rows_by_variant.values())) for label, fn in cols] + print(f"\n=== Summary: chosen_logprob (per-token) ===") + header = f"{'Variant':<{name_width}}" + " ".join( + f"{label:<{w}}" for (label, _), w in zip(cols, widths, strict=True) + ) + print(header) + print("-" * len(header)) + for name, row in rows_by_variant.items(): + cells = [fn(row) for _, fn in cols] + print(f"{name:<{name_width}}" + " ".join(f"{c:<{w}}" for c, w in zip(cells, widths, strict=True))) def _grad_category(tensor_name: str) -> str: @@ -420,10 +600,14 @@ def _print_table(name: str, rows: list[dict[str, typing.Any]]) -> None: # (typical for weight gradients) stay legible, capped at 5 to bound column width. relative_decimals = _column_decimals((r["rms_rel"] for r in rows), default=2, max_decimals=5) relative_fn = lambda r: f"{r['rms_rel'] * 100:.{relative_decimals}f}%" + bias_decimals = _column_decimals((r["bias_rel"] for r in rows), default=2, max_decimals=5) + bias_fn = lambda r: f"{r['bias_rel'] * 100:+.{bias_decimals}f}%" relative_width = max(len("Relative"), max(len(relative_fn(r)) for r in rows)) + bias_width = max(len("Bias"), max(len(bias_fn(r)) for r in rows)) columns: list[tuple[str, int, typing.Callable[[dict[str, typing.Any]], str]]] = [ ("Tensor", name_width, name_fn), ("Relative", relative_width, relative_fn), + ("Bias", bias_width, bias_fn), ("Absolute", 10, lambda r: f"{r['rms_abs']:.4g}"), ("Max", 10, lambda r: f"{r['max_abs']:.4g}"), ("Scale", 10, lambda r: f"{r['ref_scale']:.4g}"),