From d4be9fbec25ad8a8337156b14c4d0660151eb0a5 Mon Sep 17 00:00:00 2001 From: Ilya Panfilov Date: Mon, 1 Jun 2026 00:06:46 -0400 Subject: [PATCH 1/4] Remove explicit xla_gpu_enable_nccl_comm_splitting=false because it should be default starting JAX 0.8 --- ci/jax.sh | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/ci/jax.sh b/ci/jax.sh index 0e1e356f6..61b4b2372 100755 --- a/ci/jax.sh +++ b/ci/jax.sh @@ -71,8 +71,6 @@ run_test_config_mgpu() { # Mitigate distributed tests hang by adding 5min timeout _timeout_args="--timeout 300 --timeout-method thread" - # Workaround for some distributed tests hang/abortion - export XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" if [ $_fus_attn = $_DEFAULT_FUSED_ATTN ]; then _dfa_level=2 @@ -83,8 +81,7 @@ run_test_config_mgpu() { fi run_default_fa 2 test_distributed_dense.py - # RCCL_MSCCL_ENABLE=0 is to avoid hangs in some distributed tests (ROCM-1719) - RCCL_MSCCL_ENABLE=0 run $_dfa_level test_distributed_fused_attn.py $_timeout_args + run $_dfa_level test_distributed_fused_attn.py $_timeout_args run_default_fa 3 test_distributed_layernorm.py run_default_fa 2 test_distributed_layernorm_mlp.py $_timeout_args run_default_fa 3 test_distributed_softmax.py From 38712cb8ba5c8a23d7c7aed20440e74ac06202ea Mon Sep 17 00:00:00 2001 From: Ilya Panfilov Date: Mon, 1 Jun 2026 00:08:35 -0400 Subject: [PATCH 2/4] Do not use deprecated pxla.thread_resources Meshs --- .../attention/benchmark_attention_jax.py | 4 +- tests/jax/test_distributed_dense.py | 6 ++- tests/jax/test_distributed_layernorm.py | 6 ++- tests/jax/test_distributed_layernorm_mlp.py | 4 +- tests/jax/test_distributed_softmax.py | 4 +- tests/jax/test_fused_attn.py | 8 ++-- transformer_engine/jax/cpp_extensions/base.py | 2 +- transformer_engine/jax/sharding.py | 42 +++++++++++++++---- 8 files changed, 54 insertions(+), 22 deletions(-) diff --git a/benchmarks/attention/benchmark_attention_jax.py b/benchmarks/attention/benchmark_attention_jax.py index 54dd28505..b0b60cdb8 100644 --- a/benchmarks/attention/benchmark_attention_jax.py +++ b/benchmarks/attention/benchmark_attention_jax.py @@ -136,7 +136,7 @@ def bench_forward(self, warmup, iters, timings_dir): self.dropout_rng_sharding, ], ) - with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + with jax.set_mesh(self.mesh), fp8_autocast(mesh_resource=self.mesh_resource): for _ in range(warmup): customcall_fused_dpa_jit(*customcall_args) @@ -227,7 +227,7 @@ def grad_func(func, *args, cp_reverse_out=False, **kwargs): ), out_shardings=(None, grad_shardings), ) - with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + with jax.set_mesh(self.mesh), fp8_autocast(mesh_resource=self.mesh_resource): for _ in range(warmup): jitted_primitive(*customcall_args) diff --git a/tests/jax/test_distributed_dense.py b/tests/jax/test_distributed_dense.py index b8caf188d..818298ed8 100644 --- a/tests/jax/test_distributed_dense.py +++ b/tests/jax/test_distributed_dense.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -127,7 +129,7 @@ def test_distributed_gemm( contracting_dims = ((2,), (0,)) # Contract on hidden_in dimension - with mesh, autocast(enabled=False, mesh_resource=mesh_resource): + with jax.set_mesh(mesh), autocast(enabled=False, mesh_resource=mesh_resource): # TE GEMM result te_result = _jitted_gemm( x_sharded, @@ -209,7 +211,7 @@ def test_te_distributed_dense_grad( contracting_dims = ((2,), (0,)) - with mesh, autocast(enabled=False, mesh_resource=mesh_resource): + with jax.set_mesh(mesh), autocast(enabled=False, mesh_resource=mesh_resource): # Test gradients w.r.t. all inputs te_grad_func = jax.jit( jax.value_and_grad(self._te_sum_dense, argnums=(0, 1, 2)), diff --git a/tests/jax/test_distributed_layernorm.py b/tests/jax/test_distributed_layernorm.py index 21359cedf..eb4497a0a 100644 --- a/tests/jax/test_distributed_layernorm.py +++ b/tests/jax/test_distributed_layernorm.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -135,7 +137,7 @@ def ref_func(x, gamma, beta): ) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) - with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource): + with jax.set_mesh(mesh), autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource): x_named_sharding = NamedSharding(mesh, x_pspec) g_named_sharding = NamedSharding(mesh, g_pspec) b_named_sharding = NamedSharding(mesh, b_pspec) @@ -217,7 +219,7 @@ def ref_func(x, gamma): ) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) - with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource): + with jax.set_mesh(mesh), autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource): x_named_sharding = NamedSharding(mesh, x_pspec) g_named_sharding = NamedSharding(mesh, g_pspec) x_ = jax.device_put(x, x_named_sharding) diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 6a2f395b1..4ed9e3cf5 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -261,7 +261,7 @@ def _test_layernorm_mlp_grad( # Multi GPUs devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) - with mesh, autocast( + with jax.set_mesh(mesh), autocast( enabled=quantization_recipe is not None, recipe=quantization_recipe, mesh_resource=mesh_resource, @@ -452,7 +452,7 @@ def _test_layernorm_mlp( device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) - with mesh, autocast( + with jax.set_mesh(mesh), autocast( enabled=use_fp8, recipe=quantization_recipe, mesh_resource=mesh_resource ): ln_mlp_sharded = LayerNormMLP( diff --git a/tests/jax/test_distributed_softmax.py b/tests/jax/test_distributed_softmax.py index 0665baa4e..ff44f249c 100644 --- a/tests/jax/test_distributed_softmax.py +++ b/tests/jax/test_distributed_softmax.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -109,7 +111,7 @@ def impl_test_softmax( collective_count_ref = self.generate_collectives_count_ref() devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) - with mesh, autocast(mesh_resource=mesh_resource): + with jax.set_mesh(mesh), autocast(mesh_resource=mesh_resource): x_named_sharding = NamedSharding(mesh, x_pspec) mask_named_sharding = NamedSharding(mesh, mask_pspec) x_ = jax.device_put(x, x_named_sharding) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 3b3db30bd..6e9cf23cb 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -907,7 +907,7 @@ def test_forward(self): ], ) - with self.mesh, autocast(mesh_resource=self.mesh_resource): + with jax.set_mesh(self.mesh), autocast(mesh_resource=self.mesh_resource): primitive_out = customcall_fused_dpa_jit(*customcall_args) primitive_out = self.cp_inverse_reorder_fn(primitive_out) @@ -924,7 +924,7 @@ def test_forward(self): assert_allclose(primitive_valid, reference_valid, dtype=self.dtype) if self.coll_count_ref is not None: - with self.mesh, autocast(mesh_resource=self.mesh_resource): + with jax.set_mesh(self.mesh), autocast(mesh_resource=self.mesh_resource): target_hlo = ( customcall_fused_dpa_jit.lower(*customcall_args, **kwargs).compile().as_text() ) @@ -1038,7 +1038,7 @@ def grad_func(func, *args, cp_reverse_out=False, **kwargs): ) ) - with self.mesh, autocast(mesh_resource=self.mesh_resource): + with jax.set_mesh(self.mesh), autocast(mesh_resource=self.mesh_resource): primitive_out, primitive_dgrad = jitted_primitive(*customcall_args) reference_out, reference_dgrad = jitted_reference(*args) @@ -1126,7 +1126,7 @@ def check_dqkv(primitive, reference, pad, idx): ) if self.coll_count_ref is not None: - with self.mesh, autocast(mesh_resource=self.mesh_resource): + with jax.set_mesh(self.mesh), autocast(mesh_resource=self.mesh_resource): target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text() assert_equal_collectives(target_hlo, self.coll_count_ref) diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index e65215bec..670f59cfe 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -223,7 +223,7 @@ def name_of_wrapper_p(): for _name, _value in transformer_engine_jax.registrations().items(): - ffi.register_ffi_target(_name, _value, platform="ROCM" if is_hip_extension else "CUDA") + ffi.register_ffi_target(_name, _value, platform="ROCM" if is_hip_extension() else "CUDA") def manage_primitives(enable_names=None, disable_names=None, disable_all_first=False): diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 9b13412c1..543c1957a 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -20,7 +22,7 @@ from jax.sharding import PartitionSpec, get_abstract_mesh import numpy as np -_PXLA_THREAD_RESOURCES = pxla.thread_resources +#ROCm: disable deprecated pxla.thread_resources import # Axis Names BATCH_AXES = "nvte_batch" @@ -38,10 +40,7 @@ def _get_mesh(): - # Handle Mesh's set via `with mesh:` - mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh - if mesh is not None and not mesh.empty: - return mesh + # ROCm remove deprecated handling of Mesh's set via `with mesh` # Handle Mesh's set via `jax.set_mesh(mesh)` return jax.sharding.get_abstract_mesh() @@ -164,6 +163,18 @@ def filter_manual_axes(name_or_tuple): return x cleaned_pspec = PartitionSpec(*cleaned_axis_names) + + # ROCm: JAX 0.9 support + # In eager mode (x is a concrete jax.Array with an accessible .sharding attribute), + # jax.lax.with_sharding_constraint creates a target sharding using the AbstractMesh. + # JAX then requires the input to already have a NamedSharding. + # During model init or other eager code where inputs are not yet on a mesh + # (e.g. SingleDeviceSharding), return x unchanged. + # Inside jax.jit, traced arrays raise AttributeError for .sharding, so hasattr() + # returns False and we fall through to the normal constraint path. + if hasattr(x, 'sharding') and not isinstance(x.sharding, jax.sharding.NamedSharding): + return x + return jax.lax.with_sharding_constraint(x, cleaned_pspec) @@ -359,6 +370,14 @@ def global_shard_guard(resource: MeshResource): old_resources = _GLOBAL_MESH_RESOURCE try: _GLOBAL_MESH_RESOURCE = resource + # ROCm: JAX 0.9 support + # Validate once at context-setup time, where get_abstract_mesh() correctly + # reflects the physical mesh. Calling _validate_mesh_resource_configuration + # from global_mesh_resource() (i.e. on every access) breaks in JAX 0.9 + # because get_abstract_mesh() returns an empty AbstractMesh when called + # from inside a custom_partitioning sharded_impl during jit(...).lower(). + if resource is not None: + _validate_mesh_resource_configuration(resource) yield finally: _GLOBAL_MESH_RESOURCE = old_resources @@ -375,7 +394,12 @@ def global_mesh_resource() -> MeshResource: " context. If you are not using multiple GPUs, you can use an empty MeshResource by" " wrapping your program in 'with global_shard_guard(MeshResource()):'" ) - _validate_mesh_resource_configuration(_GLOBAL_MESH_RESOURCE) + # ROCm: _validate_mesh_resource_configuration is intentionally NOT called here. + # Validation is done once in global_shard_guard() at context-setup time, where + # get_abstract_mesh() correctly reflects the physical mesh. Calling it here + # would break in JAX 0.9 when global_mesh_resource() is invoked from inside a + # custom_partitioning sharded_impl during jit(...).lower(), at which point + # get_abstract_mesh() returns an empty AbstractMesh. return _GLOBAL_MESH_RESOURCE @@ -418,8 +442,10 @@ def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mes Returns: Reduced tensor """ - all_axes = get_all_mesh_axes() - for axis in all_axes: + # ROCm: Use mesh.axis_names from the concrete mesh argument rather than calling + # get_all_mesh_axes() → _get_mesh() → get_abstract_mesh(), which returns + # empty in JAX 0.9 when called from inside a custom_partitioning sharded_impl. + for axis in mesh.axis_names: if axis != global_mesh_resource().pp_resource: x = lax_paral_op(x, jax.lax.pmax, axis, mesh) return x From e93a51ad56f5042590b7aab118492935ae9d14fe Mon Sep 17 00:00:00 2001 From: Ilya Panfilov Date: Mon, 1 Jun 2026 00:11:50 -0400 Subject: [PATCH 3/4] Remove unwanted RCCL flags removal --- ci/jax.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ci/jax.sh b/ci/jax.sh index 61b4b2372..50427d5c9 100755 --- a/ci/jax.sh +++ b/ci/jax.sh @@ -81,7 +81,8 @@ run_test_config_mgpu() { fi run_default_fa 2 test_distributed_dense.py - run $_dfa_level test_distributed_fused_attn.py $_timeout_args + # RCCL_MSCCL_ENABLE=0 is to avoid hangs in some distributed tests (ROCM-1719) + RCCL_MSCCL_ENABLE=0 run $_dfa_level test_distributed_fused_attn.py $_timeout_args run_default_fa 3 test_distributed_layernorm.py run_default_fa 2 test_distributed_layernorm_mlp.py $_timeout_args run_default_fa 3 test_distributed_softmax.py From 6f65333b24da9977a82206d62299654be6d8c3ff Mon Sep 17 00:00:00 2001 From: Ilya Panfilov Date: Mon, 1 Jun 2026 02:51:28 -0400 Subject: [PATCH 4/4] Revert JAX flag removal because it is not properly defaulted in JAX 0.8.0 --- ci/jax.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ci/jax.sh b/ci/jax.sh index 50427d5c9..0e1e356f6 100755 --- a/ci/jax.sh +++ b/ci/jax.sh @@ -71,6 +71,8 @@ run_test_config_mgpu() { # Mitigate distributed tests hang by adding 5min timeout _timeout_args="--timeout 300 --timeout-method thread" + # Workaround for some distributed tests hang/abortion + export XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" if [ $_fus_attn = $_DEFAULT_FUSED_ATTN ]; then _dfa_level=2