Ipanfilo/jax0.9 support#604
Conversation
|
Rerunning JAX mGPU on one host. It is highly likely not related to the PR - similar failures are seen on dev |
| import numpy as np | ||
|
|
||
| _PXLA_THREAD_RESOURCES = pxla.thread_resources | ||
| #ROCm: disable deprecated pxla.thread_resources import |
There was a problem hiding this comment.
Stub comment + dead import: with _PXLA_THREAD_RESOURCES removed, this file has no remaining reference to pxla — the from jax.interpreters import pxla at line 21 is now unused and should be deleted along with this placeholder comment (grep -n pxla returns only this line and the now-dead import). Also nit: the rest of the PR uses # ROCm: with a space; this is #ROCm: without one.
| #ROCm: disable deprecated pxla.thread_resources import |
(and remove from jax.interpreters import pxla above).
| return mesh | ||
| # ROCm remove deprecated handling of Mesh's set via `with mesh` | ||
| # Handle Mesh's set via `jax.set_mesh(mesh)` | ||
| return jax.sharding.get_abstract_mesh() |
There was a problem hiding this comment.
Backward-incompatibility risk on the upstream/CUDA path. This removes the with mesh: discovery branch entirely, so any caller that still uses the (still-supported in JAX) with mesh: pattern — instead of jax.set_mesh() — will now silently see an empty AbstractMesh here, and TE will treat it as "no mesh" rather than raising. That changes behaviour for non-ROCm users on JAX versions where both patterns still work, not just JAX 0.9.
This file is shared with upstream NVIDIA TE; per the fork's review guidance, behavioural changes to CUDA-reachable code need either (a) a runtime guard so CUDA stays byte-identical, or (b) explicit classification as a generic JAX-0.9 compat fix worth upstreaming. The "ROCm remove deprecated…" comment reads as ROCm-specific, but the change applies unconditionally. Consider falling back to pxla.thread_resources.env.physical_mesh when the abstract mesh is empty (keeping both code paths) or flagging this in the PR description as a deliberate JAX-0.9 cutover.
| # (e.g. SingleDeviceSharding), return x unchanged. | ||
| # Inside jax.jit, traced arrays raise AttributeError for .sharding, so hasattr() | ||
| # returns False and we fall through to the normal constraint path. | ||
| if hasattr(x, 'sharding') and not isinstance(x.sharding, jax.sharding.NamedSharding): |
There was a problem hiding this comment.
Silent fallback to no-op concerns me. This returns x unchanged whenever the input has a .sharding that isn't a NamedSharding — which includes any concrete array on a SingleDeviceSharding outside a mesh context. Previously such a call would have raised (visible failure); now the constraint is silently dropped. In practice users calling with_sharding_constraint(x, pspec) in eager mode tend to expect either an effect or an error, not a no-op, so subtle "my model isn't sharded" bugs become possible.
Two specific concerns:
- The "Inside
jax.jit, traced arrays raise AttributeError for.sharding" assumption is fragile. In current JAX,jax.core.Tracerdoes expose.shardingin many contexts; if/when it does insidejit, this branch will start silently swallowing the constraint there too. A more robust gate is something likenot isinstance(x, jax.core.Tracer)(or check concreteness) so the JIT path is explicit, not incidental. - Scope mismatch with the comment label. Comment is tagged
ROCm: JAX 0.9 supportbut the change is unguarded and runs on the CUDA path as well. If this is genuinely a JAX-0.9-only requirement (not ROCm-specific), please reclassify per the fork's upstream-compat rules: either declare it as a generic JAX-0.9 fix worth upstreaming, or guard with ais_hip_extension()/JAX-version check so CUDA behaviour is unchanged.
At minimum, consider emitting a warnings.warn(...) on the silent-skip branch so users know the constraint didn't apply.
Claude review summaryScope: JAX 0.9 migration — drop Verdict: Test/benchmark migrations and the
Behavioural changes in Copyright headers: OK. All eight changed files have correct, current-year AMD headers; the four NVIDIA-only files that ROCm is now modifying ( |
Description
Add JAX 0.9 support
https://github.com/ROCm/frameworks-internal/issues/16494
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: