Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions src/maxtext/common/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,26 +740,34 @@ def is_structural_or_shape_mismatch(e: Exception) -> bool:
return any(kw in msg for kw in mismatch_keywords)


def _assert_no_shaped_dtype_struct(pytree):
"""Asserts that there are no jax.ShapeDtypeStruct leaves in the restored pytree."""
def _assert_no_shaped_dtype_struct(pytree, path=()):
"""Asserts that there are no jax.ShapeDtypeStruct leaves in the restored pytree.
A surviving ShapeDtypeStruct means Orbax skipped that parameter on a
structural/shape mismatch; the path and shape point straight at the offending
config (e.g. emb_dim / mlp_dim / layers) vs. the checkpoint.
"""
if isinstance(pytree, jax.ShapeDtypeStruct):
key = "/".join(str(p) for p in path) or "<root>"
raise ValueError(
"Some parameters in the restored state remained as ShapeDtypeStruct"
f" (indicating structure mismatch): {pytree}."
f"Parameter '{key}' remained an unmaterialized ShapeDtypeStruct after restore"
f" ({pytree.shape} {pytree.dtype}), meaning the checkpoint had no matching array."
" This is usually a config/checkpoint shape mismatch -- verify emb_dim, mlp_dim,"
" num layers, and scan_layers match the saved checkpoint."
)

if hasattr(pytree, "keys") and hasattr(pytree, "__getitem__"):
for k in pytree.keys():
_assert_no_shaped_dtype_struct(pytree[k])
_assert_no_shaped_dtype_struct(pytree[k], path + (k,))
elif isinstance(pytree, (list, tuple)):
for v in pytree:
_assert_no_shaped_dtype_struct(v)
for i, v in enumerate(pytree):
_assert_no_shaped_dtype_struct(v, path + (i,))
else:
leaves = jax.tree_util.tree_leaves(pytree)
if len(leaves) == 1 and leaves[0] is pytree:
return
for leaf in leaves:
_assert_no_shaped_dtype_struct(leaf)
_assert_no_shaped_dtype_struct(leaf, path)


@contextlib.contextmanager
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/checkpointing_nnx_load_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,5 +265,20 @@ def test_linen_layout_params_restore_into_nnx_state(self):
self.assertTrue(jnp.array_equal(pure["linear"]["bias"], weights["linear"]["bias"]))


class TestShapeDtypeStructPath(unittest.TestCase):
"""The SDS guard must name the offending parameter path."""

def test_names_path_of_surviving_sds(self):
tree = {"model": {"linear": {"kernel": jax.ShapeDtypeStruct((2, 1), jnp.float32)}}}
with self.assertRaises(ValueError) as ctx:
checkpointing._assert_no_shaped_dtype_struct(tree) # pylint: disable=protected-access
self.assertIn("model/linear/kernel", str(ctx.exception))
self.assertIn("(2, 1)", str(ctx.exception))

def test_passes_for_concrete_arrays(self):
tree = {"model": {"linear": {"kernel": jnp.ones((2, 1))}}}
checkpointing._assert_no_shaped_dtype_struct(tree) # pylint: disable=protected-access


if __name__ == "__main__":
unittest.main()
158 changes: 158 additions & 0 deletions tests/unit/checkpointing_nnx_sds_error_roundtrip_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright 2025-2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""End-to-end proof that the ShapeDtypeStruct restore error names its parameter path.

Drives the real create_orbax_checkpoint_manager -> maybe_save_checkpoint ->
load_state_if_possible stack (no mocks). A weight present in the model but absent
from the checkpoint survives Orbax's partial restore as an unmaterialized
ShapeDtypeStruct; the guard must fail naming the exact parameter path and shape,
instead of letting the placeholder reach the first train_step as a cryptic compile
error.

