[Feat] Fuse TopKGatingSoftmax and MoE Sorting kernels#582
[Feat] Fuse TopKGatingSoftmax and MoE Sorting kernels#582amd-wsung102 wants to merge 14 commits into
Conversation
|
|
||
|
|
||
| @contextmanager | ||
| def _if_then(if_op): |
There was a problem hiding this comment.
The fused kernel needs the explicit scf.IfOp and the helper function _if_then because if the plain python if is used, the kernel JIT fails:
File ".../kernels/moe_sorting_kernel.py", line 1242, in __then_1
for _z in range(_zs, _ze, _z1):
TypeError: 'ArithValue' object cannot be interpreted as an integer
There was a problem hiding this comment.
The fused kernel needs the explicit
scf.IfOpand the helper function_if_thenbecause if the plain pythonifis used, the kernel JIT fails:File ".../kernels/moe_sorting_kernel.py", line 1242, in __then_1
for _z in range(_zs, _ze, _z1):
TypeError: 'ArithValue' object cannot be interpreted as an integer
I've fixed a similar bug before, but I'm not sure if it's the same issue. Could you please simplify the test case that's causing the error and let me reproduce it?
There was a problem hiding this comment.
Hi @xudoyuan, you may use this commit 5627ef2, which uses the regular python if instead of scf if. Then, run pytest tests/kernels/test_moe_sorting.py::test_moe_softmax_sort_fused_oneshot -k "1-256-8-bf16".
It will show this
FAILED tests/kernels/test_moe_sorting.py::test_moe_softmax_sort_fused_oneshot[1-256-8-bf16] - TypeError: 'ArithValue' object cannot be interpreted as an integer
After the regular python if is switched to scf if, like in commit 07ea93d, the pytest doesn't show the error anymore.
There was a problem hiding this comment.
Hi @xudoyuan, you may use this commit 5627ef2, which uses the regular python
ifinstead ofscf if. Then, runpytest tests/kernels/test_moe_sorting.py::test_moe_softmax_sort_fused_oneshot -k "1-256-8-bf16".It will show this
FAILED tests/kernels/test_moe_sorting.py::test_moe_softmax_sort_fused_oneshot[1-256-8-bf16] - TypeError: 'ArithValue' object cannot be interpreted as an integer
After the regular python
ifis switched toscf if, like in commit 07ea93d, the pytest doesn't show the error anymore.
okay, let me check
There was a problem hiding this comment.
File ".../kernels/moe_sorting_kernel.py", line 1242, in __then_1
for _z in range(_zs, _ze, _z1):
TypeError: 'ArithValue' object cannot be interpreted as an integer
Hi,
Please pay attention to this bugfix PR(#601). It wasn't merged before for some reason, but after merging, you won't need to use scf.if anymore.
Also, for kernel-internal 'device' functions like _emit_topk_gating_softmax_body, which appear outside of @flyc.kernel and contain control flow syntax (if/for, etc.), you still need to use @flyc.jit to decorate these functions.
After merging the PR and adding @flyc.jit, you can try not using scf.if; I've verified it locally and it works.
Thanks.
There was a problem hiding this comment.
Got it, thank you Xudong! I will keep an eye out for the status of PR 601, and make the appropriate changes that you suggested.
There was a problem hiding this comment.
Hi @xudoyuan @coderfeli , I have addressed the above issues in this commit 2538368. The commit replaced scf.if with plain python if, and added @flyc.jit to device functions like _emit_topk_gating_softmax_body.
Both unit test and CI passed, and performance improvement is the same as the data and tables in PR description.
…o topkgating for better clarity
Motivation
The
topk_gating_softmax_kernel.pykernel andmoe_sorting_kernel.pykernel can be fused for improved performance across eager mode, graph mode, and raw kernel time.Relevant Files
kernels/moe_sorting_kernel.py- added fused topk and sortingkernels/topk_gating_softmax_kernel.py- topk gating softmax kerneltests/kernels/test_moe_sorting.py- unit test for the fused kernelsAdditional Details
The fusion applies to the decode path in moe_sorting and only for number of tokens T=16 and T<16. For T > 16, the fusion doesn't yield improvements and this is an ongoing investigation, and a future PR can be created to tackle this issue.
Test Result - DeepSeek-R1: E=256, topk=8, model_dim=7168, bf16
All time are in us
Test Result - GPT-OSS 120B: E=128, topk=4, model_dim=2880, bf16
All time are in us
Submission Checklist