From 1cf2be58fc590364192bbdfecf5113ea21e8bc45 Mon Sep 17 00:00:00 2001 From: Sujeeth Jinesh Date: Wed, 1 Jul 2026 12:44:09 -0700 Subject: [PATCH] Avoid checkpoint work on skipped steps --- src/maxtext/common/checkpointing.py | 57 ++++++-- tests/unit/train_state_nnx_checkpoint_test.py | 132 +++++++++++++++--- 2 files changed, 156 insertions(+), 33 deletions(-) diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py index 0316cd29b5..a9c0e12bf3 100644 --- a/src/maxtext/common/checkpointing.py +++ b/src/maxtext/common/checkpointing.py @@ -1082,6 +1082,32 @@ def load_checkpoint_metadata(checkpoint_dir_path: str) -> dict[str, Any]: return {} +def _uses_local_checkpoint_period(config): + return config.enable_emergency_checkpoint or config.enable_multi_tier_checkpointing + + +def _should_save_checkpoint_at_step(checkpoint_manager, step, config, force): + """Returns whether MaxText should build and dispatch checkpoint args.""" + if force: + return True + if config.enable_continuous_checkpointing: + base_checkpoint_due = bool(checkpoint_manager.should_save(step)) + else: + base_checkpoint_due = step % config.checkpoint_period == 0 + local_checkpoint_due = _uses_local_checkpoint_period(config) and step % config.local_checkpoint_period == 0 + autocheckpoint_due = config.enable_autocheckpoint and checkpoint_manager.reached_preemption(step) + return base_checkpoint_due or local_checkpoint_due or autocheckpoint_due + + +def _handle_post_checkpoint_preemption(checkpoint_manager, step, force_ckpt_save): + """Waits on final/preemption saves and raises if preempted.""" + reached_preemption = checkpoint_manager.reached_preemption(step) + if force_ckpt_save or reached_preemption: + checkpoint_manager.wait_until_finished() + if reached_preemption: + raise exceptions.StopTraining("Job is preempted.") + + def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step=None): """Save checkpoint if checkpointing is enabled.""" if checkpoint_manager is None: @@ -1100,6 +1126,18 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step # Linen TrainState has .step attribute actual_step = int(state.step) - 1 + # Determine if a checkpoint save should be forced, overriding the usual + # `config.checkpoint_period` logic. + # This occurs if this function was called: + # without an explicit 'step' (implying it's a checkpoint save for final step), + # AND the 'actual_step' is a valid step, + # AND it's not a step that would normally trigger a checkpoint save. + force_ckpt_save = step is None and actual_step != -1 and (actual_step % config.checkpoint_period != 0) + + if not _should_save_checkpoint_at_step(checkpoint_manager, actual_step, config, force_ckpt_save): + _handle_post_checkpoint_preemption(checkpoint_manager, actual_step, force_ckpt_save) + return + if checkpoint_manager.latest_step() == actual_step: max_logging.log(f"Checkpoint for step {actual_step} already exists, skipping save.") return @@ -1114,13 +1152,6 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step else: state = train_state_nnx.to_linen_checkpoint_dict(state.to_pure_dict()) - # Determine if a checkpoint save should be forced, overriding the usual `config.checkpoint_period` logic. - # This occurs if this function was called: - # without an explicit 'step' (implying it's a checkpoint save for final step), - # AND the 'actual_step' is a valid step, - # AND it's not a step that would normally trigger a checkpoint save. - force_ckpt_save = step is None and actual_step != -1 and (actual_step % config.checkpoint_period != 0) - try: checkpoint_saved = save_checkpoint(checkpoint_manager, actual_step, state, config, data_iterator, force_ckpt_save) if checkpoint_saved: @@ -1142,13 +1173,9 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step except Exception as e: raise exceptions.StopTraining(f"Checkpointing failed. {str(e)}") from e - # Wait for any pending checkpoint save to finish during preemption or final step save - if force_ckpt_save or checkpoint_manager.reached_preemption(actual_step): - checkpoint_manager.wait_until_finished() - - # Raise exception upon preemption - if checkpoint_manager.reached_preemption(actual_step): - raise exceptions.StopTraining("Job is preempted.") + # Wait for any pending checkpoint save to finish during preemption or final + # step save, then raise upon preemption. + _handle_post_checkpoint_preemption(checkpoint_manager, actual_step, force_ckpt_save) def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator=None, force=False): @@ -1157,7 +1184,7 @@ def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator= if ( force or (step % config.checkpoint_period == 0 and not config.enable_continuous_checkpointing) - or (config.enable_emergency_checkpoint and step % config.local_checkpoint_period == 0) + or (_uses_local_checkpoint_period(config) and step % config.local_checkpoint_period == 0) or (config.enable_autocheckpoint and checkpoint_manager.reached_preemption(step)) ): blocking_until_ready_start = time.time() diff --git a/tests/unit/train_state_nnx_checkpoint_test.py b/tests/unit/train_state_nnx_checkpoint_test.py index fdb94046c7..41e7d19769 100644 --- a/tests/unit/train_state_nnx_checkpoint_test.py +++ b/tests/unit/train_state_nnx_checkpoint_test.py @@ -351,6 +351,22 @@ class TestMaybeSaveCheckpointStepAlignment(unittest.TestCase): def setUp(self): self.tx = optax.adam(1e-3) + def _config(self, **overrides): + """Builds a minimal checkpoint config for maybe_save_checkpoint tests.""" + values = { + "pure_nnx": True, + "checkpoint_period": 10, + "async_checkpointing": False, + "enable_diloco": False, + "enable_continuous_checkpointing": False, + "enable_emergency_checkpoint": False, + "enable_multi_tier_checkpointing": False, + "local_checkpoint_period": 0, + "enable_autocheckpoint": False, + } + values.update(overrides) + return SimpleNamespace(**values) + def _build_nnx_state(self, num_steps): """Build an nnx.State flattened from TrainStateNNX after num_steps gradient applications.""" model = MockModel(rngs=nnx.Rngs(0)) @@ -380,12 +396,7 @@ def _build_linen_state(self, num_steps): def _invoke_maybe_save(self, state, pure_nnx): """Call maybe_save_checkpoint with save_checkpoint patched, return {step, state} captured.""" # checkpoint_period=1 keeps force_ckpt_save False regardless of actual_step. - config = SimpleNamespace( - pure_nnx=pure_nnx, - checkpoint_period=1, - async_checkpointing=False, - enable_diloco=False, - ) + config = self._config(pure_nnx=pure_nnx, checkpoint_period=1) mgr = mock.MagicMock() mgr.reached_preemption.return_value = False @@ -451,12 +462,7 @@ def test_maybe_save_checkpoint_skips_if_already_saved(self): state = self._build_nnx_state(self.N_STEPS) actual_step = self.N_STEPS - 1 - config = SimpleNamespace( - pure_nnx=True, - checkpoint_period=1, - async_checkpointing=False, - enable_diloco=False, - ) + config = self._config(checkpoint_period=1) mgr = mock.MagicMock() mgr.reached_preemption.return_value = False # Mock latest_step to return the same actual_step @@ -475,12 +481,7 @@ def test_maybe_save_checkpoint_saves_if_not_already_saved(self): state = self._build_nnx_state(self.N_STEPS) actual_step = self.N_STEPS - 1 - config = SimpleNamespace( - pure_nnx=True, - checkpoint_period=1, - async_checkpointing=False, - enable_diloco=False, - ) + config = self._config(checkpoint_period=1) mgr = mock.MagicMock() mgr.reached_preemption.return_value = False # Mock latest_step to return a different step (or None) @@ -495,6 +496,101 @@ def test_maybe_save_checkpoint_saves_if_not_already_saved(self): # Assert that save_checkpoint WAS called! save_checkpoint_mock.assert_called_once() + def test_maybe_save_checkpoint_skips_non_checkpoint_step_before_state_work( + self, + ): + """Non-checkpoint steps should not query latest_step or build save args.""" + state = mock.Mock() + config = self._config() + mgr = mock.MagicMock() + mgr.reached_preemption.return_value = False + + with mock.patch.object(checkpointing, "save_checkpoint") as save_checkpoint_mock: + checkpointing.maybe_save_checkpoint(mgr, state, config, data_iterator=None, step=3) + + mgr.latest_step.assert_not_called() + mgr.reached_preemption.assert_called_once_with(3) + mgr.wait_until_finished.assert_not_called() + state.to_pure_dict.assert_not_called() + save_checkpoint_mock.assert_not_called() + + def test_maybe_save_checkpoint_handles_preemption_on_non_checkpoint_step( + self, + ): + """Non-checkpoint steps must still honor preemption handling.""" + state = mock.Mock() + config = self._config() + mgr = mock.MagicMock() + mgr.reached_preemption.return_value = True + + with mock.patch.object(checkpointing, "save_checkpoint") as save_checkpoint_mock: + with self.assertRaises(checkpointing.exceptions.StopTraining): + checkpointing.maybe_save_checkpoint(mgr, state, config, data_iterator=None, step=3) + + mgr.latest_step.assert_not_called() + mgr.reached_preemption.assert_called_once_with(3) + mgr.wait_until_finished.assert_called_once_with() + state.to_pure_dict.assert_not_called() + save_checkpoint_mock.assert_not_called() + + def test_maybe_save_checkpoint_allows_local_checkpoint_period(self): + """Emergency and multi-tier local checkpoint periods dispatch save work.""" + for checkpoint_flag in ( + "enable_emergency_checkpoint", + "enable_multi_tier_checkpointing", + ): + with self.subTest(checkpoint_flag=checkpoint_flag): + state = mock.Mock() + state.to_pure_dict.return_value = { + "model": {}, + "optimizer": {"step": 5}, + } + config = self._config( + checkpoint_period=100, + local_checkpoint_period=5, + **{checkpoint_flag: True}, + ) + mgr = mock.MagicMock() + mgr.latest_step.return_value = None + mgr.reached_preemption.return_value = False + save_checkpoint_mock = mock.MagicMock(return_value=False) + + with mock.patch.object(checkpointing, "save_checkpoint", save_checkpoint_mock): + checkpointing.maybe_save_checkpoint(mgr, state, config, data_iterator=None, step=5) + + mgr.latest_step.assert_called_once_with() + mgr.reached_preemption.assert_called_once_with(5) + mgr.wait_until_finished.assert_not_called() + state.to_pure_dict.assert_called_once_with() + save_checkpoint_mock.assert_called_once() + + def test_maybe_save_checkpoint_allows_mtc_period_with_continuous_policy( + self, + ): + """Continuous checkpointing should not suppress MTC local saves.""" + state = mock.Mock() + state.to_pure_dict.return_value = {"model": {}, "optimizer": {"step": 5}} + config = self._config( + checkpoint_period=100, + enable_continuous_checkpointing=True, + enable_multi_tier_checkpointing=True, + local_checkpoint_period=5, + ) + mgr = mock.MagicMock() + mgr.should_save.return_value = False + mgr.latest_step.return_value = None + mgr.reached_preemption.return_value = False + save_checkpoint_mock = mock.MagicMock(return_value=False) + + with mock.patch.object(checkpointing, "save_checkpoint", save_checkpoint_mock): + checkpointing.maybe_save_checkpoint(mgr, state, config, data_iterator=None, step=5) + + mgr.should_save.assert_called_once_with(5) + mgr.latest_step.assert_called_once_with() + mgr.reached_preemption.assert_called_once_with(5) + state.to_pure_dict.assert_called_once_with() + save_checkpoint_mock.assert_called_once() + class TestLinenCheckpointFormatConverters(unittest.TestCase): """to_linen_checkpoint_dict / from_linen_checkpoint_dict (NNX <-> Linen on-disk layout)."""