Skip to content

feat(ir): add tile.mscatter op for per-element scatter-store to GM#936

Merged
lyfne123 merged 2 commits intohw-native-sys:mainfrom
Little-oil:issue-921-add-mscatter-op
Apr 17, 2026
Merged

feat(ir): add tile.mscatter op for per-element scatter-store to GM#936
lyfne123 merged 2 commits intohw-native-sys:mainfrom
Little-oil:issue-921-add-mscatter-op

Conversation

@Little-oil
Copy link
Copy Markdown
Contributor

@Little-oil Little-oil commented Apr 9, 2026

Summary

Add tile.mscatter operation that maps to the PTOAS pto.mscatter instruction for per-element scatter-store from UB tile to GM tensor:

output_tensor[idx[i, j]] = src[i, j]

Changes

  • C++ op registration (src/ir/op/tile_ops/memory.cpp): tile.mscatter with type deduction validating src (FP16/FP32/INT16/INT32), idx (INT32, same shape as src), and output_tensor (TensorType, non-scalar, same dtype as src)
  • PTO codegen (src/backend/common/pto_ops_common.cpp): Emits pto.partition_view + pto.mscatter ins(src, idx) outs(pview) with row_major layout constraints on inputs
  • Python IR wrapper (python/pypto/ir/op/tile_ops.py): tile.mscatter(src, idx, output_tensor)
  • Python DSL wrapper (python/pypto/language/op/tile_ops.py): pl.mscatter(src_tile, idx_tile, out_tensor) exported at top-level pl namespace and added to __all__
  • Unit tests (tests/ut/ir/operators/test_tile_ops.py): 9 tests covering basic usage (FP32/FP16), error paths (wrong dtype, rank mismatch, shape mismatch, scalar output, arg count)
  • ST runtime tests (tests/st/runtime/test_mscatter.py): 11 tests covering FP32/FP16/INT32, 8x32/16x64 shapes, sequential/reversed/random/strided index patterns — all 11 passed on NPU device
  • CI updates (.github/workflows/ci.yml):
    • Bumped --pto-isa-commit to ed0b4643 (includes mscatter CPU simulator fix 12f663c3)
    • Bumped PTOAS to v0.26 for both system-tests (aarch64) and a5sim (x86_64) jobs

Platform Support

pto.mscatter is currently only supported on A5 (Ascend 950). A3 (Ascend 910B) does not yet support this instruction in PTOAS, so ST tests are marked with @pytest.mark.a5 and are not executed on 910B.

Testing

  • Unit tests pass (9/9)
  • ST runtime tests pass on NPU device (11/11, A5 only)
  • CI passes (all jobs green)
  • Pre-commit hooks pass (clang-format, cpplint, ruff, pyright)

Fixes #921

@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Apr 9, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

A new tile scatter operation tile.mscatter was added end-to-end: IR op and type deduction, Python IR and DSL wrappers, backend codegen lowering to pto.partition_view + pto.mscatter, module exports, and unit + runtime tests (runtime tests currently skipped).

Changes

