Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions src/maxtext/utils/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import re
import subprocess
import time
from types import SimpleNamespace
from typing import Any

from packaging.version import Version
Expand Down Expand Up @@ -231,6 +232,37 @@ def add_text_to_summary_writer(key, value, summary_writer):
summary_writer.add_text(key, value)


def _single_controller_mtc_init_kwargs(raw_keys):
"""Returns topology kwargs for single-controller MTC initialization."""
kwargs = {
"data_parallelism": raw_keys["mtc_data_parallelism"],
"num_slices": raw_keys["num_slices"],
}
if not raw_keys.get("elastic_enabled", False):
return kwargs

config = SimpleNamespace(**raw_keys)
if not elastic_utils.should_use_elastic(config):
return kwargs

active_devices = tuple(elastic_utils.live_devices(config))
active_slice_indices = {getattr(device, "slice_index", 0) for device in active_devices if device is not None}
if not active_devices or not active_slice_indices:
raise ValueError("Elastic single-controller MTC initialization found no active devices.")

kwargs["devices"] = active_devices
kwargs["num_slices"] = len(active_slice_indices)
if not kwargs["data_parallelism"]:
kwargs["data_parallelism"] = kwargs["num_slices"]
max_logging.log(
"Using active elastic devices for single-controller MTC initialization: "
f"active_num_slices={kwargs['num_slices']}, "
f"active_device_count={len(active_devices)}, "
f"configured_num_slices={raw_keys['num_slices']}."
)
return kwargs


def maybe_initialize_jax_distributed_system(raw_keys):
"""The best recipe to initialize the Jax Distributed System has varied over time. We keep a layer of
indirection in MaxText to avoid breaking the call sites unnecessarily.
Expand All @@ -248,14 +280,14 @@ def maybe_initialize_jax_distributed_system(raw_keys):
max_logging.log("Skipping jax distributed system since its not needed for single controller.")
if raw_keys["enable_multi_tier_checkpointing"]:
max_logging.log("Initializing multi-tier checkpointing for single controller...")
mtc_init_kwargs = _single_controller_mtc_init_kwargs(raw_keys)
initialize_multi_tier_checkpointing(
local_checkpoint_directory=raw_keys["local_checkpoint_directory"],
backup_interval_minutes=raw_keys["multi_tier_checkpointing_backup_interval_minutes"],
run_name=raw_keys["run_name"],
jax_initialization_timeout_seconds=raw_keys["jax_distributed_initialization_timeout"],
data_parallelism=raw_keys["mtc_data_parallelism"],
num_slices=raw_keys["num_slices"],
use_colocated_python=True,
**mtc_init_kwargs,
)
return
if jax.distributed.is_initialized():
Expand Down
39 changes: 39 additions & 0 deletions tests/unit/max_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,45 @@ def test_single_controller_multi_tier_checkpointing_uses_colocated_python(self,
use_colocated_python=True,
)

@mock.patch("maxtext.utils.max_utils.elastic_utils.live_devices")
@mock.patch(
"maxtext.utils.max_utils.elastic_utils.should_use_elastic",
return_value=True,
)
@mock.patch("maxtext.utils.max_utils.initialize_multi_tier_checkpointing")
@mock.patch("jax.distributed.initialize")
def test_single_controller_multi_tier_checkpointing_uses_active_elastic_devices(
self, mock_init, mock_mtc, mock_should_use_elastic, mock_live_devices
):
active_devices = (
mock.Mock(slice_index=0),
mock.Mock(slice_index=0),
)
mock_live_devices.return_value = active_devices
raw_keys = self._base_keys(
enable_single_controller=True,
enable_multi_tier_checkpointing=True,
elastic_enabled=True,
mtc_data_parallelism=0,
num_slices=2,
)

max_utils.maybe_initialize_jax_distributed_system(raw_keys)

mock_init.assert_not_called()
mock_should_use_elastic.assert_called_once()
mock_live_devices.assert_called_once()
mock_mtc.assert_called_once_with(
local_checkpoint_directory=self._base_keys()["local_checkpoint_directory"],
backup_interval_minutes=self._base_keys()["multi_tier_checkpointing_backup_interval_minutes"],
run_name=self._base_keys()["run_name"],
jax_initialization_timeout_seconds=self._base_keys()["jax_distributed_initialization_timeout"],
data_parallelism=1,
num_slices=1,
use_colocated_python=True,
devices=active_devices,
)

@mock.patch("jax.distributed.initialize")
def test_tpu_checkpointing_no_emergency_calls_jax_init(self, mock_init):
raw_keys = self._base_keys(enable_checkpointing=True, compile_topology_num_slices=-1)
Expand Down
Loading