-
Notifications
You must be signed in to change notification settings - Fork 29
Ipanfilo/jax0.9 support #604
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
d4be9fb
38712cb
e93a51a
6f65333
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,5 @@ | ||
| # This file was modified for portability to AMDGPU | ||
| # Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. | ||
| # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
|
|
@@ -20,7 +22,7 @@ | |
| from jax.sharding import PartitionSpec, get_abstract_mesh | ||
| import numpy as np | ||
|
|
||
| _PXLA_THREAD_RESOURCES = pxla.thread_resources | ||
| #ROCm: disable deprecated pxla.thread_resources import | ||
|
|
||
| # Axis Names | ||
| BATCH_AXES = "nvte_batch" | ||
|
|
@@ -38,10 +40,7 @@ | |
|
|
||
|
|
||
| def _get_mesh(): | ||
| # Handle Mesh's set via `with mesh:` | ||
| mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh | ||
| if mesh is not None and not mesh.empty: | ||
| return mesh | ||
| # ROCm remove deprecated handling of Mesh's set via `with mesh` | ||
| # Handle Mesh's set via `jax.set_mesh(mesh)` | ||
| return jax.sharding.get_abstract_mesh() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Backward-incompatibility risk on the upstream/CUDA path. This removes the This file is shared with upstream NVIDIA TE; per the fork's review guidance, behavioural changes to CUDA-reachable code need either (a) a runtime guard so CUDA stays byte-identical, or (b) explicit classification as a generic JAX-0.9 compat fix worth upstreaming. The "ROCm remove deprecated…" comment reads as ROCm-specific, but the change applies unconditionally. Consider falling back to |
||
|
|
||
|
|
@@ -164,6 +163,18 @@ def filter_manual_axes(name_or_tuple): | |
| return x | ||
|
|
||
| cleaned_pspec = PartitionSpec(*cleaned_axis_names) | ||
|
|
||
| # ROCm: JAX 0.9 support | ||
| # In eager mode (x is a concrete jax.Array with an accessible .sharding attribute), | ||
| # jax.lax.with_sharding_constraint creates a target sharding using the AbstractMesh. | ||
| # JAX then requires the input to already have a NamedSharding. | ||
| # During model init or other eager code where inputs are not yet on a mesh | ||
| # (e.g. SingleDeviceSharding), return x unchanged. | ||
| # Inside jax.jit, traced arrays raise AttributeError for .sharding, so hasattr() | ||
| # returns False and we fall through to the normal constraint path. | ||
| if hasattr(x, 'sharding') and not isinstance(x.sharding, jax.sharding.NamedSharding): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Silent fallback to no-op concerns me. This returns Two specific concerns:
At minimum, consider emitting a |
||
| return x | ||
|
|
||
| return jax.lax.with_sharding_constraint(x, cleaned_pspec) | ||
|
|
||
|
|
||
|
|
@@ -359,6 +370,14 @@ def global_shard_guard(resource: MeshResource): | |
| old_resources = _GLOBAL_MESH_RESOURCE | ||
| try: | ||
| _GLOBAL_MESH_RESOURCE = resource | ||
| # ROCm: JAX 0.9 support | ||
| # Validate once at context-setup time, where get_abstract_mesh() correctly | ||
| # reflects the physical mesh. Calling _validate_mesh_resource_configuration | ||
| # from global_mesh_resource() (i.e. on every access) breaks in JAX 0.9 | ||
| # because get_abstract_mesh() returns an empty AbstractMesh when called | ||
| # from inside a custom_partitioning sharded_impl during jit(...).lower(). | ||
| if resource is not None: | ||
| _validate_mesh_resource_configuration(resource) | ||
| yield | ||
| finally: | ||
| _GLOBAL_MESH_RESOURCE = old_resources | ||
|
|
@@ -375,7 +394,12 @@ def global_mesh_resource() -> MeshResource: | |
| " context. If you are not using multiple GPUs, you can use an empty MeshResource by" | ||
| " wrapping your program in 'with global_shard_guard(MeshResource()):'" | ||
| ) | ||
| _validate_mesh_resource_configuration(_GLOBAL_MESH_RESOURCE) | ||
| # ROCm: _validate_mesh_resource_configuration is intentionally NOT called here. | ||
| # Validation is done once in global_shard_guard() at context-setup time, where | ||
| # get_abstract_mesh() correctly reflects the physical mesh. Calling it here | ||
| # would break in JAX 0.9 when global_mesh_resource() is invoked from inside a | ||
| # custom_partitioning sharded_impl during jit(...).lower(), at which point | ||
| # get_abstract_mesh() returns an empty AbstractMesh. | ||
| return _GLOBAL_MESH_RESOURCE | ||
|
|
||
|
|
||
|
|
@@ -418,8 +442,10 @@ def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mes | |
| Returns: | ||
| Reduced tensor | ||
| """ | ||
| all_axes = get_all_mesh_axes() | ||
| for axis in all_axes: | ||
| # ROCm: Use mesh.axis_names from the concrete mesh argument rather than calling | ||
| # get_all_mesh_axes() → _get_mesh() → get_abstract_mesh(), which returns | ||
| # empty in JAX 0.9 when called from inside a custom_partitioning sharded_impl. | ||
| for axis in mesh.axis_names: | ||
| if axis != global_mesh_resource().pp_resource: | ||
| x = lax_paral_op(x, jax.lax.pmax, axis, mesh) | ||
| return x | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stub comment + dead import: with
_PXLA_THREAD_RESOURCESremoved, this file has no remaining reference topxla— thefrom jax.interpreters import pxlaat line 21 is now unused and should be deleted along with this placeholder comment (grep -n pxlareturns only this line and the now-dead import). Also nit: the rest of the PR uses# ROCm:with a space; this is#ROCm:without one.(and remove
from jax.interpreters import pxlaabove).