Skip to content

Parallelize CPU ScatterElements kernel via ThreadPool#28588

Open
Copilot wants to merge 3 commits into
mainfrom
copilot/fix-scatterelements-cpu-performance
Open

Parallelize CPU ScatterElements kernel via ThreadPool#28588
Copilot wants to merge 3 commits into
mainfrom
copilot/fix-scatterelements-cpu-performance

Conversation

Copy link
Copy Markdown
Contributor

Copilot AI commented May 20, 2026

Description

Parallelizes both GetIndices and ScatterData in the CPU ScatterElements implementation using ThreadPool::TryParallelFor.

Key insight: For ScatterElements with axis=a, work units identified by coordinates orthogonal to the axis (outer_size × inner_size) are guaranteed to write to disjoint output elements—even with reductions. This enables lock-free parallelization without correctness concerns.

Changes:

  • GetIndices: Index validation/normalization parallelized over the flat index array
  • ScatterData: Rewritten to decompose into outer_size * inner_size independent work units, each processing axis_size sequential scatter operations along the axis dimension
  • Thread pool plumbed through ScatterDataDispatchTarget from OpKernelContext::GetOperatorThreadPool()
  • Training GatherElementsGradImpl passes nullptr (sequential fallback preserved)

For the reported workload (axis=0, indices shape [481385, 80]): 80 independent parallel streams, each processing 481385 elements—well-suited for multi-core execution.

Motivation and Context

The CPU ScatterElements kernel was entirely sequential—single-threaded index conversion followed by single-threaded scatter—yielding ~761ms on a 24-core ARM system for a workload that an optimized parallel implementation handles in ~6ms (129× gap). The kernel showed zero intra-op thread utilization in ORT profiling.

Add multi-threaded execution to both GetIndices() and ScatterData()
in the CPU ScatterElements operator implementation.

Key changes:
- GetIndices: parallelize index validation and normalization using
  ThreadPool::TryParallelFor
- ScatterData: decompose work into independent units based on
  outer_size * inner_size (dimensions orthogonal to the scatter axis).
  Each work unit processes axis_size elements along the scatter axis.
  Work units are guaranteed to write to disjoint output elements,
  making parallelization safe even with reductions (add, mul, min, max).

For the reported workload (axis=0, data=[556416,80], indices=[481385,80]):
- inner_size=80 independent work units can run in parallel
- Each processes 481385 sequential scatter operations along axis 0
- This enables effective use of multi-core CPUs (24 threads in the issue)

The approach avoids write conflicts without locks by exploiting the
mathematical property that different (outer, inner) coordinate pairs
always map to different output element addresses regardless of the
index values along the scatter axis.

Agent-Logs-Url: https://github.com/microsoft/onnxruntime/sessions/067e9420-d43f-4f95-ae1e-afba1da9c29a

Co-authored-by: tianleiwu <30328909+tianleiwu@users.noreply.github.com>
Copilot AI changed the title [WIP] Fix CPU performance of ScatterElements with reduction='add' Parallelize CPU ScatterElements kernel via ThreadPool May 20, 2026
Copilot AI requested a review from tianleiwu May 20, 2026 16:55
@tianleiwu tianleiwu requested a review from Copilot May 20, 2026 19:26
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR parallelizes the CPU ScatterElements implementation by leveraging concurrency::ThreadPool::TryParallelFor for both index normalization/validation and the scatter update loop, aiming to significantly improve intra-op CPU utilization and latency for large workloads.

Changes:

  • Parallelized GetIndices over the flattened indices array using ThreadPool::TryParallelFor.
  • Reworked ScatterData to parallelize over outer_size * inner_size independent work units, each processing the axis dimension sequentially.
  • ThreadPool is plumbed from OpKernelContext::GetOperatorThreadPool() through the dispatch path; training GatherElementsGradImpl uses nullptr to preserve sequential fallback.
Comments suppressed due to low confidence (1)

onnxruntime/core/providers/cpu/tensor/scatter.cc:362

  • TryParallelFor uses static_cast<std::ptrdiff_t>(total_work_units). Similar to num_indices, this can overflow/truncate on 32-bit builds. Please use a checked conversion (narrow/SafeInt) before passing sizes to the threadpool APIs.
  concurrency::ThreadPool::TryParallelFor(
      tp, static_cast<std::ptrdiff_t>(total_work_units), static_cast<double>(axis_size),
      [&](std::ptrdiff_t first, std::ptrdiff_t last) {

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/core/providers/cpu/tensor/scatter.cc
Comment thread onnxruntime/core/providers/cpu/tensor/scatter.cc Outdated
Comment thread onnxruntime/core/providers/cpu/tensor/scatter.cc Outdated
- Add missing #include <atomic> for self-contained includes
- Replace static_cast<std::ptrdiff_t> with narrow<std::ptrdiff_t> for
  checked conversion on 32-bit builds (both GetIndices and ScatterData)
- Add comment clarifying intentional nondeterministic error reporting
  when multiple indices are out-of-bounds
@tianleiwu tianleiwu marked this pull request as ready for review May 20, 2026 23:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Performance] ScatterElements(reduction="add") is very slow on CPUExecutionProvider due to sequential CPU implementation

3 participants