[JAX] Expert Parallelism: JAX primitives + VJPs#3036
Conversation
Greptile SummaryThis PR lands the JAX bindings for Expert Parallelism: XLA FFI handlers over the
Confidence Score: 4/5Safe to merge after adding a dtype guard for topk_weights; without it, mixed-precision MoE training silently produces wrong routing weight values with no error or warning. The topk_weights dtype assumption is an active data-corruption path in mixed-precision MoE training: the C++ FFI hardcodes DType::kFloat32 without checking the buffer element type, and the Python abstract eval deletes the aval before any inspection. Everything else — bootstrap validation, sharding rules, VJP correctness, build wiring — is solid. transformer_engine/jax/csrc/extensions/ep.cpp (EpDispatchFFI topk_weights wrapper) and transformer_engine/jax/cpp_extensions/ep.py (EpDispatchPrimitive.abstract) need coordinated dtype guards before the DType::kFloat32 assumption. Important Files Changed
Reviews (7): Last reviewed commit: "jax/ep: introduce per-layer EpHandle, dr..." | Re-trigger Greptile |
| Error_Type EpPrepareFFI(cudaStream_t stream, Buffer_Type topk_idx, Result_Type token_counts, | ||
| Result_Type handle_mem, Result_Type workspace, EpPrepareConfig config) { | ||
| auto topk_dims = topk_idx.dimensions(); | ||
| NVTE_CHECK(topk_dims.size() >= 2, |
There was a problem hiding this comment.
nit: can we return FFI InvalidArgument instead of a NVTE_CHECK for these inputs?
There was a problem hiding this comment.
This is probably a good idea. I suggest we make another follow-up MR to do so for all the FFIs.
|
I would appreciate your help to review this PR @tdophung @jberchtold-nvidia! |
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
| assert ret == 0, f"ncclGetUniqueId failed with code {ret}" | ||
| uid_bytes = bytes(uid_arr) |
There was a problem hiding this comment.
assert disabled by -O in ctypes UID path
assert ret == 0 is silently elided when Python runs under the -O optimisation flag (common in production or Numba/Conda environments). If ncclGetUniqueId fails, uid_bytes would be all zeros; the all-gather propagates those zeros to every rank in the EP group, causing ncclCommInitRank to either produce mismatched communicators or hang indefinitely with no diagnostic message.
| assert ret == 0, f"ncclGetUniqueId failed with code {ret}" | |
| uid_bytes = bytes(uid_arr) | |
| ret = libnccl.ncclGetUniqueId(ctypes.cast(uid_arr, ctypes.c_void_p)) | |
| if ret != 0: | |
| raise RuntimeError(f"ncclGetUniqueId failed with code {ret}") |
…em_reloc gating Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…s, MoE example) Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
LGTM pending CI
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
| ep_size, | ||
| num_experts, | ||
| max_tokens_per_rank, | ||
| recv_capacity_per_rank, |
There was a problem hiding this comment.
for the ep.h file def of NVTEEpGroupConfig, you used max_recv_tokens_per_rank instead of this. Just for consistency, maybe we should use the same names?
| f"ep_bootstrap requires world_size >= 4 (got {world_size}); NCCL EP requires" | ||
| " at least 4 ranks on the node for its HT mode." | ||
| ) | ||
| UID_SIZE = 128 |
There was a problem hiding this comment.
nit: after looking into the headers in nccl I figured out what this is. However might be helpful to have an inline comment to say what this is. Like # NCCL_UNIQUE_ID_BYTES from nccl.h to store host name, listening port, etc.
| def _dispatch_bwd(recv_capacity_per_rank, dispatch_output_per_expert_alignment, res, g_outputs): | ||
| del recv_capacity_per_rank, dispatch_output_per_expert_alignment | ||
| handle, out_leading, top_k = res | ||
| # Re-pin cotangent sharding: XLA transpose can drop the EP axis on a |
There was a problem hiding this comment.
I now understand that the sharding for other fwd-output cotangents can be lost when propagated to bwd. But is this a fault of JAX that we should ask to be fixed? Did you write this defensively because you ran into a bug where it was trying to read the entire global tensor?
| f = _sys._getframe(1) | ||
| cache_key = (f.f_code.co_filename, f.f_lineno, top_k, alignment) |
There was a problem hiding this comment.
would the public ep_dispatch wrapper interacts badly with this? Every ep_dispatch(...) call in a user's program ultimately routes through
token_counts, handle = tex.ep_prepare(topk_idx, dispatch_output_per_expert_alignment)
from jax/ep.py:191
so _sys._getframe(1) always sees the same (jax/ep.py, 191). That means a model with multiple MoE layers all using ep_dispatch end up sharing one handle_id across layers, which would corrupt the cache for the routing state?
There was a problem hiding this comment.
maybe you don't see it because the test is only 1 layer? I am not entirely surer because I have not tried running a case with multiple layers. This is just from reading the code
There was a problem hiding this comment.
This is valid. Thanks for finding it!
I will expose an option so that the user can host the handle_id per layer.
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ache Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
| Buffer_Type topk_idx, Buffer_Type tokens, Buffer_Type topk_weights, | ||
| Result_Type recv_tokens, Result_Type recv_topk_weights, | ||
| Result_Type workspace, EpDispatchConfig config) { | ||
| (void)ep_state; | ||
| auto token_dims = tokens.dimensions(); |
There was a problem hiding this comment.
topk_weights dtype unconditionally assumed float32 — silent data corruption
EpDispatchFFI wraps topk_weights as DType::kFloat32 regardless of the buffer's actual element type. If a caller passes bfloat16 or float16 weights (common in mixed-precision MoE training), the bytes are silently reinterpreted, producing completely wrong routing weights without any error. The Python abstract eval compounds the problem: topk_weights_aval is deleted before any dtype inspection, so JAX tracing also provides no defence. A dtype guard is needed in both the C++ FFI handler and in EpDispatchPrimitive.abstract.
There was a problem hiding this comment.
Currently our router is fixed to fp32, so there is no chance it will become bf16 or f16. Whenever we change the router to output more than just 1 datatype, then we can change this one too
| arg_shardings = tuple(a.sharding for a in arg_infos) | ||
| out_shardings = [ | ||
| NamedSharding(mesh, PartitionSpec(*resolved)), | ||
| NamedSharding(mesh, PartitionSpec(*resolved, None)), |
There was a problem hiding this comment.
given that len(resolved) is always num_leading_dims + 1 based on the _resolve_out_partition_spec, wouldnt it make this partitionspec have 1 more than needed dim? (so num_leading_dims + 1 + 1?)
maybe we should fix it with:
| NamedSharding(mesh, PartitionSpec(*resolved, None)), | |
| NamedSharding(mesh, PartitionSpec(*resolved[:-1], None)), |
I notice that in your test script you already clarified that XLA will drop all trailing None in the partitionspec, hence there wasn't any error message for mismatching number of dims. However, why not size it exactly the same number of dimensions as grad_topk_weights_aval?
Summary
Third PR in the TE Expert Parallelism (EP) series, built on top of #3034. Lands the JAX bindings: an XLA FFI layer over the
nvte_ep_*C API, a Python wrapper withcustom_vjpfor autograd, mesh-aware sharding rules, a multi-process test suite, and an end-to-end MoE example. NCCLncclEpDispatch/ncclEpCombineare exposed as XLA primitives and work with CUDA-graph capture.Implementation
Public Python API (
transformer_engine/jax/ep.py)ep_dispatch/ep_combinearejax.custom_vjpfunctions: forward is the FFI primitive, backward calls the matchingnvte_ep_*_bwdFFI primitive directly (noep_preparein the bwd — routing state is already cached inhandle.mem). Note thatep_dispatchalso callsep_preparein the forward path, which all-gathers and preprocesses routing maps.XLA FFI bindings (
transformer_engine/jax/csrc/extensions/ep.cpp)Five
XLA_FFI_DEFINE_HANDLER_SYMBOLentries —EpPrepareHandler,EpDispatchHandler,EpCombineHandler,EpDispatchBwdHandler,EpCombineBwdHandler— each calling the correspondingnvte_ep_*C entry point. All markedFFI_CudaGraph_Traitsso they capture cleanly.handle_idis a static FFI attribute baked at jit trace time.Primitives + Python layer (
transformer_engine/jax/cpp_extensions/ep.py, +951 lines)Standard TE primitive plumbing:
abstract_eval(shape/dtype inference),lowering,impl,outer_primitiveregistration, and partitioning rules so the EP collective is treated as a single sharded op by XLA (no spurious resharding around it).Sharding (
transformer_engine/jax/sharding.py, +12 lines)Adds the EP mesh axis to the global mesh resource set so downstream sharding rules can reference it.
Build wiring (
build_tools/jax.py, +41 lines)Threads NCCL EP linkage through the JAX
transformer_engine_jaxextension. No new top-level build flags; rides on the parent PR'sNVTE_BUILD_WITH_NCCL_EP.Tests & example
tests/jax/test_multi_process_ep.py(+690 lines): 13 tests covering bootstrap,ep_prepareshape/handle contracts, primitive-level dispatch/combine identity (uniform + skewed routing),custom_vjpfwd+bwd correctness, and HLO inspection (must not insert XLA collectives outside the EP FFI).tests/jax/multi_process_launch_ep.sh: 4-rank launcher; setsXLA_FLAGSto keep XLA command-buffer capture off for the EP FFI sequence (NCCL EP graph-destroy interaction).examples/jax/ep/ep_moe.py(+394 lines) +run_test_ep.sh: end-to-end MoE with EP, dp=ep=2 mesh, includes a ref-comparison--checkthat verifies fwd+bwd vs a single-process reference.Type of change
Checklist: