Skip to content

Add INT8, INT16, and UINT8 type support for CUDA TopK operator#27862

Merged
tianleiwu merged 3 commits intomicrosoft:mainfrom
elwhyjay:add-topk-int8-int16-uint8-cuda-support
Apr 19, 2026
Merged

Add INT8, INT16, and UINT8 type support for CUDA TopK operator#27862
tianleiwu merged 3 commits intomicrosoft:mainfrom
elwhyjay:add-topk-int8-int16-uint8-cuda-support

Conversation

@elwhyjay
Copy link
Copy Markdown
Contributor

@elwhyjay elwhyjay commented Mar 26, 2026

Add type constraints and dispatch cases for int8_t, int16_t, and uint8_t in the CUDA TopK kernel (opset 1-9, 10, 11-23, 24), along with three new .cu template instantiation files. This is the CUDA counterpart to the CPU support added in #27860.

Fixes #27859

Description

Add CUDA kernel type dispatch and template specializations for int8_t, int16_t, and uint8_t types in the CUDA TopK operator

Changed files:

  • onnxruntime/core/providers/cuda/math/topk.cc — type constraints + dispatch cases for int8/int16/uint8
  • onnxruntime/core/providers/cuda/math/topk_impl_i8.cunew template instantiation for int8_t
  • onnxruntime/core/providers/cuda/math/topk_impl_u8.cunew template instantiation for uint8_t
  • onnxruntime/core/providers/cuda/math/topk_impl_i16.cunew template instantiation for int16_t

Motivation and Context

This is the CUDA counterpart to #27860 (CPU TopK INT8/INT16/UINT8 support).

The ONNX specification (opset 11+) lists INT8, INT16, and UINT8 as valid input types for the TopK operator. After #27860 added CPU support, the CUDA execution provider still lacked kernels for these types, causing models to fall back to CPU or fail when using CUDAExecutionProvider.

The existing CUDA TopK implementation uses a split-compilation pattern (one .cu file per type) with ToCudaType<T> mapping. Since the default template maps integer types to themselves and NumericLimits<T> uses std::numeric_limits<T>, no algorithmic changes were needed — only:

  1. Adding type constraints to kernel registrations (all opset versions)
  2. Adding dispatch cases in ComputeInternal
  3. Creating three new .cu files for template instantiation

All 64 TopK tests pass (including 8 tests for the new types, running on both CPU and CUDA providers).

Test Results

[==========] Running 64 tests from 1 test suite.
...
[ RUN      ] TopKOperator.TopK_Int8
[       OK ] TopKOperator.TopK_Int8 (21 ms)
[ RUN      ] TopKOperator.TopK_Int8_Negative
[       OK ] TopKOperator.TopK_Int8_Negative (21 ms)
[ RUN      ] TopKOperator.TopK_Int8_Smallest
[       OK ] TopKOperator.TopK_Int8_Smallest (21 ms)
[ RUN      ] TopKOperator.TopK_Int16
[       OK ] TopKOperator.TopK_Int16 (21 ms)
[ RUN      ] TopKOperator.TopK_Uint8
[       OK ] TopKOperator.TopK_Uint8 (21 ms)
[ RUN      ] TopKOperator.TopK_Int8_ExplicitAxis
[       OK ] TopKOperator.TopK_Int8_ExplicitAxis (21 ms)
[ RUN      ] TopKOperator.TopK_Int8_Opset24
[       OK ] TopKOperator.TopK_Int8_Opset24 (21 ms)
[ RUN      ] TopKOperator.TopK_Uint8_Opset24
[       OK ] TopKOperator.TopK_Uint8_Opset24 (21 ms)
...
[  PASSED  ] 64 tests.

Add type constraints and dispatch cases for int8_t, int16_t, and uint8_t
in the CUDA TopK kernel (opset 1-9, 10, 11-23, 24), along with three new
.cu template instantiation files. This is the CUDA counterpart to the CPU
support added in microsoft#27860.

Fixes microsoft#27859
Comment thread onnxruntime/core/providers/cuda/math/topk.cc Outdated
Comment thread onnxruntime/core/providers/cuda/math/topk.cc Outdated
@tianleiwu
Copy link
Copy Markdown
Contributor

tianleiwu commented Mar 26, 2026

Please also update docs/OperatorKernels.md. You may download the updated file in artifact of Windows GPU Doc Gen CI Pipeline job.

@tianleiwu
Copy link
Copy Markdown
Contributor

/azp run Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines successfully started running 1 pipeline(s).

Remove these types from opset 1–10 as they are only supported starting from opset 11.
@azure-pipelines
Copy link
Copy Markdown

Commenter does not have sufficient privileges for PR 27862 in repo microsoft/onnxruntime

@elwhyjay
Copy link
Copy Markdown
Contributor Author

@tianleiwu Hi, I've addressed the review feedback in the latest commit (restricting int8/int16/uint8 TopK CUDA kernels to opset 11+). Could you please trigger the CI pipeline so I can download the updated OperatorKernels.md from the artifact? Thanks!

@tianleiwu
Copy link
Copy Markdown
Contributor

/azp run Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines successfully started running 1 pipeline(s).

@elwhyjay
Copy link
Copy Markdown
Contributor Author

@tianleiwu Done — updated docs/OperatorKernels.md with the artifact from the Windows GPU Doc Gen CI Pipeline. Thanks for the review!

@elwhyjay
Copy link
Copy Markdown
Contributor Author

@tianleiwu Friendly ping — I've addressed all the feedback and updated OperatorKernels.md. The CPU counterpart (#27860) has already been merged. Is there anything else needed for this to move forward?

@tianleiwu
Copy link
Copy Markdown
Contributor

/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline, Web CI Pipeline, ONNX Runtime WebGPU Builds

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines successfully started running 4 pipeline(s).

@elwhyjay
Copy link
Copy Markdown
Contributor Author

@tianleiwu Hi, it looks like the CI failure was due to a self-hosted runner losing connection — not related to code changes. Could you re-trigger the pipeline when you get a chance? Thanks!

@tianleiwu tianleiwu enabled auto-merge (squash) April 19, 2026 04:06
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for tightening the opset gating after the earlier review. One remaining blocker is test coverage: this PR enables new CUDA TopK runtime paths, but currently only the compile is covered. Please add CUDA-covered TopK cases for the newly registered int8/int16/uint8 types before merging.

Comment thread onnxruntime/core/providers/cuda/math/topk.cc
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw some test cases in #2786.

@elwhyjay
Copy link
Copy Markdown
Contributor Author

@tianleiwu Hi! Thanks for the review and approval!

@tianleiwu tianleiwu merged commit f018066 into microsoft:main Apr 19, 2026
126 of 136 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature Request] Add INT8/INT16/UINT8 support for TopK operator

2 participants