Skip to content

Commit 42d31a3

Browse files
Nikhil BansalOrbax Authors
authored andcommitted
Snapshot feature to save/load pathways checkpoint in host memory
PiperOrigin-RevId: 895163261
1 parent 5961a66 commit 42d31a3

1 file changed

Lines changed: 17 additions & 3 deletions

File tree

  • checkpoint/orbax/checkpoint/experimental/v1/_src/testing

checkpoint/orbax/checkpoint/experimental/v1/_src/testing/array_utils.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,19 +43,29 @@ def create_numpy_pytree(*, add: int = 0, include_scalars: bool = True):
4343

4444

4545
def create_sharded_pytree(
46-
*, add: int = 0, reverse_devices: bool = False, include_scalars: bool = True
46+
*,
47+
add: int = 0,
48+
reverse_devices: bool = False,
49+
include_scalars: bool = True,
50+
replicated_arrays: bool = False,
51+
devices: list[jax.Device] | None = None,
52+
slices: int = 2,
4753
) -> tuple[tree_types.PyTree, tree_types.PyTree]:
4854
"""Creates a sharded PyTree from `create_numpy_pytree`.
4955
5056
Args:
5157
add: The value to add to leaf arrays.
5258
reverse_devices: Whether to reverse the devices in the mesh.
5359
include_scalars: Whether to include scalars in the pytree.
60+
replicated_arrays: Whether to replicate arrays across devices.
61+
devices: The devices to use for the mesh.
62+
slices: The number of slices to use for the mesh.
5463
5564
Returns:
5665
A tuple of (pytree, abstract_pytree).
5766
"""
58-
devices = jax.devices()
67+
if devices is None:
68+
devices = jax.devices()
5969
num_devices = len(devices)
6070
devices = (
6171
np.asarray(list(reversed(devices)))
@@ -64,13 +74,17 @@ def create_sharded_pytree(
6474
)
6575

6676
mesh_2d = jax.sharding.Mesh(
67-
devices.reshape((2, num_devices // 2)), ('x', 'y')
77+
devices.reshape((slices, num_devices // slices)), ('x', 'y')
6878
)
6979
mesh_axes_2d = jax.sharding.PartitionSpec('x', 'y')
80+
if replicated_arrays:
81+
mesh_axes_2d = jax.sharding.PartitionSpec(None, 'y')
7082
mesh_1d = jax.sharding.Mesh(devices, ('x',))
7183
mesh_axes_1d = jax.sharding.PartitionSpec(
7284
'x',
7385
)
86+
if replicated_arrays:
87+
mesh_axes_1d = jax.sharding.PartitionSpec(None,)
7488
mesh_0d = jax.sharding.Mesh(devices, ('x',))
7589
mesh_axes_0d = jax.sharding.PartitionSpec(
7690
None,

0 commit comments

Comments
 (0)