Parallelize CPU ScatterElements kernel via ThreadPool#28588
Open
Copilot wants to merge 3 commits into
Open
Conversation
Add multi-threaded execution to both GetIndices() and ScatterData() in the CPU ScatterElements operator implementation. Key changes: - GetIndices: parallelize index validation and normalization using ThreadPool::TryParallelFor - ScatterData: decompose work into independent units based on outer_size * inner_size (dimensions orthogonal to the scatter axis). Each work unit processes axis_size elements along the scatter axis. Work units are guaranteed to write to disjoint output elements, making parallelization safe even with reductions (add, mul, min, max). For the reported workload (axis=0, data=[556416,80], indices=[481385,80]): - inner_size=80 independent work units can run in parallel - Each processes 481385 sequential scatter operations along axis 0 - This enables effective use of multi-core CPUs (24 threads in the issue) The approach avoids write conflicts without locks by exploiting the mathematical property that different (outer, inner) coordinate pairs always map to different output element addresses regardless of the index values along the scatter axis. Agent-Logs-Url: https://github.com/microsoft/onnxruntime/sessions/067e9420-d43f-4f95-ae1e-afba1da9c29a Co-authored-by: tianleiwu <30328909+tianleiwu@users.noreply.github.com>
Copilot
AI
changed the title
[WIP] Fix CPU performance of ScatterElements with reduction='add'
Parallelize CPU ScatterElements kernel via ThreadPool
May 20, 2026
Contributor
There was a problem hiding this comment.
Pull request overview
This PR parallelizes the CPU ScatterElements implementation by leveraging concurrency::ThreadPool::TryParallelFor for both index normalization/validation and the scatter update loop, aiming to significantly improve intra-op CPU utilization and latency for large workloads.
Changes:
- Parallelized
GetIndicesover the flattened indices array usingThreadPool::TryParallelFor. - Reworked
ScatterDatato parallelize overouter_size * inner_sizeindependent work units, each processing the axis dimension sequentially. - ThreadPool is plumbed from
OpKernelContext::GetOperatorThreadPool()through the dispatch path; trainingGatherElementsGradImplusesnullptrto preserve sequential fallback.
Comments suppressed due to low confidence (1)
onnxruntime/core/providers/cpu/tensor/scatter.cc:362
TryParallelForusesstatic_cast<std::ptrdiff_t>(total_work_units). Similar tonum_indices, this can overflow/truncate on 32-bit builds. Please use a checked conversion (narrow/SafeInt) before passing sizes to the threadpool APIs.
concurrency::ThreadPool::TryParallelFor(
tp, static_cast<std::ptrdiff_t>(total_work_units), static_cast<double>(axis_size),
[&](std::ptrdiff_t first, std::ptrdiff_t last) {
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
- Add missing #include <atomic> for self-contained includes - Replace static_cast<std::ptrdiff_t> with narrow<std::ptrdiff_t> for checked conversion on 32-bit builds (both GetIndices and ScatterData) - Add comment clarifying intentional nondeterministic error reporting when multiple indices are out-of-bounds
tianleiwu
approved these changes
May 20, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Parallelizes both
GetIndicesandScatterDatain the CPUScatterElementsimplementation usingThreadPool::TryParallelFor.Key insight: For ScatterElements with
axis=a, work units identified by coordinates orthogonal to the axis (outer_size × inner_size) are guaranteed to write to disjoint output elements—even with reductions. This enables lock-free parallelization without correctness concerns.Changes:
GetIndices: Index validation/normalization parallelized over the flat index arrayScatterData: Rewritten to decompose intoouter_size * inner_sizeindependent work units, each processingaxis_sizesequential scatter operations along the axis dimensionScatterDataDispatchTargetfromOpKernelContext::GetOperatorThreadPool()GatherElementsGradImplpassesnullptr(sequential fallback preserved)For the reported workload (
axis=0, indices shape[481385, 80]): 80 independent parallel streams, each processing 481385 elements—well-suited for multi-core execution.Motivation and Context
The CPU
ScatterElementskernel was entirely sequential—single-threaded index conversion followed by single-threaded scatter—yielding ~761ms on a 24-core ARM system for a workload that an optimized parallel implementation handles in ~6ms (129× gap). The kernel showed zero intra-op thread utilization in ORT profiling.