Skip to content

Add Training support to MORI#335

Open
sudhu2k wants to merge 13 commits into
mainfrom
sudhu/mori_training
Open

Add Training support to MORI#335
sudhu2k wants to merge 13 commits into
mainfrom
sudhu/mori_training

Conversation

@sudhu2k
Copy link
Copy Markdown

@sudhu2k sudhu2k commented May 25, 2026

Motivation

MoE training needs dispatch/combine behavior that inference-only paths do not provide:

  1. Cached routing (mode-1 / mode-2) — Forward dispatch records where each top-k slot lands; backward must replay the same layout without redoing atomics, dedup, and cross-rank slot assignment. That matches DeepEP’s “first call” vs “cached replay” split used in training stacks.

  2. -1 sentinel experts — Routing tensors can carry DeepEP-style sentinels for dropped or invalid top-k slots. Dispatch must skip those slots; combine must not count them toward PE fan-in.

  3. Multi-layer reuse of one op — A single EpDispatchCombineOp may serve many layers if each layer keeps its own routing snapshot and combine uses the correct handle, without stale symmetric-buffer state bleeding across layers.

This PR adds that training-oriented surface on IntraNode and InterNodeV1 only. Low-latency (LL) internode kernels are intentionally unchanged so inference/LL paths stay aligned with main.

Technical Details

DeepEP-style -1 sentinels

On the IntraNode path (intranode.hpp), any top-k slot with a negative expert ID is treated as a dropped route: the kernel does not send payload for that slot, writes the combine null sentinel into dispDestTokIdMap (encoded as PE == worldSize), and excludes negative peers when deduplicating sends to the same destination rank.

InterNodeV1 (internode_v1.cpp) applies the same sentinel rules on the internode send path. In replay mode, it reuses the maps recorded by an earlier mode-1 dispatch instead of rebuilding layout.

Coverage lives in test_dispatch_combine_minus_one_sentinel (IntraNode only). Test data generation accepts a configurable sentinel_pattern in dispatch_combine_test_utils.py, and the combine reference only counts experts with idx >= 0 when computing unique PE fan-in.

DeepEP-style routing handle

This PR introduces EpDispatchRoutingHandle in Python and EpDispatchCombineRoutingPtrs in C++. Both wrap caller-owned tensors—disp_dest_tok_id_map, internode routing maps, total_recv_token_num, and disp_tok_id_to_src_tok_id_local—that capture where each top-k slot was placed for a given forward pass.

In mode-1, the caller passes return_routing=True on dispatch(). Kernels run with replayMode=false, perform slot assignment and map population, and then snapshot_disp_tok_id_to_src_tok_id_local copies the symmetric inverse map into the handle so later no-P2P combine can read a stable local view.

In mode-2, the caller passes routing=R on a subsequent dispatch(). Kernels run with replayMode=true, reuse the frozen maps from mode-1, and may ship a new activation payload while expert indices and layout stay unchanged. The same handle is passed to combine(..., routing=R), which reads those snapshot maps rather than relying solely on op-owned symmetric state.

The Python API exposes routing and return_routing as keyword-only arguments on dispatch() and combine(), and rejects passing both at once. This path is implemented for IntraNode and InterNodeV1 only. Pybind adds build_args_with_routing and snapshot_disp_tok_id_to_src_tok_id_local, and threads routing pointers through GetEpDispatchCombineArgsRaw.

Submission Checklist

sudhu2k added 6 commits May 23, 2026 23:59
…ions

- Implement logic to skip processing for negative expert IDs in both intra-node and inter-node dispatch functions.
- Update the test suite to support various patterns of sentinel injection, ensuring that dispatched and combined outputs correctly handle these cases.
- Introduce new parameters in test data generation to control sentinel behavior, improving test coverage for edge cases.
- Introduce `EpDispatchRoutingHandle` class to manage per-call routing snapshots for mode-1 and mode-2 dispatch.
- Enhance `EpDispatchCombineArgsRaw` to support routing-related pointers and replay mode.
- Update dispatch and combine operations to utilize the new routing handle, allowing for improved routing management and replay correctness.
- Add tests to validate the functionality of the routing handle, ensuring compatibility with existing dispatch mechanisms and correctness in multi-layer scenarios.
- Simplify test data generation documentation in `dispatch_combine_test_utils.py`.
- Remove deprecated test for TP-replicated routing with sentinels in `test_dispatch_combine_intranode.py`.
- Rename functions for clarity in `test_dispatch_combine_routing_handle.py`, ensuring consistency in naming conventions.
- Update assertions to reflect changes in routing handle comparisons, enhancing test reliability.
@sudhu2k sudhu2k marked this pull request as ready for review May 25, 2026 05:48
@jhchouuu jhchouuu requested a review from Copilot May 25, 2026 09:43
@jhchouuu jhchouuu requested a review from isytwu May 25, 2026 09:43
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends MORI’s MoE dispatch/combine to support training-oriented cached routing (“mode-1/mode-2 replay”) and DeepEP-style -1 sentinel experts for the IntraNode and InterNodeV1 kernel types. It introduces a caller-owned routing handle so later dispatches/combines can deterministically reuse a previously-recorded layout without re-running slot assignment.

