Skip to content

Ipanfilo/jax0.9 support#604

Open
ipanfilo wants to merge 4 commits into
devfrom
ipanfilo/jax0.9_support
Open

Ipanfilo/jax0.9 support#604
ipanfilo wants to merge 4 commits into
devfrom
ipanfilo/jax0.9_support

Conversation

@ipanfilo
Copy link
Copy Markdown
Collaborator

@ipanfilo ipanfilo commented Jun 1, 2026

Description

Add JAX 0.9 support
https://github.com/ROCm/frameworks-internal/issues/16494

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Remove using of deprecated pxla.thread_resources Meshs and update TE logic accordingly

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@ipanfilo ipanfilo added the ci-level 3 CI test level 3 label Jun 1, 2026
@ipanfilo
Copy link
Copy Markdown
Collaborator Author

ipanfilo commented Jun 1, 2026

Rerunning JAX mGPU on one host. It is highly likely not related to the PR - similar failures are seen on dev

@ipanfilo ipanfilo marked this pull request as ready for review June 1, 2026 15:25
import numpy as np

_PXLA_THREAD_RESOURCES = pxla.thread_resources
#ROCm: disable deprecated pxla.thread_resources import
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stub comment + dead import: with _PXLA_THREAD_RESOURCES removed, this file has no remaining reference to pxla — the from jax.interpreters import pxla at line 21 is now unused and should be deleted along with this placeholder comment (grep -n pxla returns only this line and the now-dead import). Also nit: the rest of the PR uses # ROCm: with a space; this is #ROCm: without one.

Suggested change
#ROCm: disable deprecated pxla.thread_resources import

(and remove from jax.interpreters import pxla above).

return mesh
# ROCm remove deprecated handling of Mesh's set via `with mesh`
# Handle Mesh's set via `jax.set_mesh(mesh)`
return jax.sharding.get_abstract_mesh()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Backward-incompatibility risk on the upstream/CUDA path. This removes the with mesh: discovery branch entirely, so any caller that still uses the (still-supported in JAX) with mesh: pattern — instead of jax.set_mesh() — will now silently see an empty AbstractMesh here, and TE will treat it as "no mesh" rather than raising. That changes behaviour for non-ROCm users on JAX versions where both patterns still work, not just JAX 0.9.

This file is shared with upstream NVIDIA TE; per the fork's review guidance, behavioural changes to CUDA-reachable code need either (a) a runtime guard so CUDA stays byte-identical, or (b) explicit classification as a generic JAX-0.9 compat fix worth upstreaming. The "ROCm remove deprecated…" comment reads as ROCm-specific, but the change applies unconditionally. Consider falling back to pxla.thread_resources.env.physical_mesh when the abstract mesh is empty (keeping both code paths) or flagging this in the PR description as a deliberate JAX-0.9 cutover.

# (e.g. SingleDeviceSharding), return x unchanged.
# Inside jax.jit, traced arrays raise AttributeError for .sharding, so hasattr()
# returns False and we fall through to the normal constraint path.
if hasattr(x, 'sharding') and not isinstance(x.sharding, jax.sharding.NamedSharding):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Silent fallback to no-op concerns me. This returns x unchanged whenever the input has a .sharding that isn't a NamedSharding — which includes any concrete array on a SingleDeviceSharding outside a mesh context. Previously such a call would have raised (visible failure); now the constraint is silently dropped. In practice users calling with_sharding_constraint(x, pspec) in eager mode tend to expect either an effect or an error, not a no-op, so subtle "my model isn't sharded" bugs become possible.

Two specific concerns:

  1. The "Inside jax.jit, traced arrays raise AttributeError for .sharding" assumption is fragile. In current JAX, jax.core.Tracer does expose .sharding in many contexts; if/when it does inside jit, this branch will start silently swallowing the constraint there too. A more robust gate is something like not isinstance(x, jax.core.Tracer) (or check concreteness) so the JIT path is explicit, not incidental.
  2. Scope mismatch with the comment label. Comment is tagged ROCm: JAX 0.9 support but the change is unguarded and runs on the CUDA path as well. If this is genuinely a JAX-0.9-only requirement (not ROCm-specific), please reclassify per the fork's upstream-compat rules: either declare it as a generic JAX-0.9 fix worth upstreaming, or guard with a is_hip_extension()/JAX-version check so CUDA behaviour is unchanged.

At minimum, consider emitting a warnings.warn(...) on the silent-skip branch so users know the constraint didn't apply.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Jun 1, 2026

Claude review summary

Scope: JAX 0.9 migration — drop pxla.thread_resources mesh discovery, switch tests/benchmarks from with mesh: to with jax.set_mesh(mesh):, and add ROCm-tagged workarounds in transformer_engine/jax/sharding.py for get_abstract_mesh() returning empty inside custom_partitioning.sharded_impl.

Verdict: Test/benchmark migrations and the is_hip_extensionis_hip_extension() fix in cpp_extensions/base.py look correct (the latter is a real pre-existing bug — without the call the truthy function object always forced platform="ROCM"). Main concerns are in transformer_engine/jax/sharding.py, posted inline:

  • _get_mesh() drops the with mesh: discovery branch entirely — silent behaviour change on shared (CUDA-reachable) code, not just ROCm.
  • New eager-mode short-circuit in with_sharding_constraint silently no-ops on non-NamedSharding inputs (was previously a visible error); tagged "ROCm:" but runs unconditionally.
  • Dead from jax.interpreters import pxla import + stub placeholder comment left behind by the cleanup.

Behavioural changes in global_shard_guard/global_mesh_resource (validation moved to context-entry) and all_reduce_max_along_all_axes_except_PP (using passed mesh.axis_names instead of the global) look reasonable for the stated JAX-0.9 issue but inherit the same "tagged ROCm but applies to CUDA path" caveat — worth a one-line note in the PR description classifying them as JAX-0.9 compat (upstreamable) vs. ROCm-only.

Copyright headers: OK. All eight changed files have correct, current-year AMD headers; the four NVIDIA-only files that ROCm is now modifying (tests/jax/test_distributed_dense.py, test_distributed_layernorm.py, test_distributed_softmax.py, transformer_engine/jax/sharding.py) correctly gained AMD headers above the preserved NVIDIA lines; no NVIDIA copyright years were altered.

Copy link
Copy Markdown
Contributor

@Micky774 Micky774 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM pending the AI review comments being addressed -- particularly the non-ROCm compatibility for with mesh

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 3 CI test level 3

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants