From d416dc11ed79f97589203023ba86096b6e5d89f6 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Thu, 2 Jul 2026 18:43:53 +0000 Subject: [PATCH] [NNX] Checkpoint restore: name the parameter path on ShapeDtypeStruct mismatch When Orbax skips a parameter on a structural/shape mismatch it leaves an unmaterialized ShapeDtypeStruct in the restored state, which otherwise surfaces as a cryptic compile error deep in the first train_step. Report the offending parameter path and shape from _assert_no_shaped_dtype_struct so the mismatch points straight at the config (emb_dim/mlp_dim/layers/scan_layers). --- src/maxtext/common/checkpointing.py | 24 ++- tests/unit/checkpointing_nnx_load_test.py | 15 ++ ...ckpointing_nnx_sds_error_roundtrip_test.py | 158 ++++++++++++++++++ 3 files changed, 189 insertions(+), 8 deletions(-) create mode 100644 tests/unit/checkpointing_nnx_sds_error_roundtrip_test.py diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py index 0316cd29b5..c7e2891731 100644 --- a/src/maxtext/common/checkpointing.py +++ b/src/maxtext/common/checkpointing.py @@ -740,26 +740,34 @@ def is_structural_or_shape_mismatch(e: Exception) -> bool: return any(kw in msg for kw in mismatch_keywords) -def _assert_no_shaped_dtype_struct(pytree): - """Asserts that there are no jax.ShapeDtypeStruct leaves in the restored pytree.""" +def _assert_no_shaped_dtype_struct(pytree, path=()): + """Asserts that there are no jax.ShapeDtypeStruct leaves in the restored pytree. + + A surviving ShapeDtypeStruct means Orbax skipped that parameter on a + structural/shape mismatch; the path and shape point straight at the offending + config (e.g. emb_dim / mlp_dim / layers) vs. the checkpoint. + """ if isinstance(pytree, jax.ShapeDtypeStruct): + key = "/".join(str(p) for p in path) or "" raise ValueError( - "Some parameters in the restored state remained as ShapeDtypeStruct" - f" (indicating structure mismatch): {pytree}." + f"Parameter '{key}' remained an unmaterialized ShapeDtypeStruct after restore" + f" ({pytree.shape} {pytree.dtype}), meaning the checkpoint had no matching array." + " This is usually a config/checkpoint shape mismatch -- verify emb_dim, mlp_dim," + " num layers, and scan_layers match the saved checkpoint." ) if hasattr(pytree, "keys") and hasattr(pytree, "__getitem__"): for k in pytree.keys(): - _assert_no_shaped_dtype_struct(pytree[k]) + _assert_no_shaped_dtype_struct(pytree[k], path + (k,)) elif isinstance(pytree, (list, tuple)): - for v in pytree: - _assert_no_shaped_dtype_struct(v) + for i, v in enumerate(pytree): + _assert_no_shaped_dtype_struct(v, path + (i,)) else: leaves = jax.tree_util.tree_leaves(pytree) if len(leaves) == 1 and leaves[0] is pytree: return for leaf in leaves: - _assert_no_shaped_dtype_struct(leaf) + _assert_no_shaped_dtype_struct(leaf, path) @contextlib.contextmanager diff --git a/tests/unit/checkpointing_nnx_load_test.py b/tests/unit/checkpointing_nnx_load_test.py index 5af3f9b0b8..fa439530a6 100644 --- a/tests/unit/checkpointing_nnx_load_test.py +++ b/tests/unit/checkpointing_nnx_load_test.py @@ -265,5 +265,20 @@ def test_linen_layout_params_restore_into_nnx_state(self): self.assertTrue(jnp.array_equal(pure["linear"]["bias"], weights["linear"]["bias"])) +class TestShapeDtypeStructPath(unittest.TestCase): + """The SDS guard must name the offending parameter path.""" + + def test_names_path_of_surviving_sds(self): + tree = {"model": {"linear": {"kernel": jax.ShapeDtypeStruct((2, 1), jnp.float32)}}} + with self.assertRaises(ValueError) as ctx: + checkpointing._assert_no_shaped_dtype_struct(tree) # pylint: disable=protected-access + self.assertIn("model/linear/kernel", str(ctx.exception)) + self.assertIn("(2, 1)", str(ctx.exception)) + + def test_passes_for_concrete_arrays(self): + tree = {"model": {"linear": {"kernel": jnp.ones((2, 1))}}} + checkpointing._assert_no_shaped_dtype_struct(tree) # pylint: disable=protected-access + + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/checkpointing_nnx_sds_error_roundtrip_test.py b/tests/unit/checkpointing_nnx_sds_error_roundtrip_test.py new file mode 100644 index 0000000000..5cc309281c --- /dev/null +++ b/tests/unit/checkpointing_nnx_sds_error_roundtrip_test.py @@ -0,0 +1,158 @@ +# Copyright 2025-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end proof that the ShapeDtypeStruct restore error names its parameter path. + +Drives the real create_orbax_checkpoint_manager -> maybe_save_checkpoint -> +load_state_if_possible stack (no mocks). A weight present in the model but absent +from the checkpoint survives Orbax's partial restore as an unmaterialized +ShapeDtypeStruct; the guard must fail naming the exact parameter path and shape, +instead of letting the placeholder reach the first train_step as a cryptic compile +error. + +Scope: this covers the *surviving-SDS* case (a weight with no matching array on +disk). A same-key/different-shape mismatch (e.g. a changed emb_dim on an existing +weight) does not reach this guard -- Orbax raises during restore and that path is +reported by the scan_layers mismatch handler. +""" + +import os +import shutil +import tempfile +import unittest +from types import SimpleNamespace + +from flax import nnx +import jax +from maxtext.common import checkpointing +from maxtext.common import train_state_nnx +import optax + + +class _Model(nnx.Module): + """Linear + dropout, so the state carries rngs/dropout that gets stripped on save.""" + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 1, rngs=rngs) + self.dropout = nnx.Dropout(rate=0.5, rngs=rngs) + + def __call__(self, x, deterministic=False): + return self.dropout(self.linear(x), deterministic=deterministic) + + +class _ModelExtraLayer(nnx.Module): + """`_Model` plus a layer absent from a `_Model` checkpoint, i.e. a weight with no array on disk.""" + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 1, rngs=rngs) + self.dropout = nnx.Dropout(rate=0.5, rngs=rngs) + self.extra = nnx.Linear(1, 1, rngs=rngs) + + def __call__(self, x, deterministic=False): + return self.extra(self.dropout(self.linear(x), deterministic=deterministic)) + + +_TX = optax.adam(1e-3) + + +def _config(): + """Minimal config with the fields save/restore reads for a pure_nnx run.""" + return SimpleNamespace( + pure_nnx=True, + enable_diloco=False, + enable_checkpointing=True, + enable_continuous_checkpointing=False, + enable_emergency_checkpoint=False, + enable_autocheckpoint=False, + checkpoint_period=1, + local_checkpoint_period=0, + async_checkpointing=False, + dataset_type="tfds", + lora=None, + checkpoint_storage_target_data_file_size_bytes=checkpointing.DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE, + elastic_enabled=False, + ) + + +def _abstract_state(model_cls): + """An abstract (ShapeDtypeStruct) nnx.State for `model_cls`, the restore blueprint.""" + mesh = jax.sharding.Mesh(jax.devices(), ("x",)) + sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) + + def make(): + model = model_cls(nnx.Rngs(9)) + return nnx.state(train_state_nnx.TrainStateNNX(model, nnx.Optimizer(model, _TX, wrt=nnx.Param))) + + abstract = nnx.eval_shape(make) + return jax.tree.map( + lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=sharding) if hasattr(x, "shape") else x, + abstract, + ) + + +class TestShapeDtypeStructErrorRoundTrip(unittest.TestCase): + """Real save->restore proving the descriptive ShapeDtypeStruct error.""" + + def setUp(self): + self._dir = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self._dir, ignore_errors=True) + + def _save_model(self): + """Saves a one-step `_Model` checkpoint and returns its manager.""" + manager = checkpointing.create_orbax_checkpoint_manager( + os.path.join(self._dir, "ckpt"), + enable_checkpointing=True, + use_async=False, + save_interval_steps=1, + dataset_type="tfds", + ) + model = _Model(nnx.Rngs(0)) + state = train_state_nnx.TrainStateNNX(model, nnx.Optimizer(model, _TX, wrt=nnx.Param)) + checkpointing.maybe_save_checkpoint(manager, nnx.state(state), _config(), data_iterator=None, step=1) + manager.wait_until_finished() + return manager + + def _restore(self, manager, model_cls): + """Restores the saved checkpoint into an abstract state for `model_cls`.""" + full, _ = checkpointing.load_state_if_possible( + manager, + data_iterator=None, + load_parameters_from_path="", + load_full_state_from_path="", + checkpoint_storage_concurrent_gb=8, + abstract_unboxed_pre_state=_abstract_state(model_cls), + dataset_type="tfds", + maxtext_config=_config(), + ) + return full["items"] + + def test_surviving_shape_dtype_struct_error_names_path_and_shape(self): + manager = self._save_model() + with self.assertRaises(ValueError) as ctx: + self._restore(manager, _ModelExtraLayer) + msg = str(ctx.exception) + self.assertIn("unmaterialized ShapeDtypeStruct after restore", msg) # the SDS guard, not a compile traceback + self.assertIn("model/extra", msg) # the exact offending parameter path + self.assertIn("(1,)", msg) # its shape + self.assertIn("emb_dim", msg) # points at the config knobs to check + + def test_matching_config_restores_without_surviving_sds(self): + """Negative control: an exact-arch restore yields concrete arrays, so the guard stays quiet.""" + manager = self._save_model() + restored = self._restore(manager, _Model) + self.assertNotIsInstance(restored["model"]["linear"]["kernel"], jax.ShapeDtypeStruct) + + +if __name__ == "__main__": + unittest.main()