@@ -43,19 +43,29 @@ def create_numpy_pytree(*, add: int = 0, include_scalars: bool = True):
4343
4444
4545def create_sharded_pytree (
46- * , add : int = 0 , reverse_devices : bool = False , include_scalars : bool = True
46+ * ,
47+ add : int = 0 ,
48+ reverse_devices : bool = False ,
49+ include_scalars : bool = True ,
50+ replicated_arrays : bool = False ,
51+ devices : list [jax .Device ] | None = None ,
52+ slices : int = 2 ,
4753) -> tuple [tree_types .PyTree , tree_types .PyTree ]:
4854 """Creates a sharded PyTree from `create_numpy_pytree`.
4955
5056 Args:
5157 add: The value to add to leaf arrays.
5258 reverse_devices: Whether to reverse the devices in the mesh.
5359 include_scalars: Whether to include scalars in the pytree.
60+ replicated_arrays: Whether to replicate arrays across devices.
61+ devices: The devices to use for the mesh.
62+ slices: The number of slices to use for the mesh.
5463
5564 Returns:
5665 A tuple of (pytree, abstract_pytree).
5766 """
58- devices = jax .devices ()
67+ if devices is None :
68+ devices = jax .devices ()
5969 num_devices = len (devices )
6070 devices = (
6171 np .asarray (list (reversed (devices )))
@@ -64,13 +74,17 @@ def create_sharded_pytree(
6474 )
6575
6676 mesh_2d = jax .sharding .Mesh (
67- devices .reshape ((2 , num_devices // 2 )), ('x' , 'y' )
77+ devices .reshape ((slices , num_devices // slices )), ('x' , 'y' )
6878 )
6979 mesh_axes_2d = jax .sharding .PartitionSpec ('x' , 'y' )
80+ if replicated_arrays :
81+ mesh_axes_2d = jax .sharding .PartitionSpec (None , 'y' )
7082 mesh_1d = jax .sharding .Mesh (devices , ('x' ,))
7183 mesh_axes_1d = jax .sharding .PartitionSpec (
7284 'x' ,
7385 )
86+ if replicated_arrays :
87+ mesh_axes_1d = jax .sharding .PartitionSpec (None ,)
7488 mesh_0d = jax .sharding .Mesh (devices , ('x' ,))
7589 mesh_axes_0d = jax .sharding .PartitionSpec (
7690 None ,
0 commit comments