diff --git a/benchmarks/attention/benchmark_attention_jax.py b/benchmarks/attention/benchmark_attention_jax.py index 54dd28505..b0b60cdb8 100644 --- a/benchmarks/attention/benchmark_attention_jax.py +++ b/benchmarks/attention/benchmark_attention_jax.py @@ -136,7 +136,7 @@ def bench_forward(self, warmup, iters, timings_dir): self.dropout_rng_sharding, ], ) - with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + with jax.set_mesh(self.mesh), fp8_autocast(mesh_resource=self.mesh_resource): for _ in range(warmup): customcall_fused_dpa_jit(*customcall_args) @@ -227,7 +227,7 @@ def grad_func(func, *args, cp_reverse_out=False, **kwargs): ), out_shardings=(None, grad_shardings), ) - with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + with jax.set_mesh(self.mesh), fp8_autocast(mesh_resource=self.mesh_resource): for _ in range(warmup): jitted_primitive(*customcall_args) diff --git a/tests/jax/test_distributed_dense.py b/tests/jax/test_distributed_dense.py index b8caf188d..818298ed8 100644 --- a/tests/jax/test_distributed_dense.py +++ b/tests/jax/test_distributed_dense.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -127,7 +129,7 @@ def test_distributed_gemm( contracting_dims = ((2,), (0,)) # Contract on hidden_in dimension - with mesh, autocast(enabled=False, mesh_resource=mesh_resource): + with jax.set_mesh(mesh), autocast(enabled=False, mesh_resource=mesh_resource): # TE GEMM result te_result = _jitted_gemm( x_sharded, @@ -209,7 +211,7 @@ def test_te_distributed_dense_grad( contracting_dims = ((2,), (0,)) - with mesh, autocast(enabled=False, mesh_resource=mesh_resource): + with jax.set_mesh(mesh), autocast(enabled=False, mesh_resource=mesh_resource): # Test gradients w.r.t. all inputs te_grad_func = jax.jit( jax.value_and_grad(self._te_sum_dense, argnums=(0, 1, 2)), diff --git a/tests/jax/test_distributed_layernorm.py b/tests/jax/test_distributed_layernorm.py index 21359cedf..eb4497a0a 100644 --- a/tests/jax/test_distributed_layernorm.py +++ b/tests/jax/test_distributed_layernorm.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -135,7 +137,7 @@ def ref_func(x, gamma, beta): ) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) - with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource): + with jax.set_mesh(mesh), autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource): x_named_sharding = NamedSharding(mesh, x_pspec) g_named_sharding = NamedSharding(mesh, g_pspec) b_named_sharding = NamedSharding(mesh, b_pspec) @@ -217,7 +219,7 @@ def ref_func(x, gamma): ) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) - with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource): + with jax.set_mesh(mesh), autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource): x_named_sharding = NamedSharding(mesh, x_pspec) g_named_sharding = NamedSharding(mesh, g_pspec) x_ = jax.device_put(x, x_named_sharding) diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 6a2f395b1..4ed9e3cf5 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -261,7 +261,7 @@ def _test_layernorm_mlp_grad( # Multi GPUs devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) - with mesh, autocast( + with jax.set_mesh(mesh), autocast( enabled=quantization_recipe is not None, recipe=quantization_recipe, mesh_resource=mesh_resource, @@ -452,7 +452,7 @@ def _test_layernorm_mlp( device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) - with mesh, autocast( + with jax.set_mesh(mesh), autocast( enabled=use_fp8, recipe=quantization_recipe, mesh_resource=mesh_resource ): ln_mlp_sharded = LayerNormMLP( diff --git a/tests/jax/test_distributed_softmax.py b/tests/jax/test_distributed_softmax.py index 0665baa4e..ff44f249c 100644 --- a/tests/jax/test_distributed_softmax.py +++ b/tests/jax/test_distributed_softmax.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -109,7 +111,7 @@ def impl_test_softmax( collective_count_ref = self.generate_collectives_count_ref() devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) - with mesh, autocast(mesh_resource=mesh_resource): + with jax.set_mesh(mesh), autocast(mesh_resource=mesh_resource): x_named_sharding = NamedSharding(mesh, x_pspec) mask_named_sharding = NamedSharding(mesh, mask_pspec) x_ = jax.device_put(x, x_named_sharding) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 3b3db30bd..6e9cf23cb 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -907,7 +907,7 @@ def test_forward(self): ], ) - with self.mesh, autocast(mesh_resource=self.mesh_resource): + with jax.set_mesh(self.mesh), autocast(mesh_resource=self.mesh_resource): primitive_out = customcall_fused_dpa_jit(*customcall_args) primitive_out = self.cp_inverse_reorder_fn(primitive_out) @@ -924,7 +924,7 @@ def test_forward(self): assert_allclose(primitive_valid, reference_valid, dtype=self.dtype) if self.coll_count_ref is not None: - with self.mesh, autocast(mesh_resource=self.mesh_resource): + with jax.set_mesh(self.mesh), autocast(mesh_resource=self.mesh_resource): target_hlo = ( customcall_fused_dpa_jit.lower(*customcall_args, **kwargs).compile().as_text() ) @@ -1038,7 +1038,7 @@ def grad_func(func, *args, cp_reverse_out=False, **kwargs): ) ) - with self.mesh, autocast(mesh_resource=self.mesh_resource): + with jax.set_mesh(self.mesh), autocast(mesh_resource=self.mesh_resource): primitive_out, primitive_dgrad = jitted_primitive(*customcall_args) reference_out, reference_dgrad = jitted_reference(*args) @@ -1126,7 +1126,7 @@ def check_dqkv(primitive, reference, pad, idx): ) if self.coll_count_ref is not None: - with self.mesh, autocast(mesh_resource=self.mesh_resource): + with jax.set_mesh(self.mesh), autocast(mesh_resource=self.mesh_resource): target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text() assert_equal_collectives(target_hlo, self.coll_count_ref) diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index e65215bec..670f59cfe 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -223,7 +223,7 @@ def name_of_wrapper_p(): for _name, _value in transformer_engine_jax.registrations().items(): - ffi.register_ffi_target(_name, _value, platform="ROCM" if is_hip_extension else "CUDA") + ffi.register_ffi_target(_name, _value, platform="ROCM" if is_hip_extension() else "CUDA") def manage_primitives(enable_names=None, disable_names=None, disable_all_first=False): diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 9b13412c1..543c1957a 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -20,7 +22,7 @@ from jax.sharding import PartitionSpec, get_abstract_mesh import numpy as np -_PXLA_THREAD_RESOURCES = pxla.thread_resources +#ROCm: disable deprecated pxla.thread_resources import # Axis Names BATCH_AXES = "nvte_batch" @@ -38,10 +40,7 @@ def _get_mesh(): - # Handle Mesh's set via `with mesh:` - mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh - if mesh is not None and not mesh.empty: - return mesh + # ROCm remove deprecated handling of Mesh's set via `with mesh` # Handle Mesh's set via `jax.set_mesh(mesh)` return jax.sharding.get_abstract_mesh() @@ -164,6 +163,18 @@ def filter_manual_axes(name_or_tuple): return x cleaned_pspec = PartitionSpec(*cleaned_axis_names) + + # ROCm: JAX 0.9 support + # In eager mode (x is a concrete jax.Array with an accessible .sharding attribute), + # jax.lax.with_sharding_constraint creates a target sharding using the AbstractMesh. + # JAX then requires the input to already have a NamedSharding. + # During model init or other eager code where inputs are not yet on a mesh + # (e.g. SingleDeviceSharding), return x unchanged. + # Inside jax.jit, traced arrays raise AttributeError for .sharding, so hasattr() + # returns False and we fall through to the normal constraint path. + if hasattr(x, 'sharding') and not isinstance(x.sharding, jax.sharding.NamedSharding): + return x + return jax.lax.with_sharding_constraint(x, cleaned_pspec) @@ -359,6 +370,14 @@ def global_shard_guard(resource: MeshResource): old_resources = _GLOBAL_MESH_RESOURCE try: _GLOBAL_MESH_RESOURCE = resource + # ROCm: JAX 0.9 support + # Validate once at context-setup time, where get_abstract_mesh() correctly + # reflects the physical mesh. Calling _validate_mesh_resource_configuration + # from global_mesh_resource() (i.e. on every access) breaks in JAX 0.9 + # because get_abstract_mesh() returns an empty AbstractMesh when called + # from inside a custom_partitioning sharded_impl during jit(...).lower(). + if resource is not None: + _validate_mesh_resource_configuration(resource) yield finally: _GLOBAL_MESH_RESOURCE = old_resources @@ -375,7 +394,12 @@ def global_mesh_resource() -> MeshResource: " context. If you are not using multiple GPUs, you can use an empty MeshResource by" " wrapping your program in 'with global_shard_guard(MeshResource()):'" ) - _validate_mesh_resource_configuration(_GLOBAL_MESH_RESOURCE) + # ROCm: _validate_mesh_resource_configuration is intentionally NOT called here. + # Validation is done once in global_shard_guard() at context-setup time, where + # get_abstract_mesh() correctly reflects the physical mesh. Calling it here + # would break in JAX 0.9 when global_mesh_resource() is invoked from inside a + # custom_partitioning sharded_impl during jit(...).lower(), at which point + # get_abstract_mesh() returns an empty AbstractMesh. return _GLOBAL_MESH_RESOURCE @@ -418,8 +442,10 @@ def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mes Returns: Reduced tensor """ - all_axes = get_all_mesh_axes() - for axis in all_axes: + # ROCm: Use mesh.axis_names from the concrete mesh argument rather than calling + # get_all_mesh_axes() → _get_mesh() → get_abstract_mesh(), which returns + # empty in JAX 0.9 when called from inside a custom_partitioning sharded_impl. + for axis in mesh.axis_names: if axis != global_mesh_resource().pp_resource: x = lax_paral_op(x, jax.lax.pmax, axis, mesh) return x