Cohort / File(s) Summary
IR Python wrapper
python/pypto/ir/op/tile_ops.py
Added exported `mscatter(src: Expr, idx: Expr, output_tensor: Expr, span: Span
DSL layer
python/pypto/language/op/tile_ops.py, python/pypto/language/__init__.py, python/pypto/language/op/__init__.py
Added DSL mscatter(src: Tile, idx: Tile, output_tensor: Tensor) -> Tensor and re-exported/promoted mscatter into language and language.op public APIs (updated all in op/init.py; language/init.py re-exports).
IR type system & op registration
src/ir/op/tile_ops/memory.cpp
Registered tile.mscatter and added DeduceTileMscatterType enforcing arg count (3), src dtype ∈ {FP16,FP32,INT16,INT32}, idx dtype == INT32 and matching rank, and output_tensor dtype == src; configured memory specs and output reuse.
Backend codegen & registration
src/backend/common/pto_ops_common.cpp
Added custom codegen for tile.mscatter that emits a full-tensor pto.partition_view (zero offsets, sizes from output_tensor) then pto.mscatter(src, idx, partition_view); removed tile.mscatter from simple ops table and registered with layout constraints (src/idx row-major).
Unit tests (IR)
tests/ut/ir/operators/test_tile_ops.py
Added positive/negative unit tests for tile.mscatter construction and type-deduction error cases (dtype/rank/arg-count checks).
Runtime tests (system)
tests/st/runtime/test_mscatter.py
Added comprehensive runtime tests covering FP32/FP16/INT32 and tile shapes 8x32, 16x64 with many index patterns; tests currently globally skipped pending backend behavior.

Sequence Diagram(s)

sequenceDiagram
    actor User
    participant DSL as "DSL Layer (pl.mscatter)"
    participant IR as "IR Layer (tile.mscatter)"
    participant TS as "Type System\n(DeduceTileMscatterType)"
    participant Backend as "Backend Codegen"
    participant PTO as "PTO Lowering"

    User->>DSL: pl.mscatter(src_tile, idx_tile, output)
    DSL->>IR: _ir_ops.mscatter(src, idx, output)
    IR->>TS: Deduce types / validate args
    TS-->>IR: Validated output tensor type
    IR-->>DSL: Return Tensor(Call)
    DSL-->>User: Result wrapper

    Note over Backend,PTO: Compilation / lowering
    IR->>Backend: tile.mscatter(src, idx, output)
    Backend->>PTO: emit pto.partition_view(output_tensor)
    PTO-->>Backend: partition_view handle
    Backend->>PTO: emit pto.mscatter(src, idx, partition_view)
    PTO-->>Backend: scatter lowered
    Backend-->>IR: register result tensor view
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • lyfne123
  • Hzfengsy

Poem

🐰 I hopped through tiles both wide and small,
I carried indices, one and all,
From DSL to IR I made my flight,
Then PTO stitched each value right,
A tiny scatter — stitched at night. 🥕

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 25.22% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title 'feat(ir): add tile.mscatter op for per-element scatter-store to GM' clearly and concisely describes the primary change: adding a new tile.mscatter operation for scatter-store semantics.
Linked Issues check ✅ Passed All primary objectives from issue #921 are met: C++ op registration with type deduction, PTO codegen mapping, Python IR/DSL wrappers, and comprehensive unit/ST tests with proper skip annotations.
Out of Scope Changes check ✅ Passed All changes are directly aligned with implementing tile.mscatter: IR definition, codegen, Python bindings, and tests. No unrelated modifications to other systems detected.
Description check ✅ Passed The PR description comprehensively documents the new tile.mscatter operation, detailing implementation across C++, Python IR/DSL layers, codegen mapping, and comprehensive test coverage including unit and ST runtime tests.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
tests/st/runtime/test_mscatter.py (1)

249-572: Consider adding INT16 test coverage for completeness.

The IR definition in memory.cpp lists INT16 as a supported dtype for tile.mscatter, but the test matrix only covers FP32, FP16, and INT32. Adding an INT16 test case would complete the dtype coverage.

Since these tests are currently skipped pending PTOAS support, this can be addressed later.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/st/runtime/test_mscatter.py` around lines 249 - 572, Add INT16 test
coverage by creating test classes mirroring the INT32 cases (e.g.,
MscatterINT16SeqTestCase, MscatterINT16RevTestCase,
MscatterINT16RandPermTestCase and a larger MscatterINT16_16x64RandPermTestCase)
that follow the pattern in MscatterINT32SeqTestCase and MscatterFP16/FP32
classes: use DataType.INT16 in define_tensors with the same init helpers
(_init_randint_8x32, _init_sequential_8x32, _init_reversed_8x32,
_init_random_perm_16x64, etc.), return the corresponding program names
(MscatterINT16_8x32Program, MscatterINT16_16x64Program) from get_program, and
implement compute_expected to create a torch.zeros(..., dtype=torch.int16) and
assign out[tensors["idx_tensor"].flatten().long()] =
tensors["src_tensor"].flatten(); keep the tests marked/skipped consistent with
existing PTOAS-skipped tests.
tests/ut/ir/operators/test_tile_ops.py (1)

2009-2142: Consider adding INT16/INT32 happy-path cases for tile.mscatter.

Current positive tests cover FP32/FP16 only, while the op contract also allows INT16/INT32. Adding those two cases would lock in the full supported dtype matrix.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ut/ir/operators/test_tile_ops.py` around lines 2009 - 2142, Add two
positive tests to TestTileMscatterOps that mirror the existing FP16/FP32 cases
but use INT16 and INT32 dtypes: create span, rows, cols, tensor_n ConstInt
values; build src_type = ir.TileType([rows, cols], DataType.INT16) and src_type
= ir.TileType([rows, cols], DataType.INT32] respectively, idx_type =
ir.TileType([rows, cols], DataType.INT32), tensor_type =
ir.TensorType([tensor_n], DataType.INT16) and DataType.INT32; create src_var,
idx_var, out_var, call = tile.mscatter(src_var, idx_var, out_var), assert
call.op.name == "tile.mscatter" and assert isinstance(call.type, ir.TensorType)
and call.type.dtype equals the corresponding DataType (INT16 or INT32). Ensure
test names are distinct (e.g., test_tile_mscatter_int16 and
test_tile_mscatter_int32) and follow the same pattern as
test_tile_mscatter_fp16.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/ir/op/tile_ops/memory.cpp`:
- Around line 565-567: The current type check in tile.mscatter only validates
rank equality (using idx_type->shape_.size() vs src_type->shape_.size()) but
allows same-rank different-shape tensors; update the validation in memory.cpp
(the tile.mscatter type-deduction / CHECK around idx_type and src_type) to
assert full shape equality by comparing idx_type->shape_ and src_type->shape_
elementwise (or via direct equality) and emit a clear error referencing op_name
when they differ so mismatched same-rank shapes are rejected at IR validation
time.

---

Nitpick comments:
In `@tests/st/runtime/test_mscatter.py`:
- Around line 249-572: Add INT16 test coverage by creating test classes
mirroring the INT32 cases (e.g., MscatterINT16SeqTestCase,
MscatterINT16RevTestCase, MscatterINT16RandPermTestCase and a larger
MscatterINT16_16x64RandPermTestCase) that follow the pattern in
MscatterINT32SeqTestCase and MscatterFP16/FP32 classes: use DataType.INT16 in
define_tensors with the same init helpers (_init_randint_8x32,
_init_sequential_8x32, _init_reversed_8x32, _init_random_perm_16x64, etc.),
return the corresponding program names (MscatterINT16_8x32Program,
MscatterINT16_16x64Program) from get_program, and implement compute_expected to
create a torch.zeros(..., dtype=torch.int16) and assign
out[tensors["idx_tensor"].flatten().long()] = tensors["src_tensor"].flatten();
keep the tests marked/skipped consistent with existing PTOAS-skipped tests.

In `@tests/ut/ir/operators/test_tile_ops.py`:
- Around line 2009-2142: Add two positive tests to TestTileMscatterOps that
mirror the existing FP16/FP32 cases but use INT16 and INT32 dtypes: create span,
rows, cols, tensor_n ConstInt values; build src_type = ir.TileType([rows, cols],
DataType.INT16) and src_type = ir.TileType([rows, cols], DataType.INT32]
respectively, idx_type = ir.TileType([rows, cols], DataType.INT32), tensor_type
= ir.TensorType([tensor_n], DataType.INT16) and DataType.INT32; create src_var,
idx_var, out_var, call = tile.mscatter(src_var, idx_var, out_var), assert
call.op.name == "tile.mscatter" and assert isinstance(call.type, ir.TensorType)
and call.type.dtype equals the corresponding DataType (INT16 or INT32). Ensure
test names are distinct (e.g., test_tile_mscatter_int16 and
test_tile_mscatter_int32) and follow the same pattern as
test_tile_mscatter_fp16.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: bf02517c-aad5-4625-9e0f-2b333a52d563

📥 Commits

Reviewing files that changed from the base of the PR and between 523c08f and 432a20e.

📒 Files selected for processing (8)
  • python/pypto/ir/op/tile_ops.py
  • python/pypto/language/__init__.py
  • python/pypto/language/op/__init__.py
  • python/pypto/language/op/tile_ops.py
  • src/backend/common/pto_ops_common.cpp
  • src/ir/op/tile_ops/memory.cpp
  • tests/st/runtime/test_mscatter.py
  • tests/ut/ir/operators/test_tile_ops.py

Comment thread src/ir/op/tile_ops/memory.cpp
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements the mscatter operation, enabling scatter-store functionality from tiles to tensors using per-element indices. The changes span the IR definition, Python DSL wrappers, and C++ backend codegen, supported by new unit and runtime tests. Feedback suggests adding a rank check for the output tensor to ensure it is not a scalar and correcting the import source for the DSL wrapper to align with project conventions.

Comment thread src/ir/op/tile_ops/memory.cpp
Comment thread python/pypto/language/__init__.py
@Little-oil Little-oil force-pushed the issue-921-add-mscatter-op branch from 432a20e to 9a8577a Compare April 15, 2026 01:54
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (1)
src/ir/op/tile_ops/memory.cpp (1)

565-567: ⚠️ Potential issue | 🟠 Major

Reject same-rank idx tiles with different extents.

This only checks rank, so src[8, 32] and idx[4, 64] still pass even though tile.mscatter is defined elementwise. That can hand codegen two UB tiles with different element counts. Validate full shape equality here, not just rank.

🔧 Suggested fix
   CHECK(idx_type->shape_.size() == src_type->shape_.size())
       << "The operator " << op_name << " requires idx rank to match src rank (" << src_type->shape_.size()
       << "), but got " << idx_type->shape_.size();
+  for (size_t i = 0; i < src_type->shape_.size(); ++i) {
+    CHECK(idx_type->shape_[i]->ToString() == src_type->shape_[i]->ToString())
+        << "The operator " << op_name << " requires idx shape to match src shape at dim " << i
+        << ", but got src dim " << src_type->shape_[i]->ToString()
+        << " and idx dim " << idx_type->shape_[i]->ToString();
+  }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/ir/op/tile_ops/memory.cpp` around lines 565 - 567, The CHECK currently
only compares ranks and should instead validate full shape equality to prevent
mismatched element counts: in the tile `mscatter` validation (the CHECK using
idx_type and src_type and op_name), replace the rank-only check with a full
shape comparison (e.g., compare idx_type->shape_ to src_type->shape_) and update
the error message to report both shapes when they differ so mismatched extents
(not just rank) are rejected.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@python/pypto/language/__init__.py`:
- Around line 136-138: The module now imports mscatter from .op.tile_ops but
does not add it to the exported symbol list; update the __all__ list in
pypto.language.__init__ to include "mscatter" alongside the other promoted tile
ops so that from pypto.language import * and attribute exports (e.g.,
pl.mscatter) expose it; locate the __all__ definition in __init__.py and append
"mscatter" to that list.

In `@src/backend/common/pto_ops_common.cpp`:
- Around line 741-772: The partition-view builder currently always uses
tensor_type->shape_ but must account for TensorLayout::DN; update the block that
constructs partition_view/partition_type (around
codegen.GetOrCreateTensorView(output_tensor), partition_line, and
partition_type) to detect if tensor_type->layout_ == TensorLayout::DN and, when
true, use a view_shape that swaps the trailing two dimensions before emitting
offsets, sizes, and the partition_type string (apply the same swap when emitting
constants via codegen.GetIndexConstant or expressions via
codegen.GetExprAsCode); keep tensor_view as returned by
codegen.GetOrCreateTensorView(output_tensor) but ensure the sizes and
partition_type reflect the swapped DN ordering so tile.mscatter partitions match
DN physical layout.

In `@tests/st/runtime/test_mscatter.py`:
- Around line 348-361: The test fails to seed the sparse-output case: in
define_tensors add an explicit zero initializer for the "output" TensorSpec (or
otherwise ensure the output tensor is pre-seeded to zeros) so unwritten
positions are deterministic; alternatively, in compute_expected read the
existing pre-seeded tensors["output"] and only overwrite indices given by
tensors["idx_tensor"] instead of starting from a fresh torch.zeros — update
either define_tensors (TensorSpec "output") or compute_expected to use the
pre-seeded output buffer to build the golden result for
MscatterFP32_8x32_LargeOutputProgram.

---

Duplicate comments:
In `@src/ir/op/tile_ops/memory.cpp`:
- Around line 565-567: The CHECK currently only compares ranks and should
instead validate full shape equality to prevent mismatched element counts: in
the tile `mscatter` validation (the CHECK using idx_type and src_type and
op_name), replace the rank-only check with a full shape comparison (e.g.,
compare idx_type->shape_ to src_type->shape_) and update the error message to
report both shapes when they differ so mismatched extents (not just rank) are
rejected.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: da49c5ff-47c7-4011-b625-6785f9b7db21

📥 Commits

Reviewing files that changed from the base of the PR and between 432a20e and 9a8577a.

📒 Files selected for processing (8)
  • python/pypto/ir/op/tile_ops.py
  • python/pypto/language/__init__.py
  • python/pypto/language/op/__init__.py
  • python/pypto/language/op/tile_ops.py
  • src/backend/common/pto_ops_common.cpp
  • src/ir/op/tile_ops/memory.cpp
  • tests/st/runtime/test_mscatter.py
  • tests/ut/ir/operators/test_tile_ops.py
✅ Files skipped from review due to trivial changes (1)
  • python/pypto/language/op/init.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • python/pypto/ir/op/tile_ops.py
  • python/pypto/language/op/tile_ops.py

Comment thread python/pypto/language/__init__.py
Comment thread src/backend/common/pto_ops_common.cpp
Comment thread tests/st/runtime/test_mscatter.py
@Little-oil Little-oil force-pushed the issue-921-add-mscatter-op branch 2 times, most recently from bebab0f to efa021e Compare April 16, 2026 09:25
@Little-oil Little-oil force-pushed the issue-921-add-mscatter-op branch from 18c1445 to ba4d5b8 Compare April 17, 2026 01:59
- Add shape equality validation between idx and src tiles (not just rank)
- Reject scalar (rank-0) output_tensor in type deduction
- Add init_value=torch.zeros for sparse scatter test output buffer
- Add mscatter to __all__ in language/__init__.py
- Add unit tests for shape mismatch and scalar output errors
@lyfne123 lyfne123 merged commit c6e39b2 into hw-native-sys:main Apr 17, 2026
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

[New Op] Add op for pto.mscatter instruction

2 participants