feat(ir): add tile.mscatter op for per-element scatter-store to GM#936
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughA new tile scatter operation Changes
Sequence Diagram(s)sequenceDiagram
actor User
participant DSL as "DSL Layer (pl.mscatter)"
participant IR as "IR Layer (tile.mscatter)"
participant TS as "Type System\n(DeduceTileMscatterType)"
participant Backend as "Backend Codegen"
participant PTO as "PTO Lowering"
User->>DSL: pl.mscatter(src_tile, idx_tile, output)
DSL->>IR: _ir_ops.mscatter(src, idx, output)
IR->>TS: Deduce types / validate args
TS-->>IR: Validated output tensor type
IR-->>DSL: Return Tensor(Call)
DSL-->>User: Result wrapper
Note over Backend,PTO: Compilation / lowering
IR->>Backend: tile.mscatter(src, idx, output)
Backend->>PTO: emit pto.partition_view(output_tensor)
PTO-->>Backend: partition_view handle
Backend->>PTO: emit pto.mscatter(src, idx, partition_view)
PTO-->>Backend: scatter lowered
Backend-->>IR: register result tensor view
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
tests/st/runtime/test_mscatter.py (1)
249-572: Consider adding INT16 test coverage for completeness.The IR definition in
memory.cpplists INT16 as a supported dtype fortile.mscatter, but the test matrix only covers FP32, FP16, and INT32. Adding an INT16 test case would complete the dtype coverage.Since these tests are currently skipped pending PTOAS support, this can be addressed later.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/st/runtime/test_mscatter.py` around lines 249 - 572, Add INT16 test coverage by creating test classes mirroring the INT32 cases (e.g., MscatterINT16SeqTestCase, MscatterINT16RevTestCase, MscatterINT16RandPermTestCase and a larger MscatterINT16_16x64RandPermTestCase) that follow the pattern in MscatterINT32SeqTestCase and MscatterFP16/FP32 classes: use DataType.INT16 in define_tensors with the same init helpers (_init_randint_8x32, _init_sequential_8x32, _init_reversed_8x32, _init_random_perm_16x64, etc.), return the corresponding program names (MscatterINT16_8x32Program, MscatterINT16_16x64Program) from get_program, and implement compute_expected to create a torch.zeros(..., dtype=torch.int16) and assign out[tensors["idx_tensor"].flatten().long()] = tensors["src_tensor"].flatten(); keep the tests marked/skipped consistent with existing PTOAS-skipped tests.tests/ut/ir/operators/test_tile_ops.py (1)
2009-2142: Consider adding INT16/INT32 happy-path cases fortile.mscatter.Current positive tests cover FP32/FP16 only, while the op contract also allows INT16/INT32. Adding those two cases would lock in the full supported dtype matrix.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ut/ir/operators/test_tile_ops.py` around lines 2009 - 2142, Add two positive tests to TestTileMscatterOps that mirror the existing FP16/FP32 cases but use INT16 and INT32 dtypes: create span, rows, cols, tensor_n ConstInt values; build src_type = ir.TileType([rows, cols], DataType.INT16) and src_type = ir.TileType([rows, cols], DataType.INT32] respectively, idx_type = ir.TileType([rows, cols], DataType.INT32), tensor_type = ir.TensorType([tensor_n], DataType.INT16) and DataType.INT32; create src_var, idx_var, out_var, call = tile.mscatter(src_var, idx_var, out_var), assert call.op.name == "tile.mscatter" and assert isinstance(call.type, ir.TensorType) and call.type.dtype equals the corresponding DataType (INT16 or INT32). Ensure test names are distinct (e.g., test_tile_mscatter_int16 and test_tile_mscatter_int32) and follow the same pattern as test_tile_mscatter_fp16.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/ir/op/tile_ops/memory.cpp`:
- Around line 565-567: The current type check in tile.mscatter only validates
rank equality (using idx_type->shape_.size() vs src_type->shape_.size()) but
allows same-rank different-shape tensors; update the validation in memory.cpp
(the tile.mscatter type-deduction / CHECK around idx_type and src_type) to
assert full shape equality by comparing idx_type->shape_ and src_type->shape_
elementwise (or via direct equality) and emit a clear error referencing op_name
when they differ so mismatched same-rank shapes are rejected at IR validation
time.
---
Nitpick comments:
In `@tests/st/runtime/test_mscatter.py`:
- Around line 249-572: Add INT16 test coverage by creating test classes
mirroring the INT32 cases (e.g., MscatterINT16SeqTestCase,
MscatterINT16RevTestCase, MscatterINT16RandPermTestCase and a larger
MscatterINT16_16x64RandPermTestCase) that follow the pattern in
MscatterINT32SeqTestCase and MscatterFP16/FP32 classes: use DataType.INT16 in
define_tensors with the same init helpers (_init_randint_8x32,
_init_sequential_8x32, _init_reversed_8x32, _init_random_perm_16x64, etc.),
return the corresponding program names (MscatterINT16_8x32Program,
MscatterINT16_16x64Program) from get_program, and implement compute_expected to
create a torch.zeros(..., dtype=torch.int16) and assign
out[tensors["idx_tensor"].flatten().long()] = tensors["src_tensor"].flatten();
keep the tests marked/skipped consistent with existing PTOAS-skipped tests.
In `@tests/ut/ir/operators/test_tile_ops.py`:
- Around line 2009-2142: Add two positive tests to TestTileMscatterOps that
mirror the existing FP16/FP32 cases but use INT16 and INT32 dtypes: create span,
rows, cols, tensor_n ConstInt values; build src_type = ir.TileType([rows, cols],
DataType.INT16) and src_type = ir.TileType([rows, cols], DataType.INT32]
respectively, idx_type = ir.TileType([rows, cols], DataType.INT32), tensor_type
= ir.TensorType([tensor_n], DataType.INT16) and DataType.INT32; create src_var,
idx_var, out_var, call = tile.mscatter(src_var, idx_var, out_var), assert
call.op.name == "tile.mscatter" and assert isinstance(call.type, ir.TensorType)
and call.type.dtype equals the corresponding DataType (INT16 or INT32). Ensure
test names are distinct (e.g., test_tile_mscatter_int16 and
test_tile_mscatter_int32) and follow the same pattern as
test_tile_mscatter_fp16.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: bf02517c-aad5-4625-9e0f-2b333a52d563
📒 Files selected for processing (8)
python/pypto/ir/op/tile_ops.pypython/pypto/language/__init__.pypython/pypto/language/op/__init__.pypython/pypto/language/op/tile_ops.pysrc/backend/common/pto_ops_common.cppsrc/ir/op/tile_ops/memory.cpptests/st/runtime/test_mscatter.pytests/ut/ir/operators/test_tile_ops.py
There was a problem hiding this comment.
Code Review
This pull request implements the mscatter operation, enabling scatter-store functionality from tiles to tensors using per-element indices. The changes span the IR definition, Python DSL wrappers, and C++ backend codegen, supported by new unit and runtime tests. Feedback suggests adding a rank check for the output tensor to ensure it is not a scalar and correcting the import source for the DSL wrapper to align with project conventions.
432a20e to
9a8577a
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
♻️ Duplicate comments (1)
src/ir/op/tile_ops/memory.cpp (1)
565-567:⚠️ Potential issue | 🟠 MajorReject same-rank
idxtiles with different extents.This only checks rank, so
src[8, 32]andidx[4, 64]still pass even thoughtile.mscatteris defined elementwise. That can hand codegen two UB tiles with different element counts. Validate full shape equality here, not just rank.🔧 Suggested fix
CHECK(idx_type->shape_.size() == src_type->shape_.size()) << "The operator " << op_name << " requires idx rank to match src rank (" << src_type->shape_.size() << "), but got " << idx_type->shape_.size(); + for (size_t i = 0; i < src_type->shape_.size(); ++i) { + CHECK(idx_type->shape_[i]->ToString() == src_type->shape_[i]->ToString()) + << "The operator " << op_name << " requires idx shape to match src shape at dim " << i + << ", but got src dim " << src_type->shape_[i]->ToString() + << " and idx dim " << idx_type->shape_[i]->ToString(); + }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/ir/op/tile_ops/memory.cpp` around lines 565 - 567, The CHECK currently only compares ranks and should instead validate full shape equality to prevent mismatched element counts: in the tile `mscatter` validation (the CHECK using idx_type and src_type and op_name), replace the rank-only check with a full shape comparison (e.g., compare idx_type->shape_ to src_type->shape_) and update the error message to report both shapes when they differ so mismatched extents (not just rank) are rejected.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@python/pypto/language/__init__.py`:
- Around line 136-138: The module now imports mscatter from .op.tile_ops but
does not add it to the exported symbol list; update the __all__ list in
pypto.language.__init__ to include "mscatter" alongside the other promoted tile
ops so that from pypto.language import * and attribute exports (e.g.,
pl.mscatter) expose it; locate the __all__ definition in __init__.py and append
"mscatter" to that list.
In `@src/backend/common/pto_ops_common.cpp`:
- Around line 741-772: The partition-view builder currently always uses
tensor_type->shape_ but must account for TensorLayout::DN; update the block that
constructs partition_view/partition_type (around
codegen.GetOrCreateTensorView(output_tensor), partition_line, and
partition_type) to detect if tensor_type->layout_ == TensorLayout::DN and, when
true, use a view_shape that swaps the trailing two dimensions before emitting
offsets, sizes, and the partition_type string (apply the same swap when emitting
constants via codegen.GetIndexConstant or expressions via
codegen.GetExprAsCode); keep tensor_view as returned by
codegen.GetOrCreateTensorView(output_tensor) but ensure the sizes and
partition_type reflect the swapped DN ordering so tile.mscatter partitions match
DN physical layout.
In `@tests/st/runtime/test_mscatter.py`:
- Around line 348-361: The test fails to seed the sparse-output case: in
define_tensors add an explicit zero initializer for the "output" TensorSpec (or
otherwise ensure the output tensor is pre-seeded to zeros) so unwritten
positions are deterministic; alternatively, in compute_expected read the
existing pre-seeded tensors["output"] and only overwrite indices given by
tensors["idx_tensor"] instead of starting from a fresh torch.zeros — update
either define_tensors (TensorSpec "output") or compute_expected to use the
pre-seeded output buffer to build the golden result for
MscatterFP32_8x32_LargeOutputProgram.
---
Duplicate comments:
In `@src/ir/op/tile_ops/memory.cpp`:
- Around line 565-567: The CHECK currently only compares ranks and should
instead validate full shape equality to prevent mismatched element counts: in
the tile `mscatter` validation (the CHECK using idx_type and src_type and
op_name), replace the rank-only check with a full shape comparison (e.g.,
compare idx_type->shape_ to src_type->shape_) and update the error message to
report both shapes when they differ so mismatched extents (not just rank) are
rejected.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: da49c5ff-47c7-4011-b625-6785f9b7db21
📒 Files selected for processing (8)
python/pypto/ir/op/tile_ops.pypython/pypto/language/__init__.pypython/pypto/language/op/__init__.pypython/pypto/language/op/tile_ops.pysrc/backend/common/pto_ops_common.cppsrc/ir/op/tile_ops/memory.cpptests/st/runtime/test_mscatter.pytests/ut/ir/operators/test_tile_ops.py
✅ Files skipped from review due to trivial changes (1)
- python/pypto/language/op/init.py
🚧 Files skipped from review as they are similar to previous changes (2)
- python/pypto/ir/op/tile_ops.py
- python/pypto/language/op/tile_ops.py
bebab0f to
efa021e
Compare
18c1445 to
ba4d5b8
Compare
- Add shape equality validation between idx and src tiles (not just rank) - Reject scalar (rank-0) output_tensor in type deduction - Add init_value=torch.zeros for sparse scatter test output buffer - Add mscatter to __all__ in language/__init__.py - Add unit tests for shape mismatch and scalar output errors
Summary
Add
tile.mscatteroperation that maps to the PTOASpto.mscatterinstruction for per-element scatter-store from UB tile to GM tensor:Changes
src/ir/op/tile_ops/memory.cpp):tile.mscatterwith type deduction validating src (FP16/FP32/INT16/INT32), idx (INT32, same shape as src), and output_tensor (TensorType, non-scalar, same dtype as src)src/backend/common/pto_ops_common.cpp): Emitspto.partition_view+pto.mscatter ins(src, idx) outs(pview)with row_major layout constraints on inputspython/pypto/ir/op/tile_ops.py):tile.mscatter(src, idx, output_tensor)python/pypto/language/op/tile_ops.py):pl.mscatter(src_tile, idx_tile, out_tensor)exported at top-levelplnamespace and added to__all__tests/ut/ir/operators/test_tile_ops.py): 9 tests covering basic usage (FP32/FP16), error paths (wrong dtype, rank mismatch, shape mismatch, scalar output, arg count)tests/st/runtime/test_mscatter.py): 11 tests covering FP32/FP16/INT32, 8x32/16x64 shapes, sequential/reversed/random/strided index patterns — all 11 passed on NPU device.github/workflows/ci.yml):--pto-isa-committoed0b4643(includes mscatter CPU simulator fix12f663c3)Platform Support
pto.mscatteris currently only supported on A5 (Ascend 950). A3 (Ascend 910B) does not yet support this instruction in PTOAS, so ST tests are marked with@pytest.mark.a5and are not executed on 910B.Testing
Fixes #921