Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 221 additions & 65 deletions drjax/_src/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,43 +29,45 @@ def drjax_program(*, placements):
return api.drjax_program(placements=placements, self_module=api)


@parameterized.product(
placement_name=["clients", "XY"],
axes_type=[
jax.sharding.AxisType.Auto,
jax.sharding.AxisType.Explicit,
],
)
class ApiTest(absltest.TestCase):
class ApiTest(parameterized.TestCase):

def assertShardingEqual(self, arr, sharding):
# Canonicalize with trailing `None`s to the rank of the input array.
# This canonicalizes across Auto and Explicit axis types, the former
# which may not include trailing `None`s.
canonical_array_sharding = jax.sharding.NamedSharding(
arr.sharding.mesh,
# Canonicalize with trailing `None`s to the rank of the input array.
# This canonicalizes across Auto and Explicit axis types, the former
# which may not include trailing `None`s.
jax.sharding.PartitionSpec(*(
axis
for axis, _ in itertools.zip_longest(arr.sharding.spec, arr.shape)
)),
)
self.assertEqual(canonical_array_sharding, sharding)

@chex.variants(with_jit=True, without_jit=True)
@parameterized.product(
placement_name=["clients", "XY"],
axes_type=[
jax.sharding.AxisType.Auto,
jax.sharding.AxisType.Explicit,
],
)
def test_broadcast_with_placement_in_mesh(self, placement_name, axes_type):

@self.variant
@drjax_program(placements={placement_name: 100})
def broadcast_val(val):
return api.broadcast(val)

mesh = jax.sharding.Mesh(
np.array(jax.devices()),
axis_names=("some_axis",),
axis_types=(axes_type,),
np.array(jax.devices()).reshape([4, 2]),
axis_names=(placement_name, "some_axis"),
axis_types=(axes_type, axes_type),
)
arg_sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec("some_axis")
)
with mesh:
with jax.set_mesh(mesh):
result = broadcast_val(
jax.device_put(jnp.ones(shape=[8, 8]), arg_sharding)
)
Expand All @@ -74,11 +76,21 @@ def broadcast_val(val):
# No clients dimension in the mesh, we don't lay out the clients along that
# nonexistent dimension, but rather replicate them. Notice that we don't
# need to specify the sharding to DrJAX; it should be inferred by GSPMD.
expected_result_pspec = jax.sharding.PartitionSpec(None, "some_axis", None)
expected_result_pspec = jax.sharding.PartitionSpec(
placement_name, "some_axis", None
)
self.assertShardingEqual(
result, jax.sharding.NamedSharding(mesh, expected_result_pspec)
)

@chex.variants(with_jit=True, without_jit=True)
@parameterized.product(
placement_name=["clients", "XY"],
axes_type=[
jax.sharding.AxisType.Auto,
jax.sharding.AxisType.Explicit,
],
)
def test_broadcast_mesh_arg_without_placement(
self, placement_name, axes_type
):
Expand All @@ -88,31 +100,58 @@ def test_broadcast_mesh_arg_without_placement(
axis_types=(axes_type,),
)

@self.variant
@drjax_program(placements={placement_name: 100})
def broadcast_val(val):
return api.broadcast(val, mesh=mesh)

arg_sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec("some_axis")
)
result = broadcast_val(jax.device_put(jnp.ones(shape=[8, 8]), arg_sharding))

chex.assert_trees_all_close(result, jnp.ones(shape=[100, 8, 8]))
# No clients dimension in the mesh, we don't lay out the clients along that
# nonexistent dimension, but rather replicate them. Notice that we don't
# need to specify the sharding to DrJAX; it should be inferred by GSPMD.
expected_result_pspec = jax.sharding.PartitionSpec(None, "some_axis", None)
self.assertShardingEqual(
result, jax.sharding.NamedSharding(mesh, expected_result_pspec)
)
if (
self.variant.type == chex.ChexVariantType.WITH_JIT
and axes_type == jax.sharding.AxisType.Explicit
):
with jax.set_mesh(mesh):
with self.assertRaisesRegex(
ValueError, "not found in mesh with explicit axes"
):
broadcast_val(jax.device_put(jnp.ones(shape=[8, 8]), arg_sharding))
else:
with jax.set_mesh(mesh):
result = broadcast_val(
jax.device_put(jnp.ones(shape=[8, 8]), arg_sharding)
)

chex.assert_trees_all_close(result, jnp.ones(shape=[100, 8, 8]))
# No clients dimension in the mesh, we don't lay out the clients along
# that nonexistent dimension, but rather replicate them. Notice that we
# don't need to specify the sharding to DrJAX; it should be inferred by
# GSPMD.
expected_result_pspec = jax.sharding.PartitionSpec(
None, "some_axis", None
)
self.assertShardingEqual(
result, jax.sharding.NamedSharding(mesh, expected_result_pspec)
)

@chex.variants(with_jit=True, without_jit=True)
@parameterized.product(
placement_name=["clients", "XY"],
axes_type=[
jax.sharding.AxisType.Auto,
jax.sharding.AxisType.Explicit,
],
)
def test_fully_sharded_broadcast_mesh_arg(self, placement_name, axes_type):
mesh = jax.sharding.Mesh(
np.array(jax.devices()).reshape([4, 2]),
axis_names=(placement_name, "some_axis"),
axis_types=(axes_type, axes_type),
)

@self.variant
@drjax_program(placements={placement_name: 8})
def broadcast_val(val):
return api.broadcast(val, mesh=mesh)
Expand All @@ -121,7 +160,10 @@ def broadcast_val(val):
mesh, jax.sharding.PartitionSpec("some_axis")
)

result = broadcast_val(jax.device_put(jnp.ones(shape=[8, 8]), arg_sharding))
with jax.set_mesh(mesh):
result = broadcast_val(
jax.device_put(jnp.ones(shape=[8, 8]), arg_sharding)
)

chex.assert_trees_all_close(result, jnp.ones(shape=[8, 8, 8]))
# The result should be sharded across the placement_name axis.
Expand All @@ -132,6 +174,14 @@ def broadcast_val(val):
result, jax.sharding.NamedSharding(mesh, expected_result_pspec)
)

@chex.variants(with_jit=True, without_jit=True)
@parameterized.product(
placement_name=["clients", "XY"],
axes_type=[
jax.sharding.AxisType.Auto,
jax.sharding.AxisType.Explicit,
],
)
def test_temperature_sensors_example(self, placement_name, axes_type):
def one_if_over(threshold, value):
return jax.lax.cond(
Expand All @@ -146,27 +196,37 @@ def one_if_over(threshold, value):
axis_names=(placement_name, "some_axis"),
axis_types=(axes_type, axes_type),
)
jax.set_mesh(mesh)

@drjax_program(placements={placement_name: placement_dim})
def temperature_sensors_example(threshold, values):
threshold_at_clients = api.broadcast(threshold)
values_over = api.map_fn(one_if_over, (threshold_at_clients, values))
return api.reduce_mean(values_over)

measurements = jax.device_put(
jnp.arange(placement_dim),
jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec(placement_name)
),
)
with jax.set_mesh(mesh):

@self.variant
@drjax_program(placements={placement_name: placement_dim})
def temperature_sensors_example(threshold, values):
threshold_at_clients = api.broadcast(threshold)
values_over = api.map_fn(one_if_over, (threshold_at_clients, values))
return api.reduce_mean(values_over)

measurements = jax.device_put(
jnp.arange(placement_dim, dtype=jnp.float32),
jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec(placement_name)
),
)

self.assertEqual(temperature_sensors_example(24, measurements), 0.75)
self.assertEqual(temperature_sensors_example(24, measurements), 0.75)

@chex.variants(with_jit=True, without_jit=True)
@parameterized.product(
placement_name=["clients", "XY"],
axes_type=[
jax.sharding.AxisType.Auto,
jax.sharding.AxisType.Explicit,
],
)
def test_temperature_sensors_example_multiple_placement_values(
self, placement_name, axes_type
):

@self.variant
def one_if_over(threshold, value):
return jax.lax.cond(
value > threshold,
Expand All @@ -179,39 +239,135 @@ def one_if_over(threshold, value):
axis_names=(placement_name, "some_axis"),
axis_types=(axes_type, axes_type),
)
jax.set_mesh(mesh)
with jax.set_mesh(mesh):

@self.variant
@drjax_program(placements={placement_name: 100})
def temperature_sensors_example_100_clients(threshold, values):
threshold_at_clients = api.broadcast(threshold)
values_over = api.map_fn(one_if_over, (threshold_at_clients, values))
return api.reduce_mean(values_over)

@self.variant
@drjax_program(placements={placement_name: 20})
def temperature_sensors_example_20_clients(threshold, values):
threshold_at_clients = api.broadcast(threshold)
values_over = api.map_fn(one_if_over, (threshold_at_clients, values))
return api.reduce_mean(values_over)

placement_sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec(placement_name)
)
measurements_100 = jax.device_put(
jnp.arange(100, dtype=jnp.float32), placement_sharding
)
measurements_20 = jax.device_put(
jnp.arange(20, dtype=jnp.float32), placement_sharding
)

self.assertEqual(
temperature_sensors_example_100_clients(24, measurements_100), 0.75
)
self.assertEqual(
temperature_sensors_example_20_clients(3, measurements_20),
0.8,
)
# We should be able to recover the original result flipping back to the
# original function.
self.assertEqual(
temperature_sensors_example_100_clients(24, measurements_100), 0.75
)

# NOTE: this test only fails when in a jax.jit and with an Explicit mesh, so
# we skip the without_jit variant and the Auto mesh.
@chex.variants(with_jit=True, without_jit=False)
@parameterized.product(placement_name=["clients", "XY"])
def test_broadcast_raises_error_on_missing_axis_explicit_mesh(
self, placement_name
):
mesh = jax.sharding.Mesh(
np.array(jax.devices()),
axis_names=("some_axis",),
axis_types=(jax.sharding.AxisType.Explicit,),
)

@self.variant
@drjax_program(placements={placement_name: 100})
def temperature_sensors_example_100_clients(threshold, values):
threshold_at_clients = api.broadcast(threshold)
values_over = api.map_fn(one_if_over, (threshold_at_clients, values))
return api.reduce_mean(values_over)

@drjax_program(placements={placement_name: 20})
def temperature_sensors_example_20_clients(threshold, values):
threshold_at_clients = api.broadcast(threshold)
values_over = api.map_fn(one_if_over, (threshold_at_clients, values))
return api.reduce_mean(values_over)

placement_sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec(placement_name)
def broadcast_val(val):
return api.broadcast(val)

arg_sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec("some_axis")
)
measurements_100 = jax.device_put(jnp.arange(100), placement_sharding)
measurements_20 = jax.device_put(jnp.arange(20), placement_sharding)

self.assertEqual(
temperature_sensors_example_100_clients(24, measurements_100), 0.75
with jax.set_mesh(mesh):
with self.assertRaisesRegex(
ValueError, "not found in mesh with explicit axes"
):
broadcast_val(jax.device_put(jnp.ones(shape=[8, 8]), arg_sharding))

@parameterized.product(
placement_name=["clients", "XY"],
axes_type=[
jax.sharding.AxisType.Auto,
jax.sharding.AxisType.Explicit,
],
)
def test_broadcast_with_placement_mesh_api(self, placement_name, axes_type):
"""Verifies that api.broadcast shards along placement when axis in mesh."""
if axes_type != jax.sharding.AxisType.Explicit:
self.skipTest("Only for Explicit mesh")

@drjax_program(placements={placement_name: 100})
def broadcast_val(val):
return api.broadcast(val)

mesh = jax.sharding.Mesh(
np.array(jax.devices()).reshape([4, 2]),
axis_names=(placement_name, "some_axis"),
axis_types=(axes_type, axes_type),
)

with jax.set_mesh(mesh):
aval = jax.eval_shape(broadcast_val, jnp.zeros(shape=[10]))

self.assertEqual(aval.shape, (100, 10))
self.assertIsInstance(aval.sharding, jax.sharding.NamedSharding)
self.assertEqual(
temperature_sensors_example_20_clients(3, measurements_20),
0.8,
aval.sharding.spec, jax.sharding.PartitionSpec(placement_name, None)
)
# We should be able to recover the original result flipping back to the
# original function.
self.assertEqual(
temperature_sensors_example_100_clients(24, measurements_100), 0.75

@parameterized.product(
placement_name=["clients", "XY"],
axes_type=[
jax.sharding.AxisType.Auto,
jax.sharding.AxisType.Explicit,
],
)
def test_broadcast_with_auto_mesh_missing_axis_api(
self, placement_name, axes_type
):
"""Verifies that api.broadcast replicates when axis is missing in auto mesh."""
if axes_type != jax.sharding.AxisType.Auto:
self.skipTest("Only for Auto mesh")

@drjax_program(placements={placement_name: 100})
def broadcast_val(val):
return api.broadcast(val)

# Create a mesh without the placement axis
mesh = jax.sharding.Mesh(
np.array(jax.devices()),
axis_names=("some_other_axis",),
axis_types=(axes_type,),
)

with jax.set_mesh(mesh):
aval = jax.eval_shape(broadcast_val, jnp.zeros(shape=[10]))

self.assertEqual(aval.shape, (100, 10))
self.assertIsNone(aval.sharding)


class ApiErrorsTest(absltest.TestCase):

Expand Down
Loading
Loading