Scope: this covers the *surviving-SDS* case (a weight with no matching array on
disk). A same-key/different-shape mismatch (e.g. a changed emb_dim on an existing
weight) does not reach this guard -- Orbax raises during restore and that path is
reported by the scan_layers mismatch handler.
"""

import os
import shutil
import tempfile
import unittest
from types import SimpleNamespace

from flax import nnx
import jax
from maxtext.common import checkpointing
from maxtext.common import train_state_nnx
import optax


class _Model(nnx.Module):
"""Linear + dropout, so the state carries rngs/dropout that gets stripped on save."""

def __init__(self, rngs: nnx.Rngs):
self.linear = nnx.Linear(2, 1, rngs=rngs)
self.dropout = nnx.Dropout(rate=0.5, rngs=rngs)

def __call__(self, x, deterministic=False):
return self.dropout(self.linear(x), deterministic=deterministic)


class _ModelExtraLayer(nnx.Module):
"""`_Model` plus a layer absent from a `_Model` checkpoint, i.e. a weight with no array on disk."""

def __init__(self, rngs: nnx.Rngs):
self.linear = nnx.Linear(2, 1, rngs=rngs)
self.dropout = nnx.Dropout(rate=0.5, rngs=rngs)
self.extra = nnx.Linear(1, 1, rngs=rngs)

def __call__(self, x, deterministic=False):
return self.extra(self.dropout(self.linear(x), deterministic=deterministic))


_TX = optax.adam(1e-3)


def _config():
"""Minimal config with the fields save/restore reads for a pure_nnx run."""
return SimpleNamespace(
pure_nnx=True,
enable_diloco=False,
enable_checkpointing=True,
enable_continuous_checkpointing=False,
enable_emergency_checkpoint=False,
enable_autocheckpoint=False,
checkpoint_period=1,
local_checkpoint_period=0,
async_checkpointing=False,
dataset_type="tfds",
lora=None,
checkpoint_storage_target_data_file_size_bytes=checkpointing.DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE,
elastic_enabled=False,
)


def _abstract_state(model_cls):
"""An abstract (ShapeDtypeStruct) nnx.State for `model_cls`, the restore blueprint."""
mesh = jax.sharding.Mesh(jax.devices(), ("x",))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())

def make():
model = model_cls(nnx.Rngs(9))
return nnx.state(train_state_nnx.TrainStateNNX(model, nnx.Optimizer(model, _TX, wrt=nnx.Param)))

abstract = nnx.eval_shape(make)
return jax.tree.map(
lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=sharding) if hasattr(x, "shape") else x,
abstract,
)


class TestShapeDtypeStructErrorRoundTrip(unittest.TestCase):
"""Real save->restore proving the descriptive ShapeDtypeStruct error."""

def setUp(self):
self._dir = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, self._dir, ignore_errors=True)

def _save_model(self):
"""Saves a one-step `_Model` checkpoint and returns its manager."""
manager = checkpointing.create_orbax_checkpoint_manager(
os.path.join(self._dir, "ckpt"),
enable_checkpointing=True,
use_async=False,
save_interval_steps=1,
dataset_type="tfds",
)
model = _Model(nnx.Rngs(0))
state = train_state_nnx.TrainStateNNX(model, nnx.Optimizer(model, _TX, wrt=nnx.Param))
checkpointing.maybe_save_checkpoint(manager, nnx.state(state), _config(), data_iterator=None, step=1)
manager.wait_until_finished()
return manager

def _restore(self, manager, model_cls):
"""Restores the saved checkpoint into an abstract state for `model_cls`."""
full, _ = checkpointing.load_state_if_possible(
manager,
data_iterator=None,
load_parameters_from_path="",
load_full_state_from_path="",
checkpoint_storage_concurrent_gb=8,
abstract_unboxed_pre_state=_abstract_state(model_cls),
dataset_type="tfds",
maxtext_config=_config(),
)
return full["items"]

def test_surviving_shape_dtype_struct_error_names_path_and_shape(self):
manager = self._save_model()
with self.assertRaises(ValueError) as ctx:
self._restore(manager, _ModelExtraLayer)
msg = str(ctx.exception)
self.assertIn("unmaterialized ShapeDtypeStruct after restore", msg) # the SDS guard, not a compile traceback
self.assertIn("model/extra", msg) # the exact offending parameter path
self.assertIn("(1,)", msg) # its shape
self.assertIn("emb_dim", msg) # points at the config knobs to check

def test_matching_config_restores_without_surviving_sds(self):
"""Negative control: an exact-arch restore yields concrete arrays, so the guard stays quiet."""
manager = self._save_model()
restored = self._restore(manager, _Model)
self.assertNotIsInstance(restored["model"]["linear"]["kernel"], jax.ShapeDtypeStruct)


if __name__ == "__main__":
unittest.main()
Loading