Skip to content

[Feat] Fuse TopKGatingSoftmax and MoE Sorting kernels#582

Open
amd-wsung102 wants to merge 14 commits into
ROCm:mainfrom
amd-wsung102:fuse_topk_sorting_updated
Open

[Feat] Fuse TopKGatingSoftmax and MoE Sorting kernels#582
amd-wsung102 wants to merge 14 commits into
ROCm:mainfrom
amd-wsung102:fuse_topk_sorting_updated

Conversation

@amd-wsung102
Copy link
Copy Markdown
Contributor

@amd-wsung102 amd-wsung102 commented May 28, 2026

Motivation

The topk_gating_softmax_kernel.py kernel and moe_sorting_kernel.py kernel can be fused for improved performance across eager mode, graph mode, and raw kernel time.

Relevant Files

kernels/moe_sorting_kernel.py - added fused topk and sorting
kernels/topk_gating_softmax_kernel.py - topk gating softmax kernel
tests/kernels/test_moe_sorting.py - unit test for the fused kernels

Additional Details

The fusion applies to the decode path in moe_sorting and only for number of tokens T=16 and T<16. For T > 16, the fusion doesn't yield improvements and this is an ongoing investigation, and a future PR can be created to tackle this issue.

Test Result - DeepSeek-R1: E=256, topk=8, model_dim=7168, bf16

All time are in us

  • Eager: 2-2.4x improvement
  • Graph: 1.12-1.14x improvement
  • Raw kernel: 1.3-1.4x improvement
T unfused_eager fused_eager unfused_graph fused_graph unfused_kernel fused_kernel eager speedup graph speedup kernel speedup
1 32 13.1 16.2 14.3 13.5 10.3 2.44 1.13 1.32
2 32.5 16 16.2 14.4 14 10.5 2.03 1.12 1.34
4 32.5 16.2 16.7 14.8 14.5 11 2 1.13 1.32
8 33.7 15.9 17.3 15.4 15 11.4 2.11 1.12 1.32
12 34 14.9 19.3 16.9 17.1 12.2 2.28 1.14 1.4
16 33.9 15.4 19.7 17.4 17.7 12.9 2.2 1.13 1.37

Test Result - GPT-OSS 120B: E=128, topk=4, model_dim=2880, bf16

All time are in us

  • Eager: 2.13-2.24x improvement
  • Graph: 1.21-1.27x improvement
  • Raw kernel: 1.3-1.4x improvement
T unfused_eager fused_eager unfused_graph fused_graph unfused_kernel fused_kernel eager speedup graph speedup kernel speedup
1 31.2 14.4 13.3 10.9 9.9 7.1 2.16 1.22 1.39
2 31.4 14.6 13.5 11.1 10.2 7.3 2.15 1.21 1.39
4 31.7 14.8 13.8 11.3 10.4 7.5 2.14 1.21 1.38
8 31.6 14.8 13.9 11.4 10.5 7.6 2.13 1.21 1.38
12 32.8 14.6 15.6 12.3 12.2 8.5 2.24 1.27 1.44
16 33.4 14.9 15.8 12.5 12.3 8.7 2.24 1.27 1.42

Submission Checklist

Comment thread kernels/moe_sorting_kernel.py Outdated


@contextmanager
def _if_then(if_op):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why need this scf if?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fused kernel needs the explicit scf.IfOp and the helper function _if_then because if the plain python if is used, the kernel JIT fails:

File ".../kernels/moe_sorting_kernel.py", line 1242, in __then_1
for _z in range(_zs, _ze, _z1):
TypeError: 'ArithValue' object cannot be interpreted as an integer

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fused kernel needs the explicit scf.IfOp and the helper function _if_then because if the plain python if is used, the kernel JIT fails:

File ".../kernels/moe_sorting_kernel.py", line 1242, in __then_1
for _z in range(_zs, _ze, _z1):
TypeError: 'ArithValue' object cannot be interpreted as an integer

I've fixed a similar bug before, but I'm not sure if it's the same issue. Could you please simplify the test case that's causing the error and let me reproduce it?

Copy link
Copy Markdown
Contributor Author

@amd-wsung102 amd-wsung102 May 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @xudoyuan, you may use this commit 5627ef2, which uses the regular python if instead of scf if. Then, run pytest tests/kernels/test_moe_sorting.py::test_moe_softmax_sort_fused_oneshot -k "1-256-8-bf16".

It will show this

FAILED tests/kernels/test_moe_sorting.py::test_moe_softmax_sort_fused_oneshot[1-256-8-bf16] - TypeError: 'ArithValue' object cannot be interpreted as an integer

After the regular python if is switched to scf if, like in commit 07ea93d, the pytest doesn't show the error anymore.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @xudoyuan, you may use this commit 5627ef2, which uses the regular python if instead of scf if. Then, run pytest tests/kernels/test_moe_sorting.py::test_moe_softmax_sort_fused_oneshot -k "1-256-8-bf16".

It will show this

FAILED tests/kernels/test_moe_sorting.py::test_moe_softmax_sort_fused_oneshot[1-256-8-bf16] - TypeError: 'ArithValue' object cannot be interpreted as an integer

After the regular python if is switched to scf if, like in commit 07ea93d, the pytest doesn't show the error anymore.

okay, let me check

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

File ".../kernels/moe_sorting_kernel.py", line 1242, in __then_1
for _z in range(_zs, _ze, _z1):
TypeError: 'ArithValue' object cannot be interpreted as an integer

Hi,

Please pay attention to this bugfix PR(#601). It wasn't merged before for some reason, but after merging, you won't need to use scf.if anymore.
Also, for kernel-internal 'device' functions like _emit_topk_gating_softmax_body, which appear outside of @flyc.kernel and contain control flow syntax (if/for, etc.), you still need to use @flyc.jit to decorate these functions.
After merging the PR and adding @flyc.jit, you can try not using scf.if; I've verified it locally and it works.

Thanks.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thank you Xudong! I will keep an eye out for the status of PR 601, and make the appropriate changes that you suggested.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @xudoyuan @coderfeli , I have addressed the above issues in this commit 2538368. The commit replaced scf.if with plain python if, and added @flyc.jit to device functions like _emit_topk_gating_softmax_body.

Both unit test and CI passed, and performance improvement is the same as the data and tables in PR description.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants