diff --git a/drjax/_src/api_test.py b/drjax/_src/api_test.py index 3a9298a..bac0d35 100644 --- a/drjax/_src/api_test.py +++ b/drjax/_src/api_test.py @@ -29,21 +29,14 @@ def drjax_program(*, placements): return api.drjax_program(placements=placements, self_module=api) -@parameterized.product( - placement_name=["clients", "XY"], - axes_type=[ - jax.sharding.AxisType.Auto, - jax.sharding.AxisType.Explicit, - ], -) -class ApiTest(absltest.TestCase): +class ApiTest(parameterized.TestCase): def assertShardingEqual(self, arr, sharding): + # Canonicalize with trailing `None`s to the rank of the input array. + # This canonicalizes across Auto and Explicit axis types, the former + # which may not include trailing `None`s. canonical_array_sharding = jax.sharding.NamedSharding( arr.sharding.mesh, - # Canonicalize with trailing `None`s to the rank of the input array. - # This canonicalizes across Auto and Explicit axis types, the former - # which may not include trailing `None`s. jax.sharding.PartitionSpec(*( axis for axis, _ in itertools.zip_longest(arr.sharding.spec, arr.shape) @@ -51,21 +44,30 @@ def assertShardingEqual(self, arr, sharding): ) self.assertEqual(canonical_array_sharding, sharding) + @chex.variants(with_jit=True, without_jit=True) + @parameterized.product( + placement_name=["clients", "XY"], + axes_type=[ + jax.sharding.AxisType.Auto, + jax.sharding.AxisType.Explicit, + ], + ) def test_broadcast_with_placement_in_mesh(self, placement_name, axes_type): + @self.variant @drjax_program(placements={placement_name: 100}) def broadcast_val(val): return api.broadcast(val) mesh = jax.sharding.Mesh( - np.array(jax.devices()), - axis_names=("some_axis",), - axis_types=(axes_type,), + np.array(jax.devices()).reshape([4, 2]), + axis_names=(placement_name, "some_axis"), + axis_types=(axes_type, axes_type), ) arg_sharding = jax.sharding.NamedSharding( mesh, jax.sharding.PartitionSpec("some_axis") ) - with mesh: + with jax.set_mesh(mesh): result = broadcast_val( jax.device_put(jnp.ones(shape=[8, 8]), arg_sharding) ) @@ -74,11 +76,21 @@ def broadcast_val(val): # No clients dimension in the mesh, we don't lay out the clients along that # nonexistent dimension, but rather replicate them. Notice that we don't # need to specify the sharding to DrJAX; it should be inferred by GSPMD. - expected_result_pspec = jax.sharding.PartitionSpec(None, "some_axis", None) + expected_result_pspec = jax.sharding.PartitionSpec( + placement_name, "some_axis", None + ) self.assertShardingEqual( result, jax.sharding.NamedSharding(mesh, expected_result_pspec) ) + @chex.variants(with_jit=True, without_jit=True) + @parameterized.product( + placement_name=["clients", "XY"], + axes_type=[ + jax.sharding.AxisType.Auto, + jax.sharding.AxisType.Explicit, + ], + ) def test_broadcast_mesh_arg_without_placement( self, placement_name, axes_type ): @@ -88,6 +100,7 @@ def test_broadcast_mesh_arg_without_placement( axis_types=(axes_type,), ) + @self.variant @drjax_program(placements={placement_name: 100}) def broadcast_val(val): return api.broadcast(val, mesh=mesh) @@ -95,17 +108,42 @@ def broadcast_val(val): arg_sharding = jax.sharding.NamedSharding( mesh, jax.sharding.PartitionSpec("some_axis") ) - result = broadcast_val(jax.device_put(jnp.ones(shape=[8, 8]), arg_sharding)) - chex.assert_trees_all_close(result, jnp.ones(shape=[100, 8, 8])) - # No clients dimension in the mesh, we don't lay out the clients along that - # nonexistent dimension, but rather replicate them. Notice that we don't - # need to specify the sharding to DrJAX; it should be inferred by GSPMD. - expected_result_pspec = jax.sharding.PartitionSpec(None, "some_axis", None) - self.assertShardingEqual( - result, jax.sharding.NamedSharding(mesh, expected_result_pspec) - ) + if ( + self.variant.type == chex.ChexVariantType.WITH_JIT + and axes_type == jax.sharding.AxisType.Explicit + ): + with jax.set_mesh(mesh): + with self.assertRaisesRegex( + ValueError, "not found in mesh with explicit axes" + ): + broadcast_val(jax.device_put(jnp.ones(shape=[8, 8]), arg_sharding)) + else: + with jax.set_mesh(mesh): + result = broadcast_val( + jax.device_put(jnp.ones(shape=[8, 8]), arg_sharding) + ) + + chex.assert_trees_all_close(result, jnp.ones(shape=[100, 8, 8])) + # No clients dimension in the mesh, we don't lay out the clients along + # that nonexistent dimension, but rather replicate them. Notice that we + # don't need to specify the sharding to DrJAX; it should be inferred by + # GSPMD. + expected_result_pspec = jax.sharding.PartitionSpec( + None, "some_axis", None + ) + self.assertShardingEqual( + result, jax.sharding.NamedSharding(mesh, expected_result_pspec) + ) + @chex.variants(with_jit=True, without_jit=True) + @parameterized.product( + placement_name=["clients", "XY"], + axes_type=[ + jax.sharding.AxisType.Auto, + jax.sharding.AxisType.Explicit, + ], + ) def test_fully_sharded_broadcast_mesh_arg(self, placement_name, axes_type): mesh = jax.sharding.Mesh( np.array(jax.devices()).reshape([4, 2]), @@ -113,6 +151,7 @@ def test_fully_sharded_broadcast_mesh_arg(self, placement_name, axes_type): axis_types=(axes_type, axes_type), ) + @self.variant @drjax_program(placements={placement_name: 8}) def broadcast_val(val): return api.broadcast(val, mesh=mesh) @@ -121,7 +160,10 @@ def broadcast_val(val): mesh, jax.sharding.PartitionSpec("some_axis") ) - result = broadcast_val(jax.device_put(jnp.ones(shape=[8, 8]), arg_sharding)) + with jax.set_mesh(mesh): + result = broadcast_val( + jax.device_put(jnp.ones(shape=[8, 8]), arg_sharding) + ) chex.assert_trees_all_close(result, jnp.ones(shape=[8, 8, 8])) # The result should be sharded across the placement_name axis. @@ -132,6 +174,14 @@ def broadcast_val(val): result, jax.sharding.NamedSharding(mesh, expected_result_pspec) ) + @chex.variants(with_jit=True, without_jit=True) + @parameterized.product( + placement_name=["clients", "XY"], + axes_type=[ + jax.sharding.AxisType.Auto, + jax.sharding.AxisType.Explicit, + ], + ) def test_temperature_sensors_example(self, placement_name, axes_type): def one_if_over(threshold, value): return jax.lax.cond( @@ -146,27 +196,37 @@ def one_if_over(threshold, value): axis_names=(placement_name, "some_axis"), axis_types=(axes_type, axes_type), ) - jax.set_mesh(mesh) - - @drjax_program(placements={placement_name: placement_dim}) - def temperature_sensors_example(threshold, values): - threshold_at_clients = api.broadcast(threshold) - values_over = api.map_fn(one_if_over, (threshold_at_clients, values)) - return api.reduce_mean(values_over) - - measurements = jax.device_put( - jnp.arange(placement_dim), - jax.sharding.NamedSharding( - mesh, jax.sharding.PartitionSpec(placement_name) - ), - ) + with jax.set_mesh(mesh): + + @self.variant + @drjax_program(placements={placement_name: placement_dim}) + def temperature_sensors_example(threshold, values): + threshold_at_clients = api.broadcast(threshold) + values_over = api.map_fn(one_if_over, (threshold_at_clients, values)) + return api.reduce_mean(values_over) + + measurements = jax.device_put( + jnp.arange(placement_dim, dtype=jnp.float32), + jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec(placement_name) + ), + ) - self.assertEqual(temperature_sensors_example(24, measurements), 0.75) + self.assertEqual(temperature_sensors_example(24, measurements), 0.75) + @chex.variants(with_jit=True, without_jit=True) + @parameterized.product( + placement_name=["clients", "XY"], + axes_type=[ + jax.sharding.AxisType.Auto, + jax.sharding.AxisType.Explicit, + ], + ) def test_temperature_sensors_example_multiple_placement_values( self, placement_name, axes_type ): + @self.variant def one_if_over(threshold, value): return jax.lax.cond( value > threshold, @@ -179,39 +239,135 @@ def one_if_over(threshold, value): axis_names=(placement_name, "some_axis"), axis_types=(axes_type, axes_type), ) - jax.set_mesh(mesh) + with jax.set_mesh(mesh): + + @self.variant + @drjax_program(placements={placement_name: 100}) + def temperature_sensors_example_100_clients(threshold, values): + threshold_at_clients = api.broadcast(threshold) + values_over = api.map_fn(one_if_over, (threshold_at_clients, values)) + return api.reduce_mean(values_over) + + @self.variant + @drjax_program(placements={placement_name: 20}) + def temperature_sensors_example_20_clients(threshold, values): + threshold_at_clients = api.broadcast(threshold) + values_over = api.map_fn(one_if_over, (threshold_at_clients, values)) + return api.reduce_mean(values_over) + + placement_sharding = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec(placement_name) + ) + measurements_100 = jax.device_put( + jnp.arange(100, dtype=jnp.float32), placement_sharding + ) + measurements_20 = jax.device_put( + jnp.arange(20, dtype=jnp.float32), placement_sharding + ) + + self.assertEqual( + temperature_sensors_example_100_clients(24, measurements_100), 0.75 + ) + self.assertEqual( + temperature_sensors_example_20_clients(3, measurements_20), + 0.8, + ) + # We should be able to recover the original result flipping back to the + # original function. + self.assertEqual( + temperature_sensors_example_100_clients(24, measurements_100), 0.75 + ) + + # NOTE: this test only fails when in a jax.jit and with an Explicit mesh, so + # we skip the without_jit variant and the Auto mesh. + @chex.variants(with_jit=True, without_jit=False) + @parameterized.product(placement_name=["clients", "XY"]) + def test_broadcast_raises_error_on_missing_axis_explicit_mesh( + self, placement_name + ): + mesh = jax.sharding.Mesh( + np.array(jax.devices()), + axis_names=("some_axis",), + axis_types=(jax.sharding.AxisType.Explicit,), + ) + @self.variant @drjax_program(placements={placement_name: 100}) - def temperature_sensors_example_100_clients(threshold, values): - threshold_at_clients = api.broadcast(threshold) - values_over = api.map_fn(one_if_over, (threshold_at_clients, values)) - return api.reduce_mean(values_over) - - @drjax_program(placements={placement_name: 20}) - def temperature_sensors_example_20_clients(threshold, values): - threshold_at_clients = api.broadcast(threshold) - values_over = api.map_fn(one_if_over, (threshold_at_clients, values)) - return api.reduce_mean(values_over) - - placement_sharding = jax.sharding.NamedSharding( - mesh, jax.sharding.PartitionSpec(placement_name) + def broadcast_val(val): + return api.broadcast(val) + + arg_sharding = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec("some_axis") ) - measurements_100 = jax.device_put(jnp.arange(100), placement_sharding) - measurements_20 = jax.device_put(jnp.arange(20), placement_sharding) - self.assertEqual( - temperature_sensors_example_100_clients(24, measurements_100), 0.75 + with jax.set_mesh(mesh): + with self.assertRaisesRegex( + ValueError, "not found in mesh with explicit axes" + ): + broadcast_val(jax.device_put(jnp.ones(shape=[8, 8]), arg_sharding)) + + @parameterized.product( + placement_name=["clients", "XY"], + axes_type=[ + jax.sharding.AxisType.Auto, + jax.sharding.AxisType.Explicit, + ], + ) + def test_broadcast_with_placement_mesh_api(self, placement_name, axes_type): + """Verifies that api.broadcast shards along placement when axis in mesh.""" + if axes_type != jax.sharding.AxisType.Explicit: + self.skipTest("Only for Explicit mesh") + + @drjax_program(placements={placement_name: 100}) + def broadcast_val(val): + return api.broadcast(val) + + mesh = jax.sharding.Mesh( + np.array(jax.devices()).reshape([4, 2]), + axis_names=(placement_name, "some_axis"), + axis_types=(axes_type, axes_type), ) + + with jax.set_mesh(mesh): + aval = jax.eval_shape(broadcast_val, jnp.zeros(shape=[10])) + + self.assertEqual(aval.shape, (100, 10)) + self.assertIsInstance(aval.sharding, jax.sharding.NamedSharding) self.assertEqual( - temperature_sensors_example_20_clients(3, measurements_20), - 0.8, + aval.sharding.spec, jax.sharding.PartitionSpec(placement_name, None) ) - # We should be able to recover the original result flipping back to the - # original function. - self.assertEqual( - temperature_sensors_example_100_clients(24, measurements_100), 0.75 + + @parameterized.product( + placement_name=["clients", "XY"], + axes_type=[ + jax.sharding.AxisType.Auto, + jax.sharding.AxisType.Explicit, + ], + ) + def test_broadcast_with_auto_mesh_missing_axis_api( + self, placement_name, axes_type + ): + """Verifies that api.broadcast replicates when axis is missing in auto mesh.""" + if axes_type != jax.sharding.AxisType.Auto: + self.skipTest("Only for Auto mesh") + + @drjax_program(placements={placement_name: 100}) + def broadcast_val(val): + return api.broadcast(val) + + # Create a mesh without the placement axis + mesh = jax.sharding.Mesh( + np.array(jax.devices()), + axis_names=("some_other_axis",), + axis_types=(axes_type,), ) + with jax.set_mesh(mesh): + aval = jax.eval_shape(broadcast_val, jnp.zeros(shape=[10])) + + self.assertEqual(aval.shape, (100, 10)) + self.assertIsNone(aval.sharding) + class ApiErrorsTest(absltest.TestCase): diff --git a/drjax/_src/impls.py b/drjax/_src/impls.py index 4ad412a..3d3bf42 100644 --- a/drjax/_src/impls.py +++ b/drjax/_src/impls.py @@ -186,19 +186,38 @@ def single_arg_broadcast(x): ) elif mesh.are_all_axes_explicit: input_sharding = jax.typeof(x).sharding + requires_reshard = False + if isinstance(input_sharding, jax.sharding.NamedSharding): + in_spec = input_sharding.spec + else: + # Fallback for non-NamedSharding inputs (e.g., + # SingleDeviceSharding). + # jnp.broadcast_to with out_sharding requires the input to have a + # .spec attribute to perform compatibility checks, and fails with + # AttributeError on SingleDeviceSharding. + # jax.lax.broadcast_in_dim with out_sharding fails in eager mode + # with ValueError regarding incompatible devices when moving from + # single device to mesh. + # Thus, the two-step approach (broadcast then reshard) is necessary + # for robustness in eager mode. + in_spec = P() + requires_reshard = True + if _placement_axis_in_mesh(mesh, placement): out_sharding = jax.sharding.NamedSharding( - input_sharding.mesh, P(placement, *input_sharding.spec) + mesh, P(placement, *in_spec) ) else: # With explicit axes, when a placement axis is not in the mesh, # we must ask for replication (`None` sharding). - out_sharding = jax.sharding.NamedSharding( - input_sharding.mesh, P(None, *input_sharding.spec) + out_sharding = jax.sharding.NamedSharding(mesh, P(None, *in_spec)) + if requires_reshard: + result = jnp.broadcast_to(x, (n_elements,) + x.shape) + return jax.sharding.reshard(result, out_sharding) + else: + return jnp.broadcast_to( + x, (n_elements,) + x.shape, out_sharding=out_sharding ) - return jnp.broadcast_to( - x, (n_elements,) + x.shape, out_sharding=out_sharding - ) else: raise ValueError( 'Mesh axis types must all be either auto or manual, but got' @@ -317,14 +336,22 @@ def _constrain_at_placement_with_slices_like(x): result = call_jaxpr(mapped_fn, arg) # Ensure the result is sharded along the placement axis when using # explicit axes. + def _get_target_sharding( + arr: jnp.ndarray, + ) -> jax.sharding.NamedSharding: + arr_sharding = jax.typeof(arr).sharding + if isinstance(arr_sharding, jax.sharding.NamedSharding): + spec = P(placement, *arr_sharding.spec[1:]) + else: + raise NotImplementedError( + f'Unsupported input sharding type: {type(arr_sharding)}. DrJax' + ' requires NamedSharding when using explicit mesh axes.' + ) + return jax.sharding.NamedSharding(mesh, spec) + return jax.sharding.reshard( result, - jax.tree.map( - lambda arr: jax.sharding.NamedSharding( - mesh, spec=P(placement, *jax.typeof(arr).sharding.spec[1:]) - ), - result, - ), + jax.tree.map(_get_target_sharding, result), ) else: raise ValueError( diff --git a/drjax/_src/impls_sharding_test.py b/drjax/_src/impls_sharding_test.py index 1437e81..67b1bd3 100644 --- a/drjax/_src/impls_sharding_test.py +++ b/drjax/_src/impls_sharding_test.py @@ -33,11 +33,11 @@ # and exporting a constant to help make mesh construction easier and guarantees # cleaner. Other constants defined here are intended to help make the operations # of the assertions in the tests below more transparent. -_CLIENTS_AXIS = 'clients' +_CLIENTS_AXIS = "clients" _CLIENTS_AXIS_SIZE = 2 _NUM_CLIENTS = 100 -_DATA_AXIS = 'data' +_DATA_AXIS = "data" _DATA_AXIS_SIZE = 2 _DATA_SIZE = 10 @@ -47,7 +47,7 @@ def create_mesh(mesh_shape, axis_names, axis_types): size = math.prod(mesh_shape) if len(jax.devices()) < size: - raise unittest.SkipTest(f'Test requires {size} global devices.') + raise unittest.SkipTest(f"Test requires {size} global devices.") devices = sorted(jax.devices(), key=lambda d: d.id) mesh_devices = np.array(devices[:size]).reshape(mesh_shape) return jax.sharding.Mesh( @@ -90,6 +90,7 @@ def setUp(self): placements_to_n_elements=self._placements, ) + @chex.variants(with_jit=True, without_jit=True) @parameterized.product( mesh_as_context=[True, False], mesh_axes_type=[AxisType.Auto, AxisType.Explicit], @@ -103,16 +104,20 @@ def test_broadcast_with_1x1_fully_replicates( axis_types=(mesh_axes_type, mesh_axes_type), ) arg = jnp.zeros(shape=[_DATA_SIZE]) + + @self.variant(static_argnums=(1,)) + def _run(arg, mesh): + return self._comp_factory.broadcast_to_placement(arg, _CLIENTS_AXIS, mesh) + with mesh_context(global_mesh, mesh_as_context) as mesh: - result = self._comp_factory.broadcast_to_placement( - arg, _CLIENTS_AXIS, mesh - ) + result = _run(arg, mesh) self.assertEqual(result.shape, (_NUM_CLIENTS, _DATA_SIZE)) sharding = result.sharding # There is only one chip we talk to, so this sharding 'looks' fully # replicated. self.assertTrue(sharding.is_fully_replicated) + @chex.variants(with_jit=True, without_jit=True) def test_broadcast_clients_with_jax_use_mesh(self): global_mesh = create_mesh( [_CLIENTS_AXIS_SIZE, _DATA_AXIS_SIZE], @@ -120,11 +125,16 @@ def test_broadcast_clients_with_jax_use_mesh(self): axis_types=(AxisType.Auto, AxisType.Auto), ) arg = jnp.zeros(shape=[_DATA_SIZE]) - with jax.set_mesh(global_mesh): - result = self._comp_factory.broadcast_to_placement( + + @self.variant + def _run(arg): + return self._comp_factory.broadcast_to_placement( arg, _CLIENTS_AXIS, ) + + with jax.set_mesh(global_mesh): + result = _run(arg) self.assertEqual(result.shape, (_NUM_CLIENTS, _DATA_SIZE)) sharding = result.sharding # If this sharding were fully replicated, we would be *replicating* the data @@ -138,6 +148,7 @@ def test_broadcast_clients_with_jax_use_mesh(self): (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE), ) + @chex.variants(with_jit=True, without_jit=True) @parameterized.product( mesh_as_context=[True, False], mesh_axes_type=[AxisType.Auto, AxisType.Explicit], @@ -151,10 +162,13 @@ def test_broadcast_clients_shards_along_clients( axis_types=(mesh_axes_type, mesh_axes_type), ) arg = jnp.zeros(shape=[_DATA_SIZE]) + + @self.variant(static_argnums=(1,)) + def _run(arg, mesh): + return self._comp_factory.broadcast_to_placement(arg, _CLIENTS_AXIS, mesh) + with mesh_context(global_mesh, mesh_as_context) as mesh: - result = self._comp_factory.broadcast_to_placement( - arg, _CLIENTS_AXIS, mesh - ) + result = _run(arg, mesh) self.assertEqual(result.shape, (_NUM_CLIENTS, _DATA_SIZE)) sharding = result.sharding # If this sharding were fully replicated, we would be *replicating* the data @@ -168,6 +182,7 @@ def test_broadcast_clients_shards_along_clients( (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE), ) + @chex.variants(with_jit=True, without_jit=True) @parameterized.product( mesh_as_context=[True, False], mesh_axes_type=[AxisType.Auto, AxisType.Explicit], @@ -175,13 +190,13 @@ def test_broadcast_clients_shards_along_clients( def test_broadcast_preserves_sharding_with_no_clients_mesh( self, mesh_as_context, mesh_axes_type ): + # Replicating a situation in which the caller's mesh has no clients axis; in + # this case, we should preserve the sharding of any broadcast tensors, but + # not shard along the (nonexistent) clients axis. global_mesh = create_mesh( [_DATA_AXIS_SIZE], axis_names=[_DATA_AXIS], axis_types=(mesh_axes_type,) ) arg = jnp.zeros(shape=[_DATA_SIZE]) - # Replicating a situation in which the caller's mesh has no clients axis; in - # this case, we should preserve the sharding of any broadcast tensors, but - # not shard along the (nonexistent) clients axis. no_mesh_comp_factory = impls.PlacedComputations( placements_to_n_elements=self._placements, ) @@ -189,10 +204,15 @@ def test_broadcast_preserves_sharding_with_no_clients_mesh( sharded_arg = jax.device_put( arg, device=jax.sharding.NamedSharding(global_mesh, arg_spec) ) - with mesh_context(global_mesh, mesh_as_context) as mesh: - result = no_mesh_comp_factory.broadcast_to_placement( - sharded_arg, _CLIENTS_AXIS, mesh + + @self.variant(static_argnums=(1,)) + def _run(arg, mesh): + return no_mesh_comp_factory.broadcast_to_placement( + arg, _CLIENTS_AXIS, mesh ) + + with mesh_context(global_mesh, mesh_as_context) as mesh: + result = _run(sharded_arg, mesh) self.assertEqual(result.shape, (_NUM_CLIENTS, _DATA_SIZE)) sharding = result.sharding self.assertIsInstance(sharding, jax.sharding.NamedSharding) @@ -201,23 +221,21 @@ def test_broadcast_preserves_sharding_with_no_clients_mesh( # The resulting broadcast array should have the same sharding as its input # for the non-injected dimensions, and replication on the clients dimension. self.assertEqual(sharding.spec, PSpec(None, _DATA_AXIS)) - # Here, the clients axis should be replicated on each set of chips; - # however,the data making up the broadcasted array should be split along - # the 'data' dimension; thus only the second dimension of the tensor should - # be split. + # Here, the clients axis should be replicated on each set of chips; however, + # the data making up the broadcasted array should be split along the 'data' + # dimension; thus only the second dimension of the tensor should be split. self.assertEqual( sharding.shard_shape(result.shape), (_NUM_CLIENTS, _DATA_SIZE // _DATA_AXIS_SIZE), ) + @chex.variants(with_jit=True, without_jit=True) @parameterized.product( mesh_as_context=[True, False], mesh_axes_type=[AxisType.Auto, AxisType.Explicit], ) def test_broadcast_preserves_arg_sharding_with_clients_mesh( - self, - mesh_as_context, - mesh_axes_type, + self, mesh_as_context, mesh_axes_type ): global_mesh = create_mesh( [_CLIENTS_AXIS_SIZE, _DATA_AXIS_SIZE], @@ -229,10 +247,13 @@ def test_broadcast_preserves_arg_sharding_with_clients_mesh( sharded_arg = jax.device_put( arg, device=jax.sharding.NamedSharding(global_mesh, arg_spec) ) + + @self.variant(static_argnums=(1,)) + def _run(arg, mesh): + return self._comp_factory.broadcast_to_placement(arg, _CLIENTS_AXIS, mesh) + with mesh_context(global_mesh, mesh_as_context) as mesh: - result = self._comp_factory.broadcast_to_placement( - sharded_arg, _CLIENTS_AXIS, mesh - ) + result = _run(sharded_arg, mesh) self.assertEqual(result.shape, (_NUM_CLIENTS, _DATA_SIZE)) sharding = result.sharding self.assertIsInstance(sharding, jax.sharding.NamedSharding) @@ -245,15 +266,36 @@ def test_broadcast_preserves_arg_sharding_with_clients_mesh( # two sets of chips making up the clients axis of the mesh. However, in this # case the data making up the broadcasted array should *also* be split along # the 'data' dimension; thus each dimension of the global array's shape - # should be cut in half, with sub-arrays of shape (_NUM_CLIENTS // - # _CLIENTS_AXIS_SIZE, - # _DATA_SIZE // _DATA_AXIS_SIZE) living on each of - # the 4 chips. + # should be cut in half, with sub-arrays of shape + # (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE // _DATA_AXIS_SIZE) + # living on each of the 4 chips. self.assertEqual( sharding.shard_shape(result.shape), (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE // _DATA_AXIS_SIZE), ) + @chex.variants(with_jit=True, without_jit=True) + def test_broadcast_with_single_device_sharding(self): + mesh = create_mesh( + [_CLIENTS_AXIS_SIZE, _DATA_AXIS_SIZE], + axis_names=[_CLIENTS_AXIS, _DATA_AXIS], + axis_types=(AxisType.Explicit, AxisType.Explicit), + ) + arg = jnp.zeros(shape=[_DATA_SIZE]) + single_device = jax.devices()[0] + sharded_arg = jax.device_put(arg, single_device) + target_sharding = jax.sharding.NamedSharding(mesh, PSpec()) + sharded_arg = jax.device_put(sharded_arg, target_sharding) + + @self.variant(static_argnums=(1,)) + def _run(arg, mesh): + return self._comp_factory.broadcast_to_placement(arg, _CLIENTS_AXIS, mesh) + + with mesh_context(mesh, True) as m: + result = _run(sharded_arg, m) + self.assertEqual(result.shape, (_NUM_CLIENTS, _DATA_SIZE)) + self.assertIsInstance(result.sharding, jax.sharding.NamedSharding) + @jax.jit def add(x, y): @@ -278,6 +320,7 @@ def setUp(self): placements_to_n_elements=self._placements, ) + @chex.variants(with_jit=True, without_jit=True) @parameterized.product( mesh_as_context=[True, False], mesh_axes_type=[AxisType.Auto, AxisType.Explicit], @@ -293,10 +336,13 @@ def test_map_respects_clients_sharding(self, mesh_as_context, mesh_axes_type): axis_names=[_CLIENTS_AXIS, _DATA_AXIS], axis_types=(mesh_axes_type, mesh_axes_type), ) - with mesh_context(mesh, mesh_as_context) as mesh: - result = self._comp_factory.map_to_placement( - add, (arg1_at_c, arg2_at_c), _CLIENTS_AXIS, mesh - ) + + @self.variant(static_argnums=(1,)) + def _run(args, mesh): + return self._comp_factory.map_to_placement(add, args, _CLIENTS_AXIS, mesh) + + with mesh_context(mesh, mesh_as_context) as mesh_val: + result = _run((arg1_at_c, arg2_at_c), mesh_val) self.assertEqual(result.shape, (_NUM_CLIENTS, _DATA_SIZE)) sharding = result.sharding # The data should be partitioned across chips. @@ -308,6 +354,7 @@ def test_map_respects_clients_sharding(self, mesh_as_context, mesh_axes_type): (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE), ) + @chex.variants(with_jit=True, without_jit=True) @parameterized.product( mesh_as_context=[True, False], mesh_axes_type=[AxisType.Auto, AxisType.Explicit], @@ -324,10 +371,15 @@ def test_map_zeros_like_respects_clients_sharding( axis_names=[_CLIENTS_AXIS, _DATA_AXIS], axis_types=(mesh_axes_type, mesh_axes_type), ) - with mesh_context(mesh, mesh_as_context) as mesh: - result = self._comp_factory.map_to_placement( - jnp.zeros_like, arg_at_c, _CLIENTS_AXIS, mesh + + @self.variant(static_argnums=(1,)) + def _run(args, mesh): + return self._comp_factory.map_to_placement( + jnp.zeros_like, args, _CLIENTS_AXIS, mesh ) + + with mesh_context(mesh, mesh_as_context) as mesh_val: + result = _run(arg_at_c, mesh_val) self.assertEqual(result.shape, (_NUM_CLIENTS, _DATA_SIZE)) sharding = result.sharding # The data should be partitioned across chips. @@ -339,6 +391,7 @@ def test_map_zeros_like_respects_clients_sharding( (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE), ) + @chex.variants(with_jit=True, without_jit=True) @parameterized.product( mesh_as_context=[True, False], mesh_axes_type=[AxisType.Auto, AxisType.Explicit], @@ -365,28 +418,32 @@ def test_map_respects_non_clients_sharding( sharded_arg2, comp_factory=self._comp_factory, ) - with mesh_context(mesh, mesh_as_context) as mesh: - result = self._comp_factory.map_to_placement( - add, (arg1_at_c, arg2_at_c), _CLIENTS_AXIS, mesh - ) + + @self.variant(static_argnums=(1,)) + def _run(args, mesh): + return self._comp_factory.map_to_placement(add, args, _CLIENTS_AXIS, mesh) + + with mesh_context(mesh, mesh_as_context) as mesh_val: + result = _run((arg1_at_c, arg2_at_c), mesh_val) self.assertEqual(result.shape, (_NUM_CLIENTS, _DATA_SIZE)) - # The data should be partitioned across chips. sharding = result.sharding self.assertIsInstance(sharding, jax.sharding.NamedSharding) + # The data should be partitioned across chips. self.assertFalse(sharding.is_fully_replicated) # The resulting array here should be fully sharded, just like the argument, # by computation-follows-data semantics. self.assertEqual(sharding.spec, PSpec(_CLIENTS_AXIS, _DATA_AXIS)) # Since the argument was fully split across the data and clients axes, the # result should be too: each of the 4 chips hosts a sub-array slice of data, - # of shape (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE // - # _DATA_AXIS_SIZE), so that the entire (global) shape is (_NUM_CLIENTS, - # _DATA_SIZE). + # of shape + # (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE // _DATA_AXIS_SIZE) + # so that the entire (global) shape is (_NUM_CLIENTS, _DATA_SIZE). self.assertEqual( sharding.shard_shape(result.shape), (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE // _DATA_AXIS_SIZE), ) + @chex.variants(with_jit=True, without_jit=True) @parameterized.product( mesh_as_context=[True, False], mesh_axes_type=[AxisType.Auto, AxisType.Explicit], @@ -410,10 +467,13 @@ def test_map_forces_clients_sharding_with_model_parallelism( ) sharded_arg1 = jnp.tile(sharded_arg1, reps=[_NUM_CLIENTS, 1]) sharded_arg2 = jnp.tile(sharded_arg2, reps=[_NUM_CLIENTS, 1]) - with mesh_context(mesh, mesh_as_context) as mesh: - result = self._comp_factory.map_to_placement( - add, (sharded_arg1, sharded_arg2), _CLIENTS_AXIS, mesh - ) + + @self.variant(static_argnums=(1,)) + def _run(args, mesh): + return self._comp_factory.map_to_placement(add, args, _CLIENTS_AXIS, mesh) + + with mesh_context(mesh, mesh_as_context) as mesh_val: + result = _run((sharded_arg1, sharded_arg2), mesh_val) # Our arguments should _not_ be sharded across the clients axis. self.assertEqual( @@ -424,10 +484,10 @@ def test_map_forces_clients_sharding_with_model_parallelism( sharded_arg2.sharding.shard_shape(sharded_arg2.shape), (_NUM_CLIENTS, _DATA_SIZE // _DATA_AXIS_SIZE), ) - # But the result should be fully partitioned across chips. self.assertEqual(result.shape, (_NUM_CLIENTS, _DATA_SIZE)) sharding = result.sharding self.assertIsInstance(sharding, jax.sharding.NamedSharding) + # But the result should be fully partitioned across chips. self.assertFalse(sharding.is_fully_replicated) # The resulting array here should be fully sharded, _even though the # argument was not_, because our vmap impl inserts sharding constraints on @@ -435,14 +495,15 @@ def test_map_forces_clients_sharding_with_model_parallelism( self.assertEqual(sharding.spec, PSpec(_CLIENTS_AXIS, _DATA_AXIS)) # Since the argument was fully split across the data and clients axes, the # result should be too: each of the 4 chips hosts a sub-array slice of data, - # of shape (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE // - # _DATA_AXIS_SIZE), so that the entire (global) shape is (_NUM_CLIENTS, - # _DATA_SIZE). + # of shape: + # (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE // _DATA_AXIS_SIZE) + # so that the entire (global) shape is (_NUM_CLIENTS, _DATA_SIZE). self.assertEqual( sharding.shard_shape(result.shape), (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE // _DATA_AXIS_SIZE), ) + @chex.variants(with_jit=True, without_jit=True) @parameterized.product( mesh_as_context=[True, False], mesh_axes_type=[AxisType.Auto, AxisType.Explicit], @@ -476,16 +537,21 @@ def shard_map_add(x, y): ) sharded_arg1 = jnp.tile(sharded_arg1, reps=[_NUM_CLIENTS, 1]) sharded_arg2 = jnp.tile(sharded_arg2, reps=[_NUM_CLIENTS, 1]) - with mesh_context(mesh, mesh_as_context) as mesh: - result = self._comp_factory.map_to_placement( - shard_map_add, (sharded_arg1, sharded_arg2), _CLIENTS_AXIS, mesh + + @self.variant(static_argnums=(1,)) + def _run(args, mesh): + return self._comp_factory.map_to_placement( + shard_map_add, args, _CLIENTS_AXIS, mesh ) - # The result should be fully partitioned across chips, regardless of input - # sharding. + with mesh_context(mesh, mesh_as_context) as mesh_val: + result = _run((sharded_arg1, sharded_arg2), mesh_val) + self.assertEqual(result.shape, (_NUM_CLIENTS, _DATA_SIZE)) sharding = result.sharding self.assertIsInstance(sharding, jax.sharding.NamedSharding) + # The result should be fully partitioned across chips, regardless of input + # sharding. self.assertFalse(sharding.is_fully_replicated) # The resulting array here should be fully sharded, since the fed_map # implementation respects the _CLIENTS_AXIS sharding and the shard_map @@ -493,9 +559,9 @@ def shard_map_add(x, y): self.assertEqual(sharding.spec, PSpec(_CLIENTS_AXIS, _DATA_AXIS)) # Since the argument was fully split across the data and clients axes, the # result should be too: each of the 4 chips hosts a sub-array slice of data, - # of shape (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE // - # _DATA_AXIS_SIZE), so that the entire (global) shape is (_NUM_CLIENTS, - # _DATA_SIZE). + # of shape + # (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE // _DATA_AXIS_SIZE) + # so that the entire (global) shape is (_NUM_CLIENTS, _DATA_SIZE). self.assertEqual( sharding.shard_shape(result.shape), (_NUM_CLIENTS // _CLIENTS_AXIS_SIZE, _DATA_SIZE // _DATA_AXIS_SIZE), @@ -507,5 +573,5 @@ def setUpModule(): chex.set_n_cpu_devices(8) -if __name__ == '__main__': +if __name__ == "__main__": absltest.main() diff --git a/drjax/_src/impls_test.py b/drjax/_src/impls_test.py index 30c19b0..5b5318d 100644 --- a/drjax/_src/impls_test.py +++ b/drjax/_src/impls_test.py @@ -39,14 +39,21 @@ def setUp(self): self._placements = {'clients': 100} self._sequence_length = 10 + @chex.variants(with_jit=True, without_jit=True) def test_broadcast_on_float(self): comp_factory = impls.PlacedComputations( placements_to_n_elements=self._placements, ) - actual_output = comp_factory.broadcast_to_placement(0.0, 'clients') + + @self.variant + def _run(): + return comp_factory.broadcast_to_placement(0.0, 'clients') + + actual_output = _run() expected_output = jnp.zeros(shape=[100]) chex.assert_trees_all_equal(actual_output, expected_output) + @chex.variants(with_jit=True, without_jit=True) def test_runs_temp_sens_example(self): comp_factory = impls.PlacedComputations( placements_to_n_elements=self._placements, @@ -54,6 +61,7 @@ def test_runs_temp_sens_example(self): def _one_if_over(x, y): return jax.lax.cond(x > y, lambda: 1.0, lambda: 0.0) + @self.variant def temp_sens_example(m, t): t_at_c = comp_factory.broadcast_to_placement(t, 'clients') total_over = comp_factory.map_to_placement( @@ -62,11 +70,11 @@ def temp_sens_example(m, t): return comp_factory.mean_from_placement(total_over) measurements = jnp.arange(self._placements['clients']) - self.assertEqual( temp_sens_example(measurements, jnp.median(measurements)), 0.5 ) + @chex.variants(with_jit=True, without_jit=True) @parameterized.product( mesh_axes_type=[AxisType.Auto, AxisType.Explicit], ) @@ -82,8 +90,11 @@ def update(model, x): model, x ) + @self.variant def test_training(model, data): - model_at_clients = comp_factory.broadcast_to_placement(model, 'clients') + model_at_clients = comp_factory.broadcast_to_placement( + model, 'clients' + ) grads, _ = comp_factory.map_to_placement( update, (model_at_clients, data), 'clients' ) @@ -93,8 +104,10 @@ def test_training(model, data): jnp.ones(shape=(self._placements['clients'],), dtype=jnp.float32), device=NamedSharding(mesh, PartitionSpec('clients')), ) - model = jax.device_put([0.0], device=NamedSharding(mesh, PartitionSpec())) - self.assertEqual(jax.jit(test_training)(model, clients_data), 0.0) + model = jax.device_put( + [0.0], device=NamedSharding(mesh, PartitionSpec()) + ) + self.assertEqual(test_training(model, clients_data), 0.0) # This allows us to test sharding behavior across multiple devices. diff --git a/drjax/_src/primitives.py b/drjax/_src/primitives.py index b4bab3a..fb7c25f 100644 --- a/drjax/_src/primitives.py +++ b/drjax/_src/primitives.py @@ -91,15 +91,26 @@ def broadcast_abstract_eval( abstract_mesh = ( mesh.abstract_mesh if isinstance(mesh, jax.sharding.Mesh) else mesh ) - sharding_axis = ( - placement_str - if impls._placement_axis_in_mesh(abstract_mesh, placement_str) # pylint: disable=protected-access - else None - ) - new_sharding = xs.sharding.update( - mesh=abstract_mesh, - spec=jax.sharding.PartitionSpec(sharding_axis, *xs.sharding.spec), - ) + if impls._placement_axis_in_mesh(abstract_mesh, placement_str): # pylint: disable=protected-access + sharding_axis = placement_str + else: + if abstract_mesh is not None and any( + t == jax.sharding.AxisType.Explicit for t in abstract_mesh.axis_types + ): + raise ValueError( + f"Placement axis '{placement_str}' not found in mesh with explicit" + ' axes.' + ) + sharding_axis = None + if xs.sharding is not None: + new_sharding = xs.sharding.update( + mesh=abstract_mesh, + spec=jax.sharding.PartitionSpec(sharding_axis, *xs.sharding.spec), + ) + else: + new_sharding = jax.sharding.NamedSharding( + abstract_mesh, jax.sharding.PartitionSpec(sharding_axis) + ) return core.ShapedArray( shape=(n_elements,) + xs.shape, dtype=xs.dtype, diff --git a/drjax/_src/primitives_test.py b/drjax/_src/primitives_test.py index 569a4c1..7716173 100644 --- a/drjax/_src/primitives_test.py +++ b/drjax/_src/primitives_test.py @@ -86,22 +86,18 @@ def setUp(self): {'clients': self._n_clients}, ) + @chex.variants(with_jit=True, without_jit=True) @run_in_mesh((AxisType.Auto, AxisType.Explicit)) def test_broadcast_clients_evaluation(self): - fn = self._primdefs['broadcast_clients'] + fn = self.variant(self._primdefs['broadcast_clients']) # Check that this function is callable. chex.assert_trees_all_close( fn(jnp.array(1.0)), jnp.ones(shape=[self._n_clients]) ) - # Check that it's jittable. - chex.assert_trees_all_close( - jax.jit(fn)(jnp.array(1.0)), jnp.ones(shape=[self._n_clients]) - ) # Check that its forward-diffable. chex.assert_trees_all_close( jax.jacfwd(fn)(jnp.array(1.0)), jnp.ones(shape=[self._n_clients]) ) - # Also that it's reverse-diffable. chex.assert_trees_all_close( jax.jacrev(fn)(jnp.array(1.0)), jnp.ones(shape=[self._n_clients]) @@ -117,18 +113,15 @@ def test_broadcast_clients_closure_under_fad(self): rev_mode_jaxpr = jax.make_jaxpr(jax.jacrev(fn))(jnp.array(1.0)) self.assertTrue(_jaxpr_has_primitive(rev_mode_jaxpr, 'sum_from_clients')) + @chex.variants(with_jit=True, without_jit=True) @run_in_mesh((AxisType.Auto, AxisType.Explicit)) def test_sum_from_clients_evaluation(self): - fn = self._primdefs['sum_from_clients'] + fn = self.variant(self._primdefs['sum_from_clients']) clients_ones = self._primdefs['broadcast_clients'](jnp.ones(shape=[1])) # Check that this function is callable. chex.assert_trees_all_close( fn(clients_ones), jnp.array([1.0 * self._n_clients]) ) - # Check that it's jittable. - chex.assert_trees_all_close( - jax.jit(fn)(clients_ones), jnp.array([1.0 * self._n_clients]) - ) # Check that its forward-diffable. chex.assert_trees_all_close( jax.jacfwd(fn)(clients_ones), jnp.ones(shape=[1, 100, 1]) @@ -138,10 +131,12 @@ def test_sum_from_clients_evaluation(self): jax.jacrev(fn)(clients_ones), jnp.ones(shape=[1, self._n_clients, 1]) ) + @chex.variants(with_jit=True, without_jit=True) @run_in_mesh((AxisType.Auto, AxisType.Explicit)) def test_broadcast_and_sum_from_clients_eval(self): fn = self._primdefs['sum_from_clients'] + @self.variant def _broadcast_then_sum(x): broadcasted_x = self._primdefs['broadcast_clients'](x) return fn(broadcasted_x) @@ -151,7 +146,6 @@ def _broadcast_then_sum(x): jax.jacfwd(_broadcast_then_sum)(jnp.array([1.0])), jnp.array([[1.0 * self._n_clients]]), ) - # And here's reverse-ad. chex.assert_trees_all_close( jax.jacrev(_broadcast_then_sum)(jnp.array([1.0])), @@ -169,24 +163,25 @@ def test_sum_from_clients_closure_under_fad(self): rev_mode_jaxpr = jax.make_jaxpr(jax.jacrev(fn))(clients_ones) self.assertTrue(_jaxpr_has_primitive(rev_mode_jaxpr, 'broadcast_clients')) + @chex.variants(with_jit=True, without_jit=True) @run_in_mesh((AxisType.Auto, AxisType.Explicit)) def test_mean_from_clients_eval(self): - fn = self._primdefs['mean_from_clients'] + fn = self.variant(self._primdefs['mean_from_clients']) clients_ones = jnp.ones(shape=[self._n_clients, 1]) # Check that this function is callable. chex.assert_trees_all_close(fn(clients_ones), jnp.array([1.0])) - # Check that it's jittable. - chex.assert_trees_all_close(jax.jit(fn)(clients_ones), jnp.array([1.0])) # Check that its forward-diffable. chex.assert_trees_all_close( jax.jacfwd(fn)(clients_ones), 1 / self._n_clients * jnp.ones(shape=[1, self._n_clients, 1]), ) + @chex.variants(with_jit=True, without_jit=True) @run_in_mesh((AxisType.Auto, AxisType.Explicit)) def test_broadcast_then_mean_from_clients_eval(self): fn = self._primdefs['mean_from_clients'] + @self.variant def _broadcast_then_sum(x): broadcasted_x = self._primdefs['broadcast_clients'](x) return fn(broadcasted_x) @@ -243,8 +238,8 @@ def duplicate_prim_result(x): @jax.jit def ignore_prim_result(x): - # Ignoring one result from this tuple-returning function triggers - # reverse evaluation with a symbolic zero cotangent argument. + # Ignoring one result from this tuple-returning function triggers reverse + # evaluation with a symbolic zero cotangent argument. y, _ = duplicate_prim_result(x) return y