From 8c5894f1d83ebe42e2e61e74cf7be2e3de8b4c3a Mon Sep 17 00:00:00 2001 From: Zachary Garrett Date: Fri, 12 Jun 2026 07:48:40 -0700 Subject: [PATCH] Make drjax fully compatible with JAX Explicit Sharding. JAX has introduced Explicit Sharding, which strictly enforces mesh axes checks and disables older implicit fallbacks. This CL updates DrJAX to be compatible with the explicit model under both eager and compiled (JIT) execution, and overhauls the test suite to ensure robust, clean, and idiomatic eager/JIT test coverage. ### Implementation * Explicit Axis Verification in `primitives.py` - We raise a clear `ValueError` early during abstract evaluation if a required placement axis is missing in an `Explicit` mesh. This prevents downstream compilation errors while allowing replication fallback under `Auto` meshes. * Robust Eager Fallback in `impls.py` - JAX eager mode fails with `AttributeError` (on `SingleDeviceSharding`) or `ValueError` (on device mismatch) when broadcasting directly to a mesh via `jnp.broadcast_to`. We resolve this with a robust two-step fallback where we broadcast first to an un-sharded temporary array, and then call `jax.sharding.reshard` to move it to the target mesh. ### Testing * Flattened `subTest` Architecture - We extended test targets to run both JIT and eager modes. To avoid test class bloating or complex loops, we flattened tests into sequential `self.subTest("eager")` and `self.subTest("jit")` blocks, using local checker functions to eliminate code duplication. * JIT Mesh Context Fixes - We resolved JIT "mesh mismatch" errors by explicitly binding the active mesh context via `with jax.set_mesh(mesh)` during compiled execution in tests. PiperOrigin-RevId: 931148692 --- drjax/_src/api_test.py | 286 +++++++++++++++++++++++------- drjax/_src/impls.py | 51 ++++-- drjax/_src/impls_sharding_test.py | 192 +++++++++++++------- drjax/_src/impls_test.py | 23 ++- drjax/_src/primitives.py | 29 ++- drjax/_src/primitives_test.py | 29 ++- 6 files changed, 439 insertions(+), 171 deletions(-) diff --git a/drjax/_src/api_test.py b/drjax/_src/api_test.py index 3a9298a..bac0d35 100644 --- a/drjax/_src/api_test.py +++ b/drjax/_src/api_test.py @@ -29,21 +29,14 @@ def drjax_program(*, placements): return api.drjax_program(placements=placements, self_module=api) -@parameterized.product( - placement_name=["clients", "XY"], - axes_type=[ - jax.sharding.AxisType.Auto, - jax.sharding.AxisType.Explicit, - ], -) -class ApiTest(absltest.TestCase): +class ApiTest(parameterized.TestCase): def assertShardingEqual(self, arr, sharding): + # Canonicalize with trailing `None`s to the rank of the input array. + # This canonicalizes across Auto and Explicit axis types, the former + # which may not include trailing `None`s. canonical_array_sharding = jax.sharding.NamedSharding( arr.sharding.mesh, - # Canonicalize with trailing `None`s to the rank of the input array. - # This canonicalizes across Auto and Explicit axis types, the former - # which may not include trailing `None`s. jax.sharding.PartitionSpec(*( axis for axis, _ in itertools.zip_longest(arr.sharding.spec, arr.shape) @@ -51,21 +44,30 @@ def assertShardingEqual(self, arr, sharding): ) self.assertEqual(canonical_array_sharding, sharding) + @chex.variants(with_jit=True, without_jit=True) + @parameterized.product( + placement_name=["clients", "XY"], + axes_type=[ + jax.sharding.AxisType.Auto, + jax.sharding.AxisType.Explicit, + ], + ) def test_broadcast_with_placement_in_mesh(self, placement_name, axes_type): + @self.variant @drjax_program(placements={placement_name: 100}) def broadcast_val(val): return api.broadcast(val) mesh = jax.sharding.Mesh( - np.array(jax.devices()), - axis_names=("some_axis",), - axis_types=(axes_type,), + np.array(jax.devices()).reshape([4, 2]), + axis_names=(placement_name, "some_axis"), + axis_types=(axes_type, axes_type), ) arg_sharding = jax.sharding.NamedSharding( mesh, jax.sharding.PartitionSpec("some_axis") ) - with mesh: + with jax.set_mesh(mesh): result = broadcast_val( jax.device_put(jnp.ones(shape=[8, 8]), arg_sharding) ) @@ -74,11 +76,21 @@ def broadcast_val(val): # No clients dimension in the mesh, we don't lay out the clients along that # nonexistent dimension, but rather replicate them. Notice that we don't # need to specify the sharding to DrJAX; it should be inferred by GSPMD. - expected_result_pspec = jax.sharding.PartitionSpec(None, "some_axis", None) + expected_result_pspec = jax.sharding.PartitionSpec( + placement_name, "some_axis", None + ) self.assertShardingEqual( result, jax.sharding.NamedSharding(mesh, expected_result_pspec) ) + @chex.variants(with_jit=True, without_jit=True) + @parameterized.product( + placement_name=["clients", "XY"], + axes_type=[ + jax.sharding.AxisType.Auto, + jax.sharding.AxisType.Explicit, + ], + ) def test_broadcast_mesh_arg_without_placement( self, placement_name, axes_type ): @@ -88,6 +100,7 @@ def test_broadcast_mesh_arg_without_placement( axis_types=(axes_type,), ) + @self.variant @drjax_program(placements={placement_name: 100}) def broadcast_val(val): return api.broadcast(val, mesh=mesh) @@ -95,17 +108,42 @@ def broadcast_val(val): arg_sharding = jax.sharding.NamedSharding( mesh, jax.sharding.PartitionSpec("some_axis") ) - result = broadcast_val(jax.device_put(jnp.ones(shape=[8, 8]), arg_sharding)) - chex.assert_trees_all_close(result, jnp.ones(shape=[100, 8, 8])) - # No clients dimension in the mesh, we don't lay out the clients along that - # nonexistent dimension, but rather replicate them. Notice that we don't - # need to specify the sharding to DrJAX; it should be inferred by GSPMD. - expected_result_pspec = jax.sharding.PartitionSpec(None, "some_axis", None) - self.assertShardingEqual( - result, jax.sharding.NamedSharding(mesh, expected_result_pspec) - ) + if ( + self.variant.type == chex.ChexVariantType.WITH_JIT + and axes_type == jax.sharding.AxisType.Explicit + ): + with jax.set_mesh(mesh): + with self.assertRaisesRegex( + ValueError, "not found in mesh with explicit axes" + ): + broadcast_val(jax.device_put(jnp.ones(shape=[8, 8]), arg_sharding)) + else: + with jax.set_mesh(mesh): + result = broadcast_val( + jax.device_put(jnp.ones(shape=[8, 8]), arg_sharding) + ) + + chex.assert_trees_all_close(result, jnp.ones(shape=[100, 8, 8])) + # No clients dimension in the mesh, we don't lay out the clients along + # that nonexistent dimension, but rather replicate them. Notice that we + # don't need to specify the sharding to DrJAX; it should be inferred by + # GSPMD. + expected_result_pspec = jax.sharding.PartitionSpec( + None, "some_axis", None + ) + self.assertShardingEqual( + result, jax.sharding.NamedSharding(mesh, expected_result_pspec) + ) + @chex.variants(with_jit=True, without_jit=True) + @parameterized.product( + placement_name=["clients", "XY"], + axes_type=[ + jax.sharding.AxisType.Auto, + jax.sharding.AxisType.Explicit, + ], + ) def test_fully_sharded_broadcast_mesh_arg(self, placement_name, axes_type): mesh = jax.sharding.Mesh( np.array(jax.devices()).reshape([4, 2]), @@ -113,6 +151,7 @@ def test_fully_sharded_broadcast_mesh_arg(self, placement_name, axes_type): axis_types=(axes_type, axes_type), ) + @self.variant @drjax_program(placements={placement_name: 8}) def broadcast_val(val): return api.broadcast(val, mesh=mesh) @@ -121,7 +160,10 @@ def broadcast_val(val): mesh, jax.sharding.PartitionSpec("some_axis") ) - result = broadcast_val(jax.device_put(jnp.ones(shape=[8, 8]), arg_sharding)) + with jax.set_mesh(mesh): + result = broadcast_val( + jax.device_put(jnp.ones(shape=[8, 8]), arg_sharding) + ) chex.assert_trees_all_close(result, jnp.ones(shape=[8, 8, 8])) # The result should be sharded across the placement_name axis. @@ -132,6 +174,14 @@ def broadcast_val(val): result, jax.sharding.NamedSharding(mesh, expected_result_pspec) ) + @chex.variants(with_jit=True, without_jit=True) + @parameterized.product( + placement_name=["clients", "XY"], + axes_type=[ + jax.sharding.AxisType.Auto, + jax.sharding.AxisType.Explicit, + ], + ) def test_temperature_sensors_example(self, placement_name, axes_type): def one_if_over(threshold, value): return jax.lax.cond( @@ -146,27 +196,37 @@ def one_if_over(threshold, value): axis_names=(placement_name, "some_axis"), axis_types=(axes_type, axes_type), ) - jax.set_mesh(mesh) - - @drjax_program(placements={placement_name: placement_dim}) - def temperature_sensors_example(threshold, values): - threshold_at_clients = api.broadcast(threshold) - values_over = api.map_fn(one_if_over, (threshold_at_clients, values)) - return api.reduce_mean(values_over) - - measurements = jax.device_put( - jnp.arange(placement_dim), - jax.sharding.NamedSharding( - mesh, jax.sharding.PartitionSpec(placement_name) - ), - ) + with jax.set_mesh(mesh): + + @self.variant + @drjax_program(placements={placement_name: placement_dim}) + def temperature_sensors_example(threshold, values): + threshold_at_clients = api.broadcast(threshold) + values_over = api.map_fn(one_if_over, (threshold_at_clients, values)) + return api.reduce_mean(values_over) + + measurements = jax.device_put( + jnp.arange(placement_dim, dtype=jnp.float32), + jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec(placement_name) + ), + ) - self.assertEqual(temperature_sensors_example(24, measurements), 0.75) + self.assertEqual(temperature_sensors_example(24, measurements), 0.75) + @chex.variants(with_jit=True, without_jit=True) + @parameterized.product( + placement_name=["clients", "XY"], + axes_type=[ + jax.sharding.AxisType.Auto, + jax.sharding.AxisType.Explicit, + ], + ) def test_temperature_sensors_example_multiple_placement_values( self, placement_name, axes_type ): + @self.variant def one_if_over(threshold, value): return jax.lax.cond( value > threshold, @@ -179,39 +239,135 @@ def one_if_over(threshold, value): axis_names=(placement_name, "some_axis"), axis_types=(axes_type, axes_type), ) - jax.set_mesh(mesh) + with jax.set_mesh(mesh): + + @self.variant + @drjax_program(placements={placement_name: 100}) + def temperature_sensors_example_100_clients(threshold, values): + threshold_at_clients = api.broadcast(threshold) + values_over = api.map_fn(one_if_over, (threshold_at_clients, values)) + return api.reduce_mean(values_over) + + @self.variant + @drjax_program(placements={placement_name: 20}) + def temperature_sensors_example_20_clients(threshold, values): + threshold_at_clients = api.broadcast(threshold) + values_over = api.map_fn(one_if_over, (threshold_at_clients, values)) + return api.reduce_mean(values_over) + + placement_sharding = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec(placement_name) + ) + measurements_100 = jax.device_put( + jnp.arange(100, dtype=jnp.float32), placement_sharding + ) + measurements_20 = jax.device_put( + jnp.arange(20, dtype=jnp.float32), placement_sharding + ) + + self.assertEqual( + temperature_sensors_example_100_clients(24, measurements_100), 0.75 + ) + self.assertEqual( + temperature_sensors_example_20_clients(3, measurements_20), + 0.8, + ) + # We should be able to recover the original result flipping back to the + # original function. + self.assertEqual( + temperature_sensors_example_100_clients(24, measurements_100), 0.75 + ) + + # NOTE: this test only fails when in a jax.jit and with an Explicit mesh, so + # we skip the without_jit variant and the Auto mesh. + @chex.variants(with_jit=True, without_jit=False) + @parameterized.product(placement_name=["clients", "XY"]) + def test_broadcast_raises_error_on_missing_axis_explicit_mesh( + self, placement_name + ): + mesh = jax.sharding.Mesh( + np.array(jax.devices()), + axis_names=("some_axis",), + axis_types=(jax.sharding.AxisType.Explicit,), + ) + @self.variant @drjax_program(placements={placement_name: 100}) - def temperature_sensors_example_100_clients(threshold, values): - threshold_at_clients = api.broadcast(threshold) - values_over = api.map_fn(one_if_over, (threshold_at_clients, values)) - return api.reduce_mean(values_over) - - @drjax_program(placements={placement_name: 20}) - def temperature_sensors_example_20_clients(threshold, values): - threshold_at_clients = api.broadcast(threshold) - values_over = api.map_fn(one_if_over, (threshold_at_clients, values)) - return api.reduce_mean(values_over) - - placement_sharding = jax.sharding.NamedSharding( - mesh, jax.sharding.PartitionSpec(placement_name) + def broadcast_val(val): + return api.broadcast(val) + + arg_sharding = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec("some_axis") ) - measurements_100 = jax.device_put(jnp.arange(100), placement_sharding) - measurements_20 = jax.device_put(jnp.arange(20), placement_sharding) - self.assertEqual( - temperature_sensors_example_100_clients(24, measurements_100), 0.75 + with jax.set_mesh(mesh): + with self.assertRaisesRegex( + ValueError, "not found in mesh with explicit axes" + ): + broadcast_val(jax.device_put(jnp.ones(shape=[8, 8]), arg_sharding)) + + @parameterized.product( + placement_name=["clients", "XY"], + axes_type=[ + jax.sharding.AxisType.Auto, + jax.sharding.AxisType.Explicit, + ], + ) + def test_broadcast_with_placement_mesh_api(self, placement_name, axes_type): + """Verifies that api.broadcast shards along placement when axis in mesh.""" + if axes_type != jax.sharding.AxisType.Explicit: + self.skipTest("Only for Explicit mesh") + + @drjax_program(placements={placement_name: 100}) + def broadcast_val(val): + return api.broadcast(val) + + mesh = jax.sharding.Mesh( + np.array(jax.devices()).reshape([4, 2]), + axis_names=(placement_name, "some_axis"), + axis_types=(axes_type, axes_type), ) + + with jax.set_mesh(mesh): + aval = jax.eval_shape(broadcast_val, jnp.zeros(shape=[10])) + + self.assertEqual(aval.shape, (100, 10)) + self.assertIsInstance(aval.sharding, jax.sharding.NamedSharding) self.assertEqual( - temperature_sensors_example_20_clients(3, measurements_20), - 0.8, + aval.sharding.spec, jax.sharding.PartitionSpec(placement_name, None) ) - # We should be able to recover the original result flipping back to the - # original function. - self.assertEqual( - temperature_sensors_example_100_clients(24, measurements_100), 0.75 + + @parameterized.product( + placement_name=["clients", "XY"], + axes_type=[ + jax.sharding.AxisType.Auto, + jax.sharding.AxisType.Explicit, + ], + ) + def test_broadcast_with_auto_mesh_missing_axis_api( + self, placement_name, axes_type + ): + """Verifies that api.broadcast replicates when axis is missing in auto mesh.""" + if axes_type != jax.sharding.AxisType.Auto: + self.skipTest("Only for Auto mesh") + + @drjax_program(placements={placement_name: 100}) + def broadcast_val(val): + return api.broadcast(val) + + # Create a mesh without the placement axis + mesh = jax.sharding.Mesh( + np.array(jax.devices()), + axis_names=("some_other_axis",), + axis_types=(axes_type,), ) + with jax.set_mesh(mesh): + aval = jax.eval_shape(broadcast_val, jnp.zeros(shape=[10])) + + self.assertEqual(aval.shape, (100, 10)) + self.assertIsNone(aval.sharding) + class ApiErrorsTest(absltest.TestCase): diff --git a/drjax/_src/impls.py b/drjax/_src/impls.py index 4ad412a..3d3bf42 100644 --- a/drjax/_src/impls.py +++ b/drjax/_src/impls.py @@ -186,19 +186,38 @@ def single_arg_broadcast(x): ) elif mesh.are_all_axes_explicit: input_sharding = jax.typeof(x).sharding + requires_reshard = False + if isinstance(input_sharding, jax.sharding.NamedSharding): + in_spec = input_sharding.spec + else: + # Fallback for non-NamedSharding inputs (e.g., + # SingleDeviceSharding). + # jnp.broadcast_to with out_sharding requires the input to have a + # .spec attribute to perform compatibility checks, and fails with + # AttributeError on SingleDeviceSharding. + # jax.lax.broadcast_in_dim with out_sharding fails in eager mode + # with ValueError regarding incompatible devices when moving from + # single device to mesh. + # Thus, the two-step approach (broadcast then reshard) is necessary + # for robustness in eager mode. + in_spec = P() + requires_reshard = True + if _placement_axis_in_mesh(mesh, placement): out_sharding = jax.sharding.NamedSharding( - input_sharding.mesh, P(placement, *input_sharding.spec) + mesh, P(placement, *in_spec) ) else: # With explicit axes, when a placement axis is not in the mesh, # we must ask for replication (`None` sharding). - out_sharding = jax.sharding.NamedSharding( - input_sharding.mesh, P(None, *input_sharding.spec) + out_sharding = jax.sharding.NamedSharding(mesh, P(None, *in_spec)) + if requires_reshard: + result = jnp.broadcast_to(x, (n_elements,) + x.shape) + return jax.sharding.reshard(result, out_sharding) + else: + return jnp.broadcast_to( + x, (n_elements,) + x.shape, out_sharding=out_sharding ) - return jnp.broadcast_to( - x, (n_elements,) + x.shape, out_sharding=out_sharding - ) else: raise ValueError( 'Mesh axis types must all be either auto or manual, but got' @@ -317,14 +336,22 @@ def _constrain_at_placement_with_slices_like(x): result = call_jaxpr(mapped_fn, arg) # Ensure the result is sharded along the placement axis when using # explicit axes. + def _get_target_sharding( + arr: jnp.ndarray, + ) -> jax.sharding.NamedSharding: + arr_sharding = jax.typeof(arr).sharding + if isinstance(arr_sharding, jax.sharding.NamedSharding): + spec = P(placement, *arr_sharding.spec[1:]) + else: + raise NotImplementedError( + f'Unsupported input sharding type: {type(arr_sharding)}. DrJax' + ' requires NamedSharding when using explicit mesh axes.' + ) + return jax.sharding.NamedSharding(mesh, spec) + return jax.sharding.reshard( result, - jax.tree.map( - lambda arr: jax.sharding.NamedSharding( - mesh, spec=P(placement, *jax.typeof(arr).sharding.spec[1:]) - ), - result, - ), + jax.tree.map(_get_target_sharding, result), ) else: raise ValueError( diff --git a/drjax/_src/impls_sharding_test.py b/drjax/_src/impls_sharding_test.py index 1437e81..67b1bd3 100644 --- a/drjax/_src/impls_sharding_test.py +++ b/drjax/_src/impls_sharding_test.py @@ -33,11 +33,11 @@ # and exporting a constant to help make mesh construction easier and guarantees # cleaner. Other constants defined here are intended to help make the operations # of the assertions in the tests below more transparent. -_CLIENTS_AXIS = 'clients' +_CLIENTS_AXIS = "clients" _CLIENTS_AXIS_SIZE = 2 _NUM_CLIENTS = 100 -_DATA_AXIS = 'data' +_DATA_AXIS = "data" _DATA_AXIS_SIZE = 2 _DATA_SIZE = 10 @@ -47,7 +47,7 @@ def create_mesh(mesh_shape, axis_names, axis_types): size = math.prod(mesh_shape) if len(jax.devices()) < size: - raise unittest.SkipTest(f'Test requires {size} global devices.') + raise unittest.SkipTest(f"Test requires {size} global devices.") devices = sorted(jax.devices(), key=lambda d: d.id) mesh_devices = np.array(devices[:size]).reshape(mesh_shape) return jax.sharding.Mesh( @@ -90,6 +90,7 @@ def setUp(self): placements_to_n_elements=self._placements, ) + @chex.variants(with_jit=True, without_jit=True) @parameterized.product( mesh_as_context=[True, False], mesh_axes_type=[AxisType.Auto, AxisType.Explicit], @@ -103,16 +104,20 @@ def test_broadcast_with_1x1_fully_replicates( axis_types=(mesh_axes_type, mesh_axes_type), ) arg = jnp.zeros(shape=[_DATA_SIZE]) + + @self.variant(static_argnums=(1,)) + def _run(arg, mesh): + return self._comp_factory.broadcast_to_placement(arg, _CLIENTS_AXIS, mesh) + with mesh_context(global_mesh, mesh_as_context) as mesh: - result = self._comp_factory.broadcast_to_placement( - arg, _CLIENTS_AXIS, mesh - ) + result = _run(arg, mesh) self.assertEqual(result.shape, (_NUM_CLIENTS, _DATA_SIZE)) sharding = result.sharding # There is only one chip we talk to, so this sharding 'looks' fully # replicated. self.assertTrue(sharding.is_fully_replicated) + @chex.variants(with_jit=True, without_jit=True) def test_broadcast_clients_with_jax_use_mesh(self): global_mesh = create_mesh( [_CLIENTS_AXIS_SIZE, _DATA_AXIS_SIZE], @@ -120,11 +125,16 @@ def test_broadcast_clients_with_jax_use_mesh(self): axis_types=(AxisType.Auto, AxisType.Auto), ) arg = jnp.zeros(shape=[_DATA_SIZE]) - with jax.set_mesh(global_mesh): - result = self._comp_factory.broadcast_to_placement( + + @self.variant + def _run(arg): + return self._comp_factory.broadcast_to_placement( arg, _CLIENTS_AXIS, ) + + with jax.set_mesh(global_mesh): + result = _run(arg) self.assertEqual(result.shape, (_NUM_CLIENTS, _DATA_SIZE)) sharding = result.sharding # If this sharding were fully replicated, we would be *replicating* the data @@ -138,6 +148,7 @@ def test_broadcast_clients_with_jax_use_mesh(self): (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE), ) + @chex.variants(with_jit=True, without_jit=True) @parameterized.product( mesh_as_context=[True, False], mesh_axes_type=[AxisType.Auto, AxisType.Explicit], @@ -151,10 +162,13 @@ def test_broadcast_clients_shards_along_clients( axis_types=(mesh_axes_type, mesh_axes_type), ) arg = jnp.zeros(shape=[_DATA_SIZE]) + + @self.variant(static_argnums=(1,)) + def _run(arg, mesh): + return self._comp_factory.broadcast_to_placement(arg, _CLIENTS_AXIS, mesh) + with mesh_context(global_mesh, mesh_as_context) as mesh: - result = self._comp_factory.broadcast_to_placement( - arg, _CLIENTS_AXIS, mesh - ) + result = _run(arg, mesh) self.assertEqual(result.shape, (_NUM_CLIENTS, _DATA_SIZE)) sharding = result.sharding # If this sharding were fully replicated, we would be *replicating* the data @@ -168,6 +182,7 @@ def test_broadcast_clients_shards_along_clients( (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE), ) + @chex.variants(with_jit=True, without_jit=True) @parameterized.product( mesh_as_context=[True, False], mesh_axes_type=[AxisType.Auto, AxisType.Explicit], @@ -175,13 +190,13 @@ def test_broadcast_clients_shards_along_clients( def test_broadcast_preserves_sharding_with_no_clients_mesh( self, mesh_as_context, mesh_axes_type ): + # Replicating a situation in which the caller's mesh has no clients axis; in + # this case, we should preserve the sharding of any broadcast tensors, but + # not shard along the (nonexistent) clients axis. global_mesh = create_mesh( [_DATA_AXIS_SIZE], axis_names=[_DATA_AXIS], axis_types=(mesh_axes_type,) ) arg = jnp.zeros(shape=[_DATA_SIZE]) - # Replicating a situation in which the caller's mesh has no clients axis; in - # this case, we should preserve the sharding of any broadcast tensors, but - # not shard along the (nonexistent) clients axis. no_mesh_comp_factory = impls.PlacedComputations( placements_to_n_elements=self._placements, ) @@ -189,10 +204,15 @@ def test_broadcast_preserves_sharding_with_no_clients_mesh( sharded_arg = jax.device_put( arg, device=jax.sharding.NamedSharding(global_mesh, arg_spec) ) - with mesh_context(global_mesh, mesh_as_context) as mesh: - result = no_mesh_comp_factory.broadcast_to_placement( - sharded_arg, _CLIENTS_AXIS, mesh + + @self.variant(static_argnums=(1,)) + def _run(arg, mesh): + return no_mesh_comp_factory.broadcast_to_placement( + arg, _CLIENTS_AXIS, mesh ) + + with mesh_context(global_mesh, mesh_as_context) as mesh: + result = _run(sharded_arg, mesh) self.assertEqual(result.shape, (_NUM_CLIENTS, _DATA_SIZE)) sharding = result.sharding self.assertIsInstance(sharding, jax.sharding.NamedSharding) @@ -201,23 +221,21 @@ def test_broadcast_preserves_sharding_with_no_clients_mesh( # The resulting broadcast array should have the same sharding as its input # for the non-injected dimensions, and replication on the clients dimension. self.assertEqual(sharding.spec, PSpec(None, _DATA_AXIS)) - # Here, the clients axis should be replicated on each set of chips; - # however,the data making up the broadcasted array should be split along - # the 'data' dimension; thus only the second dimension of the tensor should - # be split. + # Here, the clients axis should be replicated on each set of chips; however, + # the data making up the broadcasted array should be split along the 'data' + # dimension; thus only the second dimension of the tensor should be split. self.assertEqual( sharding.shard_shape(result.shape), (_NUM_CLIENTS, _DATA_SIZE // _DATA_AXIS_SIZE), ) + @chex.variants(with_jit=True, without_jit=True) @parameterized.product( mesh_as_context=[True, False], mesh_axes_type=[AxisType.Auto, AxisType.Explicit], ) def test_broadcast_preserves_arg_sharding_with_clients_mesh( - self, - mesh_as_context, - mesh_axes_type, + self, mesh_as_context, mesh_axes_type ): global_mesh = create_mesh( [_CLIENTS_AXIS_SIZE, _DATA_AXIS_SIZE], @@ -229,10 +247,13 @@ def test_broadcast_preserves_arg_sharding_with_clients_mesh( sharded_arg = jax.device_put( arg, device=jax.sharding.NamedSharding(global_mesh, arg_spec) ) + + @self.variant(static_argnums=(1,)) + def _run(arg, mesh): + return self._comp_factory.broadcast_to_placement(arg, _CLIENTS_AXIS, mesh) + with mesh_context(global_mesh, mesh_as_context) as mesh: - result = self._comp_factory.broadcast_to_placement( - sharded_arg, _CLIENTS_AXIS, mesh - ) + result = _run(sharded_arg, mesh) self.assertEqual(result.shape, (_NUM_CLIENTS, _DATA_SIZE)) sharding = result.sharding self.assertIsInstance(sharding, jax.sharding.NamedSharding) @@ -245,15 +266,36 @@ def test_broadcast_preserves_arg_sharding_with_clients_mesh( # two sets of chips making up the clients axis of the mesh. However, in this # case the data making up the broadcasted array should *also* be split along # the 'data' dimension; thus each dimension of the global array's shape - # should be cut in half, with sub-arrays of shape (_NUM_CLIENTS // - # _CLIENTS_AXIS_SIZE, - # _DATA_SIZE // _DATA_AXIS_SIZE) living on each of - # the 4 chips. + # should be cut in half, with sub-arrays of shape + # (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE // _DATA_AXIS_SIZE) + # living on each of the 4 chips. self.assertEqual( sharding.shard_shape(result.shape), (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE // _DATA_AXIS_SIZE), ) + @chex.variants(with_jit=True, without_jit=True) + def test_broadcast_with_single_device_sharding(self): + mesh = create_mesh( + [_CLIENTS_AXIS_SIZE, _DATA_AXIS_SIZE], + axis_names=[_CLIENTS_AXIS, _DATA_AXIS], + axis_types=(AxisType.Explicit, AxisType.Explicit), + ) + arg = jnp.zeros(shape=[_DATA_SIZE]) + single_device = jax.devices()[0] + sharded_arg = jax.device_put(arg, single_device) + target_sharding = jax.sharding.NamedSharding(mesh, PSpec()) + sharded_arg = jax.device_put(sharded_arg, target_sharding) + + @self.variant(static_argnums=(1,)) + def _run(arg, mesh): + return self._comp_factory.broadcast_to_placement(arg, _CLIENTS_AXIS, mesh) + + with mesh_context(mesh, True) as m: + result = _run(sharded_arg, m) + self.assertEqual(result.shape, (_NUM_CLIENTS, _DATA_SIZE)) + self.assertIsInstance(result.sharding, jax.sharding.NamedSharding) + @jax.jit def add(x, y): @@ -278,6 +320,7 @@ def setUp(self): placements_to_n_elements=self._placements, ) + @chex.variants(with_jit=True, without_jit=True) @parameterized.product( mesh_as_context=[True, False], mesh_axes_type=[AxisType.Auto, AxisType.Explicit], @@ -293,10 +336,13 @@ def test_map_respects_clients_sharding(self, mesh_as_context, mesh_axes_type): axis_names=[_CLIENTS_AXIS, _DATA_AXIS], axis_types=(mesh_axes_type, mesh_axes_type), ) - with mesh_context(mesh, mesh_as_context) as mesh: - result = self._comp_factory.map_to_placement( - add, (arg1_at_c, arg2_at_c), _CLIENTS_AXIS, mesh - ) + + @self.variant(static_argnums=(1,)) + def _run(args, mesh): + return self._comp_factory.map_to_placement(add, args, _CLIENTS_AXIS, mesh) + + with mesh_context(mesh, mesh_as_context) as mesh_val: + result = _run((arg1_at_c, arg2_at_c), mesh_val) self.assertEqual(result.shape, (_NUM_CLIENTS, _DATA_SIZE)) sharding = result.sharding # The data should be partitioned across chips. @@ -308,6 +354,7 @@ def test_map_respects_clients_sharding(self, mesh_as_context, mesh_axes_type): (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE), ) + @chex.variants(with_jit=True, without_jit=True) @parameterized.product( mesh_as_context=[True, False], mesh_axes_type=[AxisType.Auto, AxisType.Explicit], @@ -324,10 +371,15 @@ def test_map_zeros_like_respects_clients_sharding( axis_names=[_CLIENTS_AXIS, _DATA_AXIS], axis_types=(mesh_axes_type, mesh_axes_type), ) - with mesh_context(mesh, mesh_as_context) as mesh: - result = self._comp_factory.map_to_placement( - jnp.zeros_like, arg_at_c, _CLIENTS_AXIS, mesh + + @self.variant(static_argnums=(1,)) + def _run(args, mesh): + return self._comp_factory.map_to_placement( + jnp.zeros_like, args, _CLIENTS_AXIS, mesh ) + + with mesh_context(mesh, mesh_as_context) as mesh_val: + result = _run(arg_at_c, mesh_val) self.assertEqual(result.shape, (_NUM_CLIENTS, _DATA_SIZE)) sharding = result.sharding # The data should be partitioned across chips. @@ -339,6 +391,7 @@ def test_map_zeros_like_respects_clients_sharding( (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE), ) + @chex.variants(with_jit=True, without_jit=True) @parameterized.product( mesh_as_context=[True, False], mesh_axes_type=[AxisType.Auto, AxisType.Explicit], @@ -365,28 +418,32 @@ def test_map_respects_non_clients_sharding( sharded_arg2, comp_factory=self._comp_factory, ) - with mesh_context(mesh, mesh_as_context) as mesh: - result = self._comp_factory.map_to_placement( - add, (arg1_at_c, arg2_at_c), _CLIENTS_AXIS, mesh - ) + + @self.variant(static_argnums=(1,)) + def _run(args, mesh): + return self._comp_factory.map_to_placement(add, args, _CLIENTS_AXIS, mesh) + + with mesh_context(mesh, mesh_as_context) as mesh_val: + result = _run((arg1_at_c, arg2_at_c), mesh_val) self.assertEqual(result.shape, (_NUM_CLIENTS, _DATA_SIZE)) - # The data should be partitioned across chips. sharding = result.sharding self.assertIsInstance(sharding, jax.sharding.NamedSharding) + # The data should be partitioned across chips. self.assertFalse(sharding.is_fully_replicated) # The resulting array here should be fully sharded, just like the argument, # by computation-follows-data semantics. self.assertEqual(sharding.spec, PSpec(_CLIENTS_AXIS, _DATA_AXIS)) # Since the argument was fully split across the data and clients axes, the # result should be too: each of the 4 chips hosts a sub-array slice of data, - # of shape (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE // - # _DATA_AXIS_SIZE), so that the entire (global) shape is (_NUM_CLIENTS, - # _DATA_SIZE). + # of shape + # (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE // _DATA_AXIS_SIZE) + # so that the entire (global) shape is (_NUM_CLIENTS, _DATA_SIZE). self.assertEqual( sharding.shard_shape(result.shape), (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE // _DATA_AXIS_SIZE), ) + @chex.variants(with_jit=True, without_jit=True) @parameterized.product( mesh_as_context=[True, False], mesh_axes_type=[AxisType.Auto, AxisType.Explicit], @@ -410,10 +467,13 @@ def test_map_forces_clients_sharding_with_model_parallelism( ) sharded_arg1 = jnp.tile(sharded_arg1, reps=[_NUM_CLIENTS, 1]) sharded_arg2 = jnp.tile(sharded_arg2, reps=[_NUM_CLIENTS, 1]) - with mesh_context(mesh, mesh_as_context) as mesh: - result = self._comp_factory.map_to_placement( - add, (sharded_arg1, sharded_arg2), _CLIENTS_AXIS, mesh - ) + + @self.variant(static_argnums=(1,)) + def _run(args, mesh): + return self._comp_factory.map_to_placement(add, args, _CLIENTS_AXIS, mesh) + + with mesh_context(mesh, mesh_as_context) as mesh_val: + result = _run((sharded_arg1, sharded_arg2), mesh_val) # Our arguments should _not_ be sharded across the clients axis. self.assertEqual( @@ -424,10 +484,10 @@ def test_map_forces_clients_sharding_with_model_parallelism( sharded_arg2.sharding.shard_shape(sharded_arg2.shape), (_NUM_CLIENTS, _DATA_SIZE // _DATA_AXIS_SIZE), ) - # But the result should be fully partitioned across chips. self.assertEqual(result.shape, (_NUM_CLIENTS, _DATA_SIZE)) sharding = result.sharding self.assertIsInstance(sharding, jax.sharding.NamedSharding) + # But the result should be fully partitioned across chips. self.assertFalse(sharding.is_fully_replicated) # The resulting array here should be fully sharded, _even though the # argument was not_, because our vmap impl inserts sharding constraints on @@ -435,14 +495,15 @@ def test_map_forces_clients_sharding_with_model_parallelism( self.assertEqual(sharding.spec, PSpec(_CLIENTS_AXIS, _DATA_AXIS)) # Since the argument was fully split across the data and clients axes, the # result should be too: each of the 4 chips hosts a sub-array slice of data, - # of shape (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE // - # _DATA_AXIS_SIZE), so that the entire (global) shape is (_NUM_CLIENTS, - # _DATA_SIZE). + # of shape: + # (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE // _DATA_AXIS_SIZE) + # so that the entire (global) shape is (_NUM_CLIENTS, _DATA_SIZE). self.assertEqual( sharding.shard_shape(result.shape), (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE // _DATA_AXIS_SIZE), ) + @chex.variants(with_jit=True, without_jit=True) @parameterized.product( mesh_as_context=[True, False], mesh_axes_type=[AxisType.Auto, AxisType.Explicit], @@ -476,16 +537,21 @@ def shard_map_add(x, y): ) sharded_arg1 = jnp.tile(sharded_arg1, reps=[_NUM_CLIENTS, 1]) sharded_arg2 = jnp.tile(sharded_arg2, reps=[_NUM_CLIENTS, 1]) - with mesh_context(mesh, mesh_as_context) as mesh: - result = self._comp_factory.map_to_placement( - shard_map_add, (sharded_arg1, sharded_arg2), _CLIENTS_AXIS, mesh + + @self.variant(static_argnums=(1,)) + def _run(args, mesh): + return self._comp_factory.map_to_placement( + shard_map_add, args, _CLIENTS_AXIS, mesh ) - # The result should be fully partitioned across chips, regardless of input - # sharding. + with mesh_context(mesh, mesh_as_context) as mesh_val: + result = _run((sharded_arg1, sharded_arg2), mesh_val) + self.assertEqual(result.shape, (_NUM_CLIENTS, _DATA_SIZE)) sharding = result.sharding self.assertIsInstance(sharding, jax.sharding.NamedSharding) + # The result should be fully partitioned across chips, regardless of input + # sharding. self.assertFalse(sharding.is_fully_replicated) # The resulting array here should be fully sharded, since the fed_map # implementation respects the _CLIENTS_AXIS sharding and the shard_map @@ -493,9 +559,9 @@ def shard_map_add(x, y): self.assertEqual(sharding.spec, PSpec(_CLIENTS_AXIS, _DATA_AXIS)) # Since the argument was fully split across the data and clients axes, the # result should be too: each of the 4 chips hosts a sub-array slice of data, - # of shape (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE // - # _DATA_AXIS_SIZE), so that the entire (global) shape is (_NUM_CLIENTS, - # _DATA_SIZE). + # of shape + # (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE // _DATA_AXIS_SIZE) + # so that the entire (global) shape is (_NUM_CLIENTS, _DATA_SIZE). self.assertEqual( sharding.shard_shape(result.shape), (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE // _DATA_AXIS_SIZE), @@ -507,5 +573,5 @@ def setUpModule(): chex.set_n_cpu_devices(8) -if __name__ == '__main__': +if __name__ == "__main__": absltest.main() diff --git a/drjax/_src/impls_test.py b/drjax/_src/impls_test.py index 30c19b0..5b5318d 100644 --- a/drjax/_src/impls_test.py +++ b/drjax/_src/impls_test.py @@ -39,14 +39,21 @@ def setUp(self): self._placements = {'clients': 100} self._sequence_length = 10 + @chex.variants(with_jit=True, without_jit=True) def test_broadcast_on_float(self): comp_factory = impls.PlacedComputations( placements_to_n_elements=self._placements, ) - actual_output = comp_factory.broadcast_to_placement(0.0, 'clients') + + @self.variant + def _run(): + return comp_factory.broadcast_to_placement(0.0, 'clients') + + actual_output = _run() expected_output = jnp.zeros(shape=[100]) chex.assert_trees_all_equal(actual_output, expected_output) + @chex.variants(with_jit=True, without_jit=True) def test_runs_temp_sens_example(self): comp_factory = impls.PlacedComputations( placements_to_n_elements=self._placements, @@ -54,6 +61,7 @@ def test_runs_temp_sens_example(self): def _one_if_over(x, y): return jax.lax.cond(x > y, lambda: 1.0, lambda: 0.0) + @self.variant def temp_sens_example(m, t): t_at_c = comp_factory.broadcast_to_placement(t, 'clients') total_over = comp_factory.map_to_placement( @@ -62,11 +70,11 @@ def temp_sens_example(m, t): return comp_factory.mean_from_placement(total_over) measurements = jnp.arange(self._placements['clients']) - self.assertEqual( temp_sens_example(measurements, jnp.median(measurements)), 0.5 ) + @chex.variants(with_jit=True, without_jit=True) @parameterized.product( mesh_axes_type=[AxisType.Auto, AxisType.Explicit], ) @@ -82,8 +90,11 @@ def update(model, x): model, x ) + @self.variant def test_training(model, data): - model_at_clients = comp_factory.broadcast_to_placement(model, 'clients') + model_at_clients = comp_factory.broadcast_to_placement( + model, 'clients' + ) grads, _ = comp_factory.map_to_placement( update, (model_at_clients, data), 'clients' ) @@ -93,8 +104,10 @@ def test_training(model, data): jnp.ones(shape=(self._placements['clients'],), dtype=jnp.float32), device=NamedSharding(mesh, PartitionSpec('clients')), ) - model = jax.device_put([0.0], device=NamedSharding(mesh, PartitionSpec())) - self.assertEqual(jax.jit(test_training)(model, clients_data), 0.0) + model = jax.device_put( + [0.0], device=NamedSharding(mesh, PartitionSpec()) + ) + self.assertEqual(test_training(model, clients_data), 0.0) # This allows us to test sharding behavior across multiple devices. diff --git a/drjax/_src/primitives.py b/drjax/_src/primitives.py index b4bab3a..fb7c25f 100644 --- a/drjax/_src/primitives.py +++ b/drjax/_src/primitives.py @@ -91,15 +91,26 @@ def broadcast_abstract_eval( abstract_mesh = ( mesh.abstract_mesh if isinstance(mesh, jax.sharding.Mesh) else mesh ) - sharding_axis = ( - placement_str - if impls._placement_axis_in_mesh(abstract_mesh, placement_str) # pylint: disable=protected-access - else None - ) - new_sharding = xs.sharding.update( - mesh=abstract_mesh, - spec=jax.sharding.PartitionSpec(sharding_axis, *xs.sharding.spec), - ) + if impls._placement_axis_in_mesh(abstract_mesh, placement_str): # pylint: disable=protected-access + sharding_axis = placement_str + else: + if abstract_mesh is not None and any( + t == jax.sharding.AxisType.Explicit for t in abstract_mesh.axis_types + ): + raise ValueError( + f"Placement axis '{placement_str}' not found in mesh with explicit" + ' axes.' + ) + sharding_axis = None + if xs.sharding is not None: + new_sharding = xs.sharding.update( + mesh=abstract_mesh, + spec=jax.sharding.PartitionSpec(sharding_axis, *xs.sharding.spec), + ) + else: + new_sharding = jax.sharding.NamedSharding( + abstract_mesh, jax.sharding.PartitionSpec(sharding_axis) + ) return core.ShapedArray( shape=(n_elements,) + xs.shape, dtype=xs.dtype, diff --git a/drjax/_src/primitives_test.py b/drjax/_src/primitives_test.py index 569a4c1..7716173 100644 --- a/drjax/_src/primitives_test.py +++ b/drjax/_src/primitives_test.py @@ -86,22 +86,18 @@ def setUp(self): {'clients': self._n_clients}, ) + @chex.variants(with_jit=True, without_jit=True) @run_in_mesh((AxisType.Auto, AxisType.Explicit)) def test_broadcast_clients_evaluation(self): - fn = self._primdefs['broadcast_clients'] + fn = self.variant(self._primdefs['broadcast_clients']) # Check that this function is callable. chex.assert_trees_all_close( fn(jnp.array(1.0)), jnp.ones(shape=[self._n_clients]) ) - # Check that it's jittable. - chex.assert_trees_all_close( - jax.jit(fn)(jnp.array(1.0)), jnp.ones(shape=[self._n_clients]) - ) # Check that its forward-diffable. chex.assert_trees_all_close( jax.jacfwd(fn)(jnp.array(1.0)), jnp.ones(shape=[self._n_clients]) ) - # Also that it's reverse-diffable. chex.assert_trees_all_close( jax.jacrev(fn)(jnp.array(1.0)), jnp.ones(shape=[self._n_clients]) @@ -117,18 +113,15 @@ def test_broadcast_clients_closure_under_fad(self): rev_mode_jaxpr = jax.make_jaxpr(jax.jacrev(fn))(jnp.array(1.0)) self.assertTrue(_jaxpr_has_primitive(rev_mode_jaxpr, 'sum_from_clients')) + @chex.variants(with_jit=True, without_jit=True) @run_in_mesh((AxisType.Auto, AxisType.Explicit)) def test_sum_from_clients_evaluation(self): - fn = self._primdefs['sum_from_clients'] + fn = self.variant(self._primdefs['sum_from_clients']) clients_ones = self._primdefs['broadcast_clients'](jnp.ones(shape=[1])) # Check that this function is callable. chex.assert_trees_all_close( fn(clients_ones), jnp.array([1.0 * self._n_clients]) ) - # Check that it's jittable. - chex.assert_trees_all_close( - jax.jit(fn)(clients_ones), jnp.array([1.0 * self._n_clients]) - ) # Check that its forward-diffable. chex.assert_trees_all_close( jax.jacfwd(fn)(clients_ones), jnp.ones(shape=[1, 100, 1]) @@ -138,10 +131,12 @@ def test_sum_from_clients_evaluation(self): jax.jacrev(fn)(clients_ones), jnp.ones(shape=[1, self._n_clients, 1]) ) + @chex.variants(with_jit=True, without_jit=True) @run_in_mesh((AxisType.Auto, AxisType.Explicit)) def test_broadcast_and_sum_from_clients_eval(self): fn = self._primdefs['sum_from_clients'] + @self.variant def _broadcast_then_sum(x): broadcasted_x = self._primdefs['broadcast_clients'](x) return fn(broadcasted_x) @@ -151,7 +146,6 @@ def _broadcast_then_sum(x): jax.jacfwd(_broadcast_then_sum)(jnp.array([1.0])), jnp.array([[1.0 * self._n_clients]]), ) - # And here's reverse-ad. chex.assert_trees_all_close( jax.jacrev(_broadcast_then_sum)(jnp.array([1.0])), @@ -169,24 +163,25 @@ def test_sum_from_clients_closure_under_fad(self): rev_mode_jaxpr = jax.make_jaxpr(jax.jacrev(fn))(clients_ones) self.assertTrue(_jaxpr_has_primitive(rev_mode_jaxpr, 'broadcast_clients')) + @chex.variants(with_jit=True, without_jit=True) @run_in_mesh((AxisType.Auto, AxisType.Explicit)) def test_mean_from_clients_eval(self): - fn = self._primdefs['mean_from_clients'] + fn = self.variant(self._primdefs['mean_from_clients']) clients_ones = jnp.ones(shape=[self._n_clients, 1]) # Check that this function is callable. chex.assert_trees_all_close(fn(clients_ones), jnp.array([1.0])) - # Check that it's jittable. - chex.assert_trees_all_close(jax.jit(fn)(clients_ones), jnp.array([1.0])) # Check that its forward-diffable. chex.assert_trees_all_close( jax.jacfwd(fn)(clients_ones), 1 / self._n_clients * jnp.ones(shape=[1, self._n_clients, 1]), ) + @chex.variants(with_jit=True, without_jit=True) @run_in_mesh((AxisType.Auto, AxisType.Explicit)) def test_broadcast_then_mean_from_clients_eval(self): fn = self._primdefs['mean_from_clients'] + @self.variant def _broadcast_then_sum(x): broadcasted_x = self._primdefs['broadcast_clients'](x) return fn(broadcasted_x) @@ -243,8 +238,8 @@ def duplicate_prim_result(x): @jax.jit def ignore_prim_result(x): - # Ignoring one result from this tuple-returning function triggers - # reverse evaluation with a symbolic zero cotangent argument. + # Ignoring one result from this tuple-returning function triggers reverse + # evaluation with a symbolic zero cotangent argument. y, _ = duplicate_prim_result(x) return y