Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,29 @@ def create_numpy_pytree(*, add: int = 0, include_scalars: bool = True):


def create_sharded_pytree(
*, add: int = 0, reverse_devices: bool = False, include_scalars: bool = True
*,
add: int = 0,
reverse_devices: bool = False,
include_scalars: bool = True,
replicated_arrays: bool = False,
devices: list[jax.Device] | None = None,
slices: int = 2,
) -> tuple[tree_types.PyTree, tree_types.PyTree]:
"""Creates a sharded PyTree from `create_numpy_pytree`.

Args:
add: The value to add to leaf arrays.
reverse_devices: Whether to reverse the devices in the mesh.
include_scalars: Whether to include scalars in the pytree.
replicated_arrays: Whether to replicate arrays across devices.
devices: The devices to use for the mesh.
slices: The number of slices to use for the mesh.

Returns:
A tuple of (pytree, abstract_pytree).
"""
devices = jax.devices()
if devices is None:
devices = jax.devices()
num_devices = len(devices)
devices = (
np.asarray(list(reversed(devices)))
Expand All @@ -64,13 +74,17 @@ def create_sharded_pytree(
)

mesh_2d = jax.sharding.Mesh(
devices.reshape((2, num_devices // 2)), ('x', 'y')
devices.reshape((slices, num_devices // slices)), ('x', 'y')
)
mesh_axes_2d = jax.sharding.PartitionSpec('x', 'y')
if replicated_arrays:
mesh_axes_2d = jax.sharding.PartitionSpec(None, 'y')
mesh_1d = jax.sharding.Mesh(devices, ('x',))
mesh_axes_1d = jax.sharding.PartitionSpec(
'x',
)
if replicated_arrays:
mesh_axes_1d = jax.sharding.PartitionSpec(None,)
mesh_0d = jax.sharding.Mesh(devices, ('x',))
mesh_axes_0d = jax.sharding.PartitionSpec(
None,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Manages asynchronous backups of JAX array states to pinned host memory."""

import collections
from typing import Any

from etils import epath
import jax
from orbax.checkpoint.experimental.v1 import training
from orbax.checkpoint.experimental.v1._src.tree import types as tree_types


class Snapshotter:
"""Manages asynchronous backups of JAX array states to pinned host memory."""

_snapshots: collections.deque[tuple[tree_types.PyTree, int]]

def __init__(self):
self._snapshots = collections.deque(maxlen=2)

def save_pytree(self, step: int, state: Any) -> None:
"""Move arrays onto CPU worker devices."""
pinned_shardings = jax.tree.map(
lambda x: x.sharding.with_memory_kind("pinned_host"), state
)

pinned_state = jax.device_put(state, pinned_shardings)

self._snapshots.append((pinned_state, step))

def load_pytree(self, abstract_state: Any) -> tree_types.PyTree:
"""Move arrays from workers onto TPU devices.

Uses `abstract_state.sharding` to properly re-partition onto the new mesh.

Args:
abstract_state: An abstract representation of the state, used to provide
the target shardings for the restored arrays on the TPU devices.

Returns:
The restored array state.

Raises:
RuntimeError: If no snapshots are available to restore from.
"""
if not self._snapshots:
raise RuntimeError("No snapshots available to restore from.")

pinned_state, _ = self._snapshots[-1]

# Re-shard on host to the target device mesh
host_target_shardings = jax.tree.map(
lambda x: x.sharding.with_memory_kind("pinned_host"), abstract_state
)
host_target_state = jax.device_put(pinned_state, host_target_shardings)

# Move from host back to device (TPU) memory.
restored_state = jax.device_put(
host_target_state, jax.tree.map(lambda x: x.sharding, abstract_state)
)
jax.block_until_ready(restored_state)

return restored_state

@property
def latest(self) -> training.CheckpointMetadata[None] | None:
"""Returns the training step of the most recently pinned backup."""
if not self._snapshots:
return None
return training.CheckpointMetadata(
step=self._snapshots[-1][1],
path=epath.Path(),
metadata=None,
)
Loading