Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/attention/benchmark_attention_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def bench_forward(self, warmup, iters, timings_dir):
self.dropout_rng_sharding,
],
)
with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
with jax.set_mesh(self.mesh), fp8_autocast(mesh_resource=self.mesh_resource):
for _ in range(warmup):
customcall_fused_dpa_jit(*customcall_args)

Expand Down Expand Up @@ -227,7 +227,7 @@ def grad_func(func, *args, cp_reverse_out=False, **kwargs):
),
out_shardings=(None, grad_shardings),
)
with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
with jax.set_mesh(self.mesh), fp8_autocast(mesh_resource=self.mesh_resource):
for _ in range(warmup):
jitted_primitive(*customcall_args)

Expand Down
6 changes: 4 additions & 2 deletions tests/jax/test_distributed_dense.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand Down Expand Up @@ -127,7 +129,7 @@ def test_distributed_gemm(

contracting_dims = ((2,), (0,)) # Contract on hidden_in dimension

with mesh, autocast(enabled=False, mesh_resource=mesh_resource):
with jax.set_mesh(mesh), autocast(enabled=False, mesh_resource=mesh_resource):
# TE GEMM result
te_result = _jitted_gemm(
x_sharded,
Expand Down Expand Up @@ -209,7 +211,7 @@ def test_te_distributed_dense_grad(

contracting_dims = ((2,), (0,))

with mesh, autocast(enabled=False, mesh_resource=mesh_resource):
with jax.set_mesh(mesh), autocast(enabled=False, mesh_resource=mesh_resource):
# Test gradients w.r.t. all inputs
te_grad_func = jax.jit(
jax.value_and_grad(self._te_sum_dense, argnums=(0, 1, 2)),
Expand Down
6 changes: 4 additions & 2 deletions tests/jax/test_distributed_layernorm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand Down Expand Up @@ -135,7 +137,7 @@ def ref_func(x, gamma, beta):
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource):
with jax.set_mesh(mesh), autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource):
x_named_sharding = NamedSharding(mesh, x_pspec)
g_named_sharding = NamedSharding(mesh, g_pspec)
b_named_sharding = NamedSharding(mesh, b_pspec)
Expand Down Expand Up @@ -217,7 +219,7 @@ def ref_func(x, gamma):
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource):
with jax.set_mesh(mesh), autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource):
x_named_sharding = NamedSharding(mesh, x_pspec)
g_named_sharding = NamedSharding(mesh, g_pspec)
x_ = jax.device_put(x, x_named_sharding)
Expand Down
4 changes: 2 additions & 2 deletions tests/jax/test_distributed_layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def _test_layernorm_mlp_grad(
# Multi GPUs
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, autocast(
with jax.set_mesh(mesh), autocast(
enabled=quantization_recipe is not None,
recipe=quantization_recipe,
mesh_resource=mesh_resource,
Expand Down Expand Up @@ -452,7 +452,7 @@ def _test_layernorm_mlp(
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, autocast(
with jax.set_mesh(mesh), autocast(
enabled=use_fp8, recipe=quantization_recipe, mesh_resource=mesh_resource
):
ln_mlp_sharded = LayerNormMLP(
Expand Down
4 changes: 3 additions & 1 deletion tests/jax/test_distributed_softmax.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand Down Expand Up @@ -109,7 +111,7 @@ def impl_test_softmax(
collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, autocast(mesh_resource=mesh_resource):
with jax.set_mesh(mesh), autocast(mesh_resource=mesh_resource):
x_named_sharding = NamedSharding(mesh, x_pspec)
mask_named_sharding = NamedSharding(mesh, mask_pspec)
x_ = jax.device_put(x, x_named_sharding)
Expand Down
8 changes: 4 additions & 4 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,7 @@ def test_forward(self):
],
)

with self.mesh, autocast(mesh_resource=self.mesh_resource):
with jax.set_mesh(self.mesh), autocast(mesh_resource=self.mesh_resource):
primitive_out = customcall_fused_dpa_jit(*customcall_args)
primitive_out = self.cp_inverse_reorder_fn(primitive_out)

Expand All @@ -924,7 +924,7 @@ def test_forward(self):
assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)

if self.coll_count_ref is not None:
with self.mesh, autocast(mesh_resource=self.mesh_resource):
with jax.set_mesh(self.mesh), autocast(mesh_resource=self.mesh_resource):
target_hlo = (
customcall_fused_dpa_jit.lower(*customcall_args, **kwargs).compile().as_text()
)
Expand Down Expand Up @@ -1038,7 +1038,7 @@ def grad_func(func, *args, cp_reverse_out=False, **kwargs):
)
)

with self.mesh, autocast(mesh_resource=self.mesh_resource):
with jax.set_mesh(self.mesh), autocast(mesh_resource=self.mesh_resource):
primitive_out, primitive_dgrad = jitted_primitive(*customcall_args)

reference_out, reference_dgrad = jitted_reference(*args)
Expand Down Expand Up @@ -1126,7 +1126,7 @@ def check_dqkv(primitive, reference, pad, idx):
)

if self.coll_count_ref is not None:
with self.mesh, autocast(mesh_resource=self.mesh_resource):
with jax.set_mesh(self.mesh), autocast(mesh_resource=self.mesh_resource):
target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text()
assert_equal_collectives(target_hlo, self.coll_count_ref)

Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/jax/cpp_extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def name_of_wrapper_p():


for _name, _value in transformer_engine_jax.registrations().items():
ffi.register_ffi_target(_name, _value, platform="ROCM" if is_hip_extension else "CUDA")
ffi.register_ffi_target(_name, _value, platform="ROCM" if is_hip_extension() else "CUDA")


def manage_primitives(enable_names=None, disable_names=None, disable_all_first=False):
Expand Down
42 changes: 34 additions & 8 deletions transformer_engine/jax/sharding.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand All @@ -20,7 +22,7 @@
from jax.sharding import PartitionSpec, get_abstract_mesh
import numpy as np

_PXLA_THREAD_RESOURCES = pxla.thread_resources
#ROCm: disable deprecated pxla.thread_resources import
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stub comment + dead import: with _PXLA_THREAD_RESOURCES removed, this file has no remaining reference to pxla — the from jax.interpreters import pxla at line 21 is now unused and should be deleted along with this placeholder comment (grep -n pxla returns only this line and the now-dead import). Also nit: the rest of the PR uses # ROCm: with a space; this is #ROCm: without one.

Suggested change
#ROCm: disable deprecated pxla.thread_resources import

(and remove from jax.interpreters import pxla above).


# Axis Names
BATCH_AXES = "nvte_batch"
Expand All @@ -38,10 +40,7 @@


def _get_mesh():
# Handle Mesh's set via `with mesh:`
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
if mesh is not None and not mesh.empty:
return mesh
# ROCm remove deprecated handling of Mesh's set via `with mesh`
# Handle Mesh's set via `jax.set_mesh(mesh)`
return jax.sharding.get_abstract_mesh()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Backward-incompatibility risk on the upstream/CUDA path. This removes the with mesh: discovery branch entirely, so any caller that still uses the (still-supported in JAX) with mesh: pattern — instead of jax.set_mesh() — will now silently see an empty AbstractMesh here, and TE will treat it as "no mesh" rather than raising. That changes behaviour for non-ROCm users on JAX versions where both patterns still work, not just JAX 0.9.

This file is shared with upstream NVIDIA TE; per the fork's review guidance, behavioural changes to CUDA-reachable code need either (a) a runtime guard so CUDA stays byte-identical, or (b) explicit classification as a generic JAX-0.9 compat fix worth upstreaming. The "ROCm remove deprecated…" comment reads as ROCm-specific, but the change applies unconditionally. Consider falling back to pxla.thread_resources.env.physical_mesh when the abstract mesh is empty (keeping both code paths) or flagging this in the PR description as a deliberate JAX-0.9 cutover.


Expand Down Expand Up @@ -164,6 +163,18 @@ def filter_manual_axes(name_or_tuple):
return x

cleaned_pspec = PartitionSpec(*cleaned_axis_names)

# ROCm: JAX 0.9 support
# In eager mode (x is a concrete jax.Array with an accessible .sharding attribute),
# jax.lax.with_sharding_constraint creates a target sharding using the AbstractMesh.
# JAX then requires the input to already have a NamedSharding.
# During model init or other eager code where inputs are not yet on a mesh
# (e.g. SingleDeviceSharding), return x unchanged.
# Inside jax.jit, traced arrays raise AttributeError for .sharding, so hasattr()
# returns False and we fall through to the normal constraint path.
if hasattr(x, 'sharding') and not isinstance(x.sharding, jax.sharding.NamedSharding):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Silent fallback to no-op concerns me. This returns x unchanged whenever the input has a .sharding that isn't a NamedSharding — which includes any concrete array on a SingleDeviceSharding outside a mesh context. Previously such a call would have raised (visible failure); now the constraint is silently dropped. In practice users calling with_sharding_constraint(x, pspec) in eager mode tend to expect either an effect or an error, not a no-op, so subtle "my model isn't sharded" bugs become possible.

Two specific concerns:

  1. The "Inside jax.jit, traced arrays raise AttributeError for .sharding" assumption is fragile. In current JAX, jax.core.Tracer does expose .sharding in many contexts; if/when it does inside jit, this branch will start silently swallowing the constraint there too. A more robust gate is something like not isinstance(x, jax.core.Tracer) (or check concreteness) so the JIT path is explicit, not incidental.
  2. Scope mismatch with the comment label. Comment is tagged ROCm: JAX 0.9 support but the change is unguarded and runs on the CUDA path as well. If this is genuinely a JAX-0.9-only requirement (not ROCm-specific), please reclassify per the fork's upstream-compat rules: either declare it as a generic JAX-0.9 fix worth upstreaming, or guard with a is_hip_extension()/JAX-version check so CUDA behaviour is unchanged.

At minimum, consider emitting a warnings.warn(...) on the silent-skip branch so users know the constraint didn't apply.

return x

return jax.lax.with_sharding_constraint(x, cleaned_pspec)


Expand Down Expand Up @@ -359,6 +370,14 @@ def global_shard_guard(resource: MeshResource):
old_resources = _GLOBAL_MESH_RESOURCE
try:
_GLOBAL_MESH_RESOURCE = resource
# ROCm: JAX 0.9 support
# Validate once at context-setup time, where get_abstract_mesh() correctly
# reflects the physical mesh. Calling _validate_mesh_resource_configuration
# from global_mesh_resource() (i.e. on every access) breaks in JAX 0.9
# because get_abstract_mesh() returns an empty AbstractMesh when called
# from inside a custom_partitioning sharded_impl during jit(...).lower().
if resource is not None:
_validate_mesh_resource_configuration(resource)
yield
finally:
_GLOBAL_MESH_RESOURCE = old_resources
Expand All @@ -375,7 +394,12 @@ def global_mesh_resource() -> MeshResource:
" context. If you are not using multiple GPUs, you can use an empty MeshResource by"
" wrapping your program in 'with global_shard_guard(MeshResource()):'"
)
_validate_mesh_resource_configuration(_GLOBAL_MESH_RESOURCE)
# ROCm: _validate_mesh_resource_configuration is intentionally NOT called here.
# Validation is done once in global_shard_guard() at context-setup time, where
# get_abstract_mesh() correctly reflects the physical mesh. Calling it here
# would break in JAX 0.9 when global_mesh_resource() is invoked from inside a
# custom_partitioning sharded_impl during jit(...).lower(), at which point
# get_abstract_mesh() returns an empty AbstractMesh.
return _GLOBAL_MESH_RESOURCE


Expand Down Expand Up @@ -418,8 +442,10 @@ def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mes
Returns:
Reduced tensor
"""
all_axes = get_all_mesh_axes()
for axis in all_axes:
# ROCm: Use mesh.axis_names from the concrete mesh argument rather than calling
# get_all_mesh_axes() → _get_mesh() → get_abstract_mesh(), which returns
# empty in JAX 0.9 when called from inside a custom_partitioning sharded_impl.
for axis in mesh.axis_names:
if axis != global_mesh_resource().pp_resource:
x = lax_paral_op(x, jax.lax.pmax, axis, mesh)
return x
Expand Down
Loading