Changes:

  • Add Python/C++ routing-handle plumbing (EpDispatchRoutingHandle / EpDispatchCombineRoutingPtrs) and a replay-mode args builder for cached dispatch/combine.
  • Update IntraNode + InterNodeV1 kernels to (a) skip -1 experts and (b) replay cached routing maps in mode-2, plus avoid resetting caller-owned routing tensors.
  • Add new Python tests for routing-handle parity/replay and add -1 sentinel generation + reference updates (IntraNode sentinel coverage).

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
tests/python/ops/test_dispatch_combine_routing_handle.py New routing-handle parity/replay/multi-layer tests (currently configured as single-node InterNodeV1).
tests/python/ops/test_dispatch_combine_intranode.py Adds IntraNode-only -1 sentinel dispatch/combine test coverage.
tests/python/ops/dispatch_combine_test_utils.py Adds sentinel pattern generation and updates combine reference to ignore -1 experts.
src/pybind/pybind_ops.cpp Adds build_args_with_routing and snapshot helper for routing-handle support.
src/ops/dispatch_combine/intranode.hpp Adds replay mode + -1 sentinel behavior and routing-handle local-map selection/reset changes.
src/ops/dispatch_combine/internode_v1.cpp Adds replay mode + -1 sentinel behavior and avoids resetting caller-owned routing tensors.
src/ops/dispatch_combine/dispatch_combine.cpp Adds routing-aware overload of GetEpDispatchCombineArgsRaw.
python/mori/ops/dispatch_combine.py Adds EpDispatchRoutingHandle, dispatch(..., routing/return_routing), and combine(..., routing=...).
include/mori/ops/dispatch_combine/dispatch_combine.hpp Adds EpDispatchCombineRoutingPtrs and replay-mode fields to args structs.
Comments suppressed due to low confidence (1)

python/mori/ops/dispatch_combine.py:605

  • dispatch(..., routing=R) (mode-2 replay) does not validate that the replay call matches the mode-1 snapshot (e.g., input.size(0) equals R.cur_rank_num_token, and routing tensors are on the same device). A mismatch can lead to incorrect reads/writes since kernels iterate based on num_tokens=input.size(0) but interpret routing maps produced for a different token count. Consider adding a fast Python-side check for is_replay to ensure the token count matches the handle and that routing tensors are CUDA tensors on the current device.
    ):
        if routing is not None and return_routing:
            raise ValueError(
                "pass either `routing=` (replay) or `return_routing=True` "
                "(new layout), not both"
            )
        use_routing_handle = routing is not None or return_routing
        is_replay = routing is not None
        if use_routing_handle and not self._supports_routing_handle():
            raise NotImplementedError(
                f"routing handle path not supported for kernel_type="
                f"{self.config.kernel_type}; only IntraNode and InterNodeV1 "
                "expose mode-1/mode-2 dispatch."
            )
        if return_routing:
            routing = self._alloc_routing_handle()

        hidden_dim = input.size(1)
        weight_ptr = weights.data_ptr() if weights is not None else 0
        has_scales = scales is not None and self.config.scale_dim > 0
        scale_ptr = scales.data_ptr() if has_scales else 0
        actual_bn, actual_rbn, actual_wpb = self._resolve_launch_params(
            block_num,
            rdma_block_num,
            warp_per_block,
            num_tokens=input.size(0),
            hidden_dim=hidden_dim,
            dtype=input.dtype,
            tuning_rules=self._dispatch_rules,
        )
        self._cached_dispatch_launch = (actual_bn, actual_rbn, actual_wpb)

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

index_t* totalRecvTokenNum{nullptr};
index_t* dispTokIdToSrcTokIdLocal{nullptr};

bool IsValid() const { return dispDestTokIdMap != nullptr; }
Comment on lines 846 to 857
is_zero_copy = not actual_use_ext
cur_n = (
routing.cur_rank_num_token
if routing is not None
else self._get_cur_rank_num_token(self._handle)
)
actual_bn, actual_rbn, actual_wpb = self._resolve_launch_params(
block_num,
rdma_block_num,
warp_per_block,
num_tokens=self._get_cur_rank_num_token(self._handle),
num_tokens=cur_n,
hidden_dim=hidden_dim,
Comment on lines +51 to +68
def _make_internode_v1_config(rank, world_size):
return mori.ops.EpDispatchCombineConfig(
data_type=torch.bfloat16,
rank=rank,
world_size=world_size,
hidden_dim=4096,
scale_dim=0,
scale_type_size=1,
max_num_inp_token_per_rank=32,
num_experts_per_rank=4,
num_experts_per_token=4,
max_token_type_size=2,
block_num=96,
rdma_block_num=64,
warp_num_per_block=8,
kernel_type=mori.ops.EpDispatchCombineKernelType.InterNodeV1,
gpu_per_node=world_size,
)
Comment on lines 111 to 130
for (int i = warpId; i < (endTokenIdx - startTokenIdx) * config.numExpertPerToken; i += warpNum) {
index_t tokenId = i / config.numExpertPerToken + startTokenIdx;
index_t destPe =
args.tokenIndices[startTokenIdx * config.numExpertPerToken + i] / config.numExpertPerRank;
index_t expertOffset = startTokenIdx * config.numExpertPerToken + i;
index_t destExpert = args.tokenIndices[expertOffset];
if (destExpert < 0) {
if (!args.replayMode && laneId == 0)
args.dispDestTokIdMap[expertOffset] = NullFlatTokenIndex(config);
continue;
}
index_t destPe = destExpert / config.numExpertPerRank;
int destNode = destPe / config.gpuPerNode;

int lanePe = -1, laneNode = -1;
if (laneId < numExpertPerToken) {
lanePe = (args.tokenIndices[tokenId * numExpertPerToken + laneId] / config.numExpertPerRank);
index_t laneExpert = args.tokenIndices[tokenId * numExpertPerToken + laneId];
// Sentinel lanes get a unique impossible destPe so dedup cannot false-match.
lanePe = (laneExpert < 0) ? (-1 - static_cast<int>(laneId))
: (laneExpert / config.numExpertPerRank);
laneNode = lanePe / config.gpuPerNode;
};
sudhu2k added 5 commits May 25, 2026 23:07
- Implement internode`test_sentinel_dispatch_combine` to validate the handling of sentinels in dispatch operations.
- Update command options to include `test_sentinel` for executing sentinel tests.
- Enhance routing pointer validation in `EpDispatchCombineRoutingPtrs` to ensure all required pointers are non-null.
- Add tests for DeepEP-style -1 sentinel handling in `test_dispatch_combine_internode_v1.py`, ensuring robustness in multi-node scenarios.
Copy link
Copy Markdown
Collaborator

@isytwu isytwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this PR. I left a couple of minor comments. Overall this looks good from my side once those are addressed.

launch_local_expert_count(
self._cpp_config,
outI_ptr,
total_ptr,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this pass routing.total_recv_token_num.data_ptr() when use_routing_handle is true? Otherwise call_local_expert_count still reads handle.totalRecvTokenNum.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, I've fixed it. Thanks!

Comment thread src/ops/dispatch_combine/intranode.hpp Outdated
continue;
}
if (!args.replayMode) {
// Mode-1 (DeepEP-style cached-mode "first call"): decide where this
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we describe this as cached/replay routing instead of Mode-1 / Mode-2 / DeepEP-style?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed it!

@sudhu2k sudhu2k requested a review from isytwu June 1, 2026 15:25
if (laneId == 0) args.dispDestTokIdMap[i] = FlatTokenIndex(config, config.worldSize, 0);
continue;
}
if (!args.replayMode) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing the earlier comments. One more performance-related point: to avoid any possible regression on the default non-replay path, would it make sense to specialize this as a static/template variant instead?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replayMode is a plain scalar that's uniform across the entire launch hence there wouldn't be any warp divergence.
I ran a benchmark using the following command:
python3 -m tests.python.ops.bench_dispatch_combine --cmd bench --world-size 8 --max-tokens 4096 --hidden-dim 7168 --dtype bf16

the two runs were essentially identical — the replayMode change shows no measurable regression

Phase main curr branch Δ (run 2) Δ (run 1)
Dispatch 1035.8 µs 1032.1 µs −3.7 µs (−0.36%) +1.5 µs (+0.15%)
Combine 935.0 µs 932.8 µs −2.2 µs (−0.23%) −3.6 µs (−0.39%)
E2E ~1946.1 µs ~1941.6 µs −4.5 µs (−0.23%) ~flat

replayMode is a launch-uniform, perfectly-predicted branch, so it costs effectively nothing on the hot path. The data confirms there's no need to template-specialize it for performance. Let me know what you think!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also check a small-token case, e.g. --max-tokens 1 or a few low-token settings?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like noise to me.

1 token

kernel curr (µs) main (µs) Δ Δ%
Dispatch 36.0 37.6 -1.6 -4.3
Combine 20.9 21.2 -0.3 -1.4
E2E 32.2 30.6 +1.6 +5.2

8 tokens

kernel curr (µs) main (µs) Δ Δ%
Dispatch 39.1 38.4 +0.7 +1.8
Combine 27.9 27.5 +0.4 +1.5
E2E 41.7 41.4 +0.3 +0.7

64 tokens

kernel curr (µs) main (µs) Δ Δ%
Dispatch 44.3 43.0 +1.3 +3.0
Combine 35.1 35.1 0.0 0.0
E2E 55.3 54.9 +0.4 +0.7

512 tokens

kernel curr (µs) main (µs) Δ Δ%
Dispatch 161.9 160.5 +1.4 +0.9
Combine 137.2 135.7 +1.5 +1.1
E2E 274.6 272.4 +2.2 +0.8

@sudhu2k sudhu2k requested a review from isytwu June 4, 2026 17:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants