Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 42 additions & 15 deletions src/maxtext/common/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,32 @@ def load_checkpoint_metadata(checkpoint_dir_path: str) -> dict[str, Any]:
return {}


def _uses_local_checkpoint_period(config):
return config.enable_emergency_checkpoint or config.enable_multi_tier_checkpointing


def _should_save_checkpoint_at_step(checkpoint_manager, step, config, force):
"""Returns whether MaxText should build and dispatch checkpoint args."""
if force:
return True
if config.enable_continuous_checkpointing:
base_checkpoint_due = bool(checkpoint_manager.should_save(step))
else:
base_checkpoint_due = step % config.checkpoint_period == 0
local_checkpoint_due = _uses_local_checkpoint_period(config) and step % config.local_checkpoint_period == 0
autocheckpoint_due = config.enable_autocheckpoint and checkpoint_manager.reached_preemption(step)
return base_checkpoint_due or local_checkpoint_due or autocheckpoint_due


def _handle_post_checkpoint_preemption(checkpoint_manager, step, force_ckpt_save):
"""Waits on final/preemption saves and raises if preempted."""
reached_preemption = checkpoint_manager.reached_preemption(step)
if force_ckpt_save or reached_preemption:
checkpoint_manager.wait_until_finished()
if reached_preemption:
raise exceptions.StopTraining("Job is preempted.")


def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step=None):
"""Save checkpoint if checkpointing is enabled."""
if checkpoint_manager is None:
Expand All @@ -1100,6 +1126,18 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step
# Linen TrainState has .step attribute
actual_step = int(state.step) - 1

# Determine if a checkpoint save should be forced, overriding the usual
# `config.checkpoint_period` logic.
# This occurs if this function was called:
# without an explicit 'step' (implying it's a checkpoint save for final step),
# AND the 'actual_step' is a valid step,
# AND it's not a step that would normally trigger a checkpoint save.
force_ckpt_save = step is None and actual_step != -1 and (actual_step % config.checkpoint_period != 0)

if not _should_save_checkpoint_at_step(checkpoint_manager, actual_step, config, force_ckpt_save):
_handle_post_checkpoint_preemption(checkpoint_manager, actual_step, force_ckpt_save)
return

if checkpoint_manager.latest_step() == actual_step:
max_logging.log(f"Checkpoint for step {actual_step} already exists, skipping save.")
return
Expand All @@ -1114,13 +1152,6 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step
else:
state = train_state_nnx.to_linen_checkpoint_dict(state.to_pure_dict())

# Determine if a checkpoint save should be forced, overriding the usual `config.checkpoint_period` logic.
# This occurs if this function was called:
# without an explicit 'step' (implying it's a checkpoint save for final step),
# AND the 'actual_step' is a valid step,
# AND it's not a step that would normally trigger a checkpoint save.
force_ckpt_save = step is None and actual_step != -1 and (actual_step % config.checkpoint_period != 0)

try:
checkpoint_saved = save_checkpoint(checkpoint_manager, actual_step, state, config, data_iterator, force_ckpt_save)
if checkpoint_saved:
Expand All @@ -1142,13 +1173,9 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step
except Exception as e:
raise exceptions.StopTraining(f"Checkpointing failed. {str(e)}") from e

# Wait for any pending checkpoint save to finish during preemption or final step save
if force_ckpt_save or checkpoint_manager.reached_preemption(actual_step):
checkpoint_manager.wait_until_finished()

# Raise exception upon preemption
if checkpoint_manager.reached_preemption(actual_step):
raise exceptions.StopTraining("Job is preempted.")
# Wait for any pending checkpoint save to finish during preemption or final
# step save, then raise upon preemption.
_handle_post_checkpoint_preemption(checkpoint_manager, actual_step, force_ckpt_save)


def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator=None, force=False):
Expand All @@ -1157,7 +1184,7 @@ def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator=
if (
force
or (step % config.checkpoint_period == 0 and not config.enable_continuous_checkpointing)
or (config.enable_emergency_checkpoint and step % config.local_checkpoint_period == 0)
or (_uses_local_checkpoint_period(config) and step % config.local_checkpoint_period == 0)
or (config.enable_autocheckpoint and checkpoint_manager.reached_preemption(step))
):
blocking_until_ready_start = time.time()
Expand Down
132 changes: 114 additions & 18 deletions tests/unit/train_state_nnx_checkpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,22 @@ class TestMaybeSaveCheckpointStepAlignment(unittest.TestCase):
def setUp(self):
self.tx = optax.adam(1e-3)

def _config(self, **overrides):
"""Builds a minimal checkpoint config for maybe_save_checkpoint tests."""
values = {
"pure_nnx": True,
"checkpoint_period": 10,
"async_checkpointing": False,
"enable_diloco": False,
"enable_continuous_checkpointing": False,
"enable_emergency_checkpoint": False,
"enable_multi_tier_checkpointing": False,
"local_checkpoint_period": 0,
"enable_autocheckpoint": False,
}
values.update(overrides)
return SimpleNamespace(**values)

def _build_nnx_state(self, num_steps):
"""Build an nnx.State flattened from TrainStateNNX after num_steps gradient applications."""
model = MockModel(rngs=nnx.Rngs(0))
Expand Down Expand Up @@ -380,12 +396,7 @@ def _build_linen_state(self, num_steps):
def _invoke_maybe_save(self, state, pure_nnx):
"""Call maybe_save_checkpoint with save_checkpoint patched, return {step, state} captured."""
# checkpoint_period=1 keeps force_ckpt_save False regardless of actual_step.
config = SimpleNamespace(
pure_nnx=pure_nnx,
checkpoint_period=1,
async_checkpointing=False,
enable_diloco=False,
)
config = self._config(pure_nnx=pure_nnx, checkpoint_period=1)
mgr = mock.MagicMock()
mgr.reached_preemption.return_value = False

Expand Down Expand Up @@ -451,12 +462,7 @@ def test_maybe_save_checkpoint_skips_if_already_saved(self):
state = self._build_nnx_state(self.N_STEPS)
actual_step = self.N_STEPS - 1

config = SimpleNamespace(
pure_nnx=True,
checkpoint_period=1,
async_checkpointing=False,
enable_diloco=False,
)
config = self._config(checkpoint_period=1)
mgr = mock.MagicMock()
mgr.reached_preemption.return_value = False
# Mock latest_step to return the same actual_step
Expand All @@ -475,12 +481,7 @@ def test_maybe_save_checkpoint_saves_if_not_already_saved(self):
state = self._build_nnx_state(self.N_STEPS)
actual_step = self.N_STEPS - 1

config = SimpleNamespace(
pure_nnx=True,
checkpoint_period=1,
async_checkpointing=False,
enable_diloco=False,
)
config = self._config(checkpoint_period=1)
mgr = mock.MagicMock()
mgr.reached_preemption.return_value = False
# Mock latest_step to return a different step (or None)
Expand All @@ -495,6 +496,101 @@ def test_maybe_save_checkpoint_saves_if_not_already_saved(self):
# Assert that save_checkpoint WAS called!
save_checkpoint_mock.assert_called_once()

def test_maybe_save_checkpoint_skips_non_checkpoint_step_before_state_work(
self,
):
"""Non-checkpoint steps should not query latest_step or build save args."""
state = mock.Mock()
config = self._config()
mgr = mock.MagicMock()
mgr.reached_preemption.return_value = False

with mock.patch.object(checkpointing, "save_checkpoint") as save_checkpoint_mock:
checkpointing.maybe_save_checkpoint(mgr, state, config, data_iterator=None, step=3)

mgr.latest_step.assert_not_called()
mgr.reached_preemption.assert_called_once_with(3)
mgr.wait_until_finished.assert_not_called()
state.to_pure_dict.assert_not_called()
save_checkpoint_mock.assert_not_called()

def test_maybe_save_checkpoint_handles_preemption_on_non_checkpoint_step(
self,
):
"""Non-checkpoint steps must still honor preemption handling."""
state = mock.Mock()
config = self._config()
mgr = mock.MagicMock()
mgr.reached_preemption.return_value = True

with mock.patch.object(checkpointing, "save_checkpoint") as save_checkpoint_mock:
with self.assertRaises(checkpointing.exceptions.StopTraining):
checkpointing.maybe_save_checkpoint(mgr, state, config, data_iterator=None, step=3)

mgr.latest_step.assert_not_called()
mgr.reached_preemption.assert_called_once_with(3)
mgr.wait_until_finished.assert_called_once_with()
state.to_pure_dict.assert_not_called()
save_checkpoint_mock.assert_not_called()

def test_maybe_save_checkpoint_allows_local_checkpoint_period(self):
"""Emergency and multi-tier local checkpoint periods dispatch save work."""
for checkpoint_flag in (
"enable_emergency_checkpoint",
"enable_multi_tier_checkpointing",
):
with self.subTest(checkpoint_flag=checkpoint_flag):
state = mock.Mock()
state.to_pure_dict.return_value = {
"model": {},
"optimizer": {"step": 5},
}
config = self._config(
checkpoint_period=100,
local_checkpoint_period=5,
**{checkpoint_flag: True},
)
mgr = mock.MagicMock()
mgr.latest_step.return_value = None
mgr.reached_preemption.return_value = False
save_checkpoint_mock = mock.MagicMock(return_value=False)

with mock.patch.object(checkpointing, "save_checkpoint", save_checkpoint_mock):
checkpointing.maybe_save_checkpoint(mgr, state, config, data_iterator=None, step=5)

mgr.latest_step.assert_called_once_with()
mgr.reached_preemption.assert_called_once_with(5)
mgr.wait_until_finished.assert_not_called()
state.to_pure_dict.assert_called_once_with()
save_checkpoint_mock.assert_called_once()

def test_maybe_save_checkpoint_allows_mtc_period_with_continuous_policy(
self,
):
"""Continuous checkpointing should not suppress MTC local saves."""
state = mock.Mock()
state.to_pure_dict.return_value = {"model": {}, "optimizer": {"step": 5}}
config = self._config(
checkpoint_period=100,
enable_continuous_checkpointing=True,
enable_multi_tier_checkpointing=True,
local_checkpoint_period=5,
)
mgr = mock.MagicMock()
mgr.should_save.return_value = False
mgr.latest_step.return_value = None
mgr.reached_preemption.return_value = False
save_checkpoint_mock = mock.MagicMock(return_value=False)

with mock.patch.object(checkpointing, "save_checkpoint", save_checkpoint_mock):
checkpointing.maybe_save_checkpoint(mgr, state, config, data_iterator=None, step=5)

mgr.should_save.assert_called_once_with(5)
mgr.latest_step.assert_called_once_with()
mgr.reached_preemption.assert_called_once_with(5)
state.to_pure_dict.assert_called_once_with()
save_checkpoint_mock.assert_called_once()


class TestLinenCheckpointFormatConverters(unittest.TestCase):
"""to_linen_checkpoint_dict / from_linen_checkpoint_dict (NNX <-> Linen on-disk layout)."""
Expand Down
Loading