Add Training support to MORI#335
Conversation
…ions - Implement logic to skip processing for negative expert IDs in both intra-node and inter-node dispatch functions. - Update the test suite to support various patterns of sentinel injection, ensuring that dispatched and combined outputs correctly handle these cases. - Introduce new parameters in test data generation to control sentinel behavior, improving test coverage for edge cases.
- Introduce `EpDispatchRoutingHandle` class to manage per-call routing snapshots for mode-1 and mode-2 dispatch. - Enhance `EpDispatchCombineArgsRaw` to support routing-related pointers and replay mode. - Update dispatch and combine operations to utilize the new routing handle, allowing for improved routing management and replay correctness. - Add tests to validate the functionality of the routing handle, ensuring compatibility with existing dispatch mechanisms and correctness in multi-layer scenarios.
- Simplify test data generation documentation in `dispatch_combine_test_utils.py`. - Remove deprecated test for TP-replicated routing with sentinels in `test_dispatch_combine_intranode.py`. - Rename functions for clarity in `test_dispatch_combine_routing_handle.py`, ensuring consistency in naming conventions. - Update assertions to reflect changes in routing handle comparisons, enhancing test reliability.
There was a problem hiding this comment.
Pull request overview
This PR extends MORI’s MoE dispatch/combine to support training-oriented cached routing (“mode-1/mode-2 replay”) and DeepEP-style -1 sentinel experts for the IntraNode and InterNodeV1 kernel types. It introduces a caller-owned routing handle so later dispatches/combines can deterministically reuse a previously-recorded layout without re-running slot assignment.
Changes:
- Add Python/C++ routing-handle plumbing (
EpDispatchRoutingHandle/EpDispatchCombineRoutingPtrs) and a replay-mode args builder for cached dispatch/combine. - Update IntraNode + InterNodeV1 kernels to (a) skip
-1experts and (b) replay cached routing maps in mode-2, plus avoid resetting caller-owned routing tensors. - Add new Python tests for routing-handle parity/replay and add
-1sentinel generation + reference updates (IntraNode sentinel coverage).
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/python/ops/test_dispatch_combine_routing_handle.py | New routing-handle parity/replay/multi-layer tests (currently configured as single-node InterNodeV1). |
| tests/python/ops/test_dispatch_combine_intranode.py | Adds IntraNode-only -1 sentinel dispatch/combine test coverage. |
| tests/python/ops/dispatch_combine_test_utils.py | Adds sentinel pattern generation and updates combine reference to ignore -1 experts. |
| src/pybind/pybind_ops.cpp | Adds build_args_with_routing and snapshot helper for routing-handle support. |
| src/ops/dispatch_combine/intranode.hpp | Adds replay mode + -1 sentinel behavior and routing-handle local-map selection/reset changes. |
| src/ops/dispatch_combine/internode_v1.cpp | Adds replay mode + -1 sentinel behavior and avoids resetting caller-owned routing tensors. |
| src/ops/dispatch_combine/dispatch_combine.cpp | Adds routing-aware overload of GetEpDispatchCombineArgsRaw. |
| python/mori/ops/dispatch_combine.py | Adds EpDispatchRoutingHandle, dispatch(..., routing/return_routing), and combine(..., routing=...). |
| include/mori/ops/dispatch_combine/dispatch_combine.hpp | Adds EpDispatchCombineRoutingPtrs and replay-mode fields to args structs. |
Comments suppressed due to low confidence (1)
python/mori/ops/dispatch_combine.py:605
dispatch(..., routing=R)(mode-2 replay) does not validate that the replay call matches the mode-1 snapshot (e.g.,input.size(0)equalsR.cur_rank_num_token, and routing tensors are on the same device). A mismatch can lead to incorrect reads/writes since kernels iterate based onnum_tokens=input.size(0)but interpret routing maps produced for a different token count. Consider adding a fast Python-side check foris_replayto ensure the token count matches the handle and that routing tensors are CUDA tensors on the current device.
):
if routing is not None and return_routing:
raise ValueError(
"pass either `routing=` (replay) or `return_routing=True` "
"(new layout), not both"
)
use_routing_handle = routing is not None or return_routing
is_replay = routing is not None
if use_routing_handle and not self._supports_routing_handle():
raise NotImplementedError(
f"routing handle path not supported for kernel_type="
f"{self.config.kernel_type}; only IntraNode and InterNodeV1 "
"expose mode-1/mode-2 dispatch."
)
if return_routing:
routing = self._alloc_routing_handle()
hidden_dim = input.size(1)
weight_ptr = weights.data_ptr() if weights is not None else 0
has_scales = scales is not None and self.config.scale_dim > 0
scale_ptr = scales.data_ptr() if has_scales else 0
actual_bn, actual_rbn, actual_wpb = self._resolve_launch_params(
block_num,
rdma_block_num,
warp_per_block,
num_tokens=input.size(0),
hidden_dim=hidden_dim,
dtype=input.dtype,
tuning_rules=self._dispatch_rules,
)
self._cached_dispatch_launch = (actual_bn, actual_rbn, actual_wpb)
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| index_t* totalRecvTokenNum{nullptr}; | ||
| index_t* dispTokIdToSrcTokIdLocal{nullptr}; | ||
|
|
||
| bool IsValid() const { return dispDestTokIdMap != nullptr; } |
| is_zero_copy = not actual_use_ext | ||
| cur_n = ( | ||
| routing.cur_rank_num_token | ||
| if routing is not None | ||
| else self._get_cur_rank_num_token(self._handle) | ||
| ) | ||
| actual_bn, actual_rbn, actual_wpb = self._resolve_launch_params( | ||
| block_num, | ||
| rdma_block_num, | ||
| warp_per_block, | ||
| num_tokens=self._get_cur_rank_num_token(self._handle), | ||
| num_tokens=cur_n, | ||
| hidden_dim=hidden_dim, |
| def _make_internode_v1_config(rank, world_size): | ||
| return mori.ops.EpDispatchCombineConfig( | ||
| data_type=torch.bfloat16, | ||
| rank=rank, | ||
| world_size=world_size, | ||
| hidden_dim=4096, | ||
| scale_dim=0, | ||
| scale_type_size=1, | ||
| max_num_inp_token_per_rank=32, | ||
| num_experts_per_rank=4, | ||
| num_experts_per_token=4, | ||
| max_token_type_size=2, | ||
| block_num=96, | ||
| rdma_block_num=64, | ||
| warp_num_per_block=8, | ||
| kernel_type=mori.ops.EpDispatchCombineKernelType.InterNodeV1, | ||
| gpu_per_node=world_size, | ||
| ) |
| for (int i = warpId; i < (endTokenIdx - startTokenIdx) * config.numExpertPerToken; i += warpNum) { | ||
| index_t tokenId = i / config.numExpertPerToken + startTokenIdx; | ||
| index_t destPe = | ||
| args.tokenIndices[startTokenIdx * config.numExpertPerToken + i] / config.numExpertPerRank; | ||
| index_t expertOffset = startTokenIdx * config.numExpertPerToken + i; | ||
| index_t destExpert = args.tokenIndices[expertOffset]; | ||
| if (destExpert < 0) { | ||
| if (!args.replayMode && laneId == 0) | ||
| args.dispDestTokIdMap[expertOffset] = NullFlatTokenIndex(config); | ||
| continue; | ||
| } | ||
| index_t destPe = destExpert / config.numExpertPerRank; | ||
| int destNode = destPe / config.gpuPerNode; | ||
|
|
||
| int lanePe = -1, laneNode = -1; | ||
| if (laneId < numExpertPerToken) { | ||
| lanePe = (args.tokenIndices[tokenId * numExpertPerToken + laneId] / config.numExpertPerRank); | ||
| index_t laneExpert = args.tokenIndices[tokenId * numExpertPerToken + laneId]; | ||
| // Sentinel lanes get a unique impossible destPe so dedup cannot false-match. | ||
| lanePe = (laneExpert < 0) ? (-1 - static_cast<int>(laneId)) | ||
| : (laneExpert / config.numExpertPerRank); | ||
| laneNode = lanePe / config.gpuPerNode; | ||
| }; |
- Implement internode`test_sentinel_dispatch_combine` to validate the handling of sentinels in dispatch operations. - Update command options to include `test_sentinel` for executing sentinel tests. - Enhance routing pointer validation in `EpDispatchCombineRoutingPtrs` to ensure all required pointers are non-null. - Add tests for DeepEP-style -1 sentinel handling in `test_dispatch_combine_internode_v1.py`, ensuring robustness in multi-node scenarios.
isytwu
left a comment
There was a problem hiding this comment.
Thanks for working on this PR. I left a couple of minor comments. Overall this looks good from my side once those are addressed.
| launch_local_expert_count( | ||
| self._cpp_config, | ||
| outI_ptr, | ||
| total_ptr, |
There was a problem hiding this comment.
Should this pass routing.total_recv_token_num.data_ptr() when use_routing_handle is true? Otherwise call_local_expert_count still reads handle.totalRecvTokenNum.
There was a problem hiding this comment.
You're right, I've fixed it. Thanks!
| continue; | ||
| } | ||
| if (!args.replayMode) { | ||
| // Mode-1 (DeepEP-style cached-mode "first call"): decide where this |
There was a problem hiding this comment.
Could we describe this as cached/replay routing instead of Mode-1 / Mode-2 / DeepEP-style?
| if (laneId == 0) args.dispDestTokIdMap[i] = FlatTokenIndex(config, config.worldSize, 0); | ||
| continue; | ||
| } | ||
| if (!args.replayMode) { |
There was a problem hiding this comment.
Thanks for addressing the earlier comments. One more performance-related point: to avoid any possible regression on the default non-replay path, would it make sense to specialize this as a static/template variant instead?
There was a problem hiding this comment.
replayMode is a plain scalar that's uniform across the entire launch hence there wouldn't be any warp divergence.
I ran a benchmark using the following command:
python3 -m tests.python.ops.bench_dispatch_combine --cmd bench --world-size 8 --max-tokens 4096 --hidden-dim 7168 --dtype bf16
the two runs were essentially identical — the replayMode change shows no measurable regression
| Phase | main | curr branch | Δ (run 2) | Δ (run 1) |
|---|---|---|---|---|
| Dispatch | 1035.8 µs | 1032.1 µs | −3.7 µs (−0.36%) | +1.5 µs (+0.15%) |
| Combine | 935.0 µs | 932.8 µs | −2.2 µs (−0.23%) | −3.6 µs (−0.39%) |
| E2E | ~1946.1 µs | ~1941.6 µs | −4.5 µs (−0.23%) | ~flat |
replayMode is a launch-uniform, perfectly-predicted branch, so it costs effectively nothing on the hot path. The data confirms there's no need to template-specialize it for performance. Let me know what you think!
There was a problem hiding this comment.
Could you also check a small-token case, e.g. --max-tokens 1 or a few low-token settings?
There was a problem hiding this comment.
Seems like noise to me.
1 token
| kernel | curr (µs) | main (µs) | Δ | Δ% |
|---|---|---|---|---|
| Dispatch | 36.0 | 37.6 | -1.6 | -4.3 |
| Combine | 20.9 | 21.2 | -0.3 | -1.4 |
| E2E | 32.2 | 30.6 | +1.6 | +5.2 |
8 tokens
| kernel | curr (µs) | main (µs) | Δ | Δ% |
|---|---|---|---|---|
| Dispatch | 39.1 | 38.4 | +0.7 | +1.8 |
| Combine | 27.9 | 27.5 | +0.4 | +1.5 |
| E2E | 41.7 | 41.4 | +0.3 | +0.7 |
64 tokens
| kernel | curr (µs) | main (µs) | Δ | Δ% |
|---|---|---|---|---|
| Dispatch | 44.3 | 43.0 | +1.3 | +3.0 |
| Combine | 35.1 | 35.1 | 0.0 | 0.0 |
| E2E | 55.3 | 54.9 | +0.4 | +0.7 |
512 tokens
| kernel | curr (µs) | main (µs) | Δ | Δ% |
|---|---|---|---|---|
| Dispatch | 161.9 | 160.5 | +1.4 | +0.9 |
| Combine | 137.2 | 135.7 | +1.5 | +1.1 |
| E2E | 274.6 | 272.4 | +2.2 | +0.8 |
Motivation
MoE training needs dispatch/combine behavior that inference-only paths do not provide:
Cached routing (mode-1 / mode-2) — Forward dispatch records where each top-k slot lands; backward must replay the same layout without redoing atomics, dedup, and cross-rank slot assignment. That matches DeepEP’s “first call” vs “cached replay” split used in training stacks.
-1sentinel experts — Routing tensors can carry DeepEP-style sentinels for dropped or invalid top-k slots. Dispatch must skip those slots; combine must not count them toward PE fan-in.Multi-layer reuse of one op — A single
EpDispatchCombineOpmay serve many layers if each layer keeps its own routing snapshot and combine uses the correct handle, without stale symmetric-buffer state bleeding across layers.This PR adds that training-oriented surface on IntraNode and InterNodeV1 only. Low-latency (LL) internode kernels are intentionally unchanged so inference/LL paths stay aligned with
main.Technical Details
DeepEP-style
-1sentinelsOn the IntraNode path (
intranode.hpp), any top-k slot with a negative expert ID is treated as a dropped route: the kernel does not send payload for that slot, writes the combine null sentinel intodispDestTokIdMap(encoded asPE == worldSize), and excludes negative peers when deduplicating sends to the same destination rank.InterNodeV1 (
internode_v1.cpp) applies the same sentinel rules on the internode send path. In replay mode, it reuses the maps recorded by an earlier mode-1 dispatch instead of rebuilding layout.Coverage lives in
test_dispatch_combine_minus_one_sentinel(IntraNode only). Test data generation accepts a configurablesentinel_patternindispatch_combine_test_utils.py, and the combine reference only counts experts withidx >= 0when computing unique PE fan-in.DeepEP-style routing handle
This PR introduces
EpDispatchRoutingHandlein Python andEpDispatchCombineRoutingPtrsin C++. Both wrap caller-owned tensors—disp_dest_tok_id_map, internode routing maps,total_recv_token_num, anddisp_tok_id_to_src_tok_id_local—that capture where each top-k slot was placed for a given forward pass.In mode-1, the caller passes
return_routing=Trueondispatch(). Kernels run withreplayMode=false, perform slot assignment and map population, and thensnapshot_disp_tok_id_to_src_tok_id_localcopies the symmetric inverse map into the handle so later no-P2P combine can read a stable local view.In mode-2, the caller passes
routing=Ron a subsequentdispatch(). Kernels run withreplayMode=true, reuse the frozen maps from mode-1, and may ship a new activation payload while expert indices and layout stay unchanged. The same handle is passed tocombine(..., routing=R), which reads those snapshot maps rather than relying solely on op-owned symmetric state.The Python API exposes
routingandreturn_routingas keyword-only arguments ondispatch()andcombine(), and rejects passing both at once. This path is implemented forIntraNodeandInterNodeV1only. Pybind addsbuild_args_with_routingandsnapshot_disp_tok_id_to_src_tok_id_local, and threads routing pointers throughGetEpDispatchCombineArgsRaw.Submission Checklist