feat: add tile scatter_update PTO lowering#1004
feat: add tile scatter_update PTO lowering#1004Little-oil wants to merge 1 commit intohw-native-sys:mainfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds PTO codegen for Changes
Sequence DiagramsequenceDiagram
participant Client as Client Code
participant Op as tile.scatter_update Op
participant PTO as PTOCodegen
participant MLIR as MLIR Builder
Client->>Op: invoke tile.scatter_update(input, index, src, dim)
Op->>PTO: MakeScatterUpdateCodegenPTO(...)
PTO->>MLIR: pto.treshape(index) / pto.treshape(src) / pto.treshape(dest)
PTO->>MLIR: scf.for i ...
PTO->>MLIR: scf.for j ...
PTO->>MLIR: EmitFlatOffsetSSAFromValues(...) → arith.muli/arith.addi
PTO->>MLIR: pto.tgetval(src_linear, src_off) → load scalar
PTO->>MLIR: pto.tsetval(dst_linear, dst_off, value) → in-place store
PTO->>Op: return lowered IR (dest aliases input)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request implements the backend code generation for the tile.scatter_update operation, which performs in-place updates on input tiles. Key changes include the addition of indentation helpers in PTOCodegen, the implementation of MakeScatterUpdateCodegenPTO using nested loops and PTO-specific value accessors, and a refactoring of memory allocation logic in InitMemRef to account for physical fractal capacity. A potential issue was identified in the scatter update codegen where a rank mismatch between indices and shapes could trigger an internal check failure if the input tiles are not 2D; a suggestion was provided to use a consistent 2D shape representation during offset calculation.
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/ir/op/tile_ops/transform.cpp (1)
388-433:⚠️ Potential issue | 🟠 MajorValidate
srcagainst the documented scatter-update shapes.This deduce path currently accepts any same-rank
src, but the lowering only works for[b*s, d]in 2D and[b, s, 1, d]in 4D. For example,input=[16,32],index=[2,4],src=[2,4]passes here, and the backend still iterateskover 32 columns, which reads past the logicalsrcextent. Please enforce theb*s/[b,s,1,d]shape relations here before codegen.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/ir/op/tile_ops/transform.cpp` around lines 388 - 433, DeduceTileScatterUpdateType currently allows any same-rank src but lowering requires specific shapes; update DeduceTileScatterUpdateType to validate src_shape exactly: for 2D input (input_type->shape_.size()==2) require index_type->shape_.size()==2 and enforce src_type->shape_.size()==2, src_type->shape_[0] == index_type->shape_[0] * index_type->shape_[1] and src_type->shape_[1] == input_type->shape_[1]; for 4D input require src_type->shape_.size()==4 and enforce src_type->shape_[0] == index_type->shape_[0], src_type->shape_[1] == index_type->shape_[1], src_type->shape_[2] == 1, and src_type->shape_[3] == input_type->shape_[3]; use CHECK with clear error messages referencing input_type, index_type, and src_type to fail when shapes mismatch.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/backend/common/pto_ops_common.cpp`:
- Around line 590-596: The code currently flattens offsets using the original
buffer shapes, causing EmitFlatOffsetSSAFromValues to see mismatched rank (e.g.,
2 vs 4) and abort; to fix, ensure offsets are computed against the reshaped 2D
views produced by make_linear_tile_type: after creating index_linear_type,
src_linear_type, and dst_linear_type and calling emit_treshape (index_linear,
src_linear, dst_linear), use the linear types' shape information when calling
EmitFlatOffsetSSAFromValues (or any offset-flattening helper) instead of
index_type->shape_, src_type->shape_, or input_type->shape_; update all places
that compute flat offsets (including around EmitFlatOffsetSSAFromValues and uses
of index_ssa/src_ssa/dst SSA) to reference
index_linear_type/src_linear_type/dst_linear_type so the indices.size() matches
shape.size().
In `@src/ir/transforms/init_memref.cpp`:
- Around line 70-90: Add a direct include of <algorithm> to this translation
unit to satisfy the use of std::max in GetTileAllocationSizeBytes; locate the
function GetTileAllocationSizeBytes in src/ir/transforms/init_memref.cpp and add
the header include for <algorithm> near the other includes so the call to
std::max<uint64_t>(...) is not relying on transitive includes.
In `@tests/ut/codegen/test_pto_codegen_ops.py`:
- Around line 1002-1025: The test test_tile_scatter_update_loop_bound currently
asserts "8" appears anywhere in the generated MLIR (mlir) which is brittle;
update the assertion to specifically check the scf.for loop upper bound (or the
SSA value used as the loop bound) so it can't be satisfied by the src_t shape.
In the test function (test_tile_scatter_update_loop_bound) locate the mlir
string and replace the loose assert with a targeted match — e.g., search for the
scf.for line pattern like "scf.for .* to 8" or extract the SSA constant that
feeds the loop bound and assert that its definition equals 8 — ensuring the
assertion references the scf.for loop bound or its loop-bound SSA symbol rather
than a raw "8".
---
Outside diff comments:
In `@src/ir/op/tile_ops/transform.cpp`:
- Around line 388-433: DeduceTileScatterUpdateType currently allows any
same-rank src but lowering requires specific shapes; update
DeduceTileScatterUpdateType to validate src_shape exactly: for 2D input
(input_type->shape_.size()==2) require index_type->shape_.size()==2 and enforce
src_type->shape_.size()==2, src_type->shape_[0] == index_type->shape_[0] *
index_type->shape_[1] and src_type->shape_[1] == input_type->shape_[1]; for 4D
input require src_type->shape_.size()==4 and enforce src_type->shape_[0] ==
index_type->shape_[0], src_type->shape_[1] == index_type->shape_[1],
src_type->shape_[2] == 1, and src_type->shape_[3] == input_type->shape_[3]; use
CHECK with clear error messages referencing input_type, index_type, and src_type
to fail when shapes mismatch.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 59d5f13b-8be0-4257-ba97-f88aa3a60845
📒 Files selected for processing (6)
include/pypto/codegen/pto/pto_codegen.hsrc/backend/common/pto_ops_common.cppsrc/ir/op/tile_ops/transform.cppsrc/ir/transforms/init_memref.cpptests/ut/codegen/test_pto_codegen_ops.pytests/ut/ir/transforms/test_init_memref.py
c2a0ce3 to
0bb800a
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (2)
src/backend/common/pto_ops_common.cpp (1)
632-634:⚠️ Potential issue | 🔴 CriticalFlatten
src_off/dst_offagainst synthetic[rows, d]shapes, not the original tile ranks.The 4D form still aborts here:
EmitFlatOffsetSSAFromValues(...)gets two indices (idx,k/row,k) butsrc_type->shape_andinput_type->shape_are rank-4. Also, using the current[1, total_elements]reshape type would compute the wrong offset. This needs a[src_total / d_val, d_val]/[dst_total / d_val, d_val]shape (or directidx * d + karithmetic) before the 4D path can lower successfully.Minimal fix sketch
- std::string src_off = EmitFlatOffsetSSAFromValues({idx, k_var}, src_type->shape_, codegen, "src_off"); - - std::string dst_off = EmitFlatOffsetSSAFromValues({row_idx, k_var}, input_type->shape_, codegen, "dst_off"); + const std::vector<ir::ExprPtr> src_linear_shape = { + std::make_shared<ir::ConstInt>(src_total / d_val, DataType::INT64, ir::Span::unknown()), + std::make_shared<ir::ConstInt>(d_val, DataType::INT64, ir::Span::unknown()), + }; + const std::vector<ir::ExprPtr> dst_linear_shape = { + std::make_shared<ir::ConstInt>(dst_total / d_val, DataType::INT64, ir::Span::unknown()), + std::make_shared<ir::ConstInt>(d_val, DataType::INT64, ir::Span::unknown()), + }; + std::string src_off = + EmitFlatOffsetSSAFromValues({idx, k_var}, src_linear_shape, codegen, "src_off"); + std::string dst_off = + EmitFlatOffsetSSAFromValues({row_idx, k_var}, dst_linear_shape, codegen, "dst_off");🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/backend/common/pto_ops_common.cpp` around lines 632 - 634, src_off and dst_off are computed with EmitFlatOffsetSSAFromValues using the original rank-4 shapes (src_type->shape_ and input_type->shape_), but the indices provided (idx, k_var and row_idx, k_var) refer to a synthetic [rows, d] layout; change the flattening to use a 2D shape [total_rows, d_val] (i.e. compute src_total/d_val and dst_total/d_val) or replace the call with explicit arithmetic idx * d_val + k_var so EmitFlatOffsetSSAFromValues receives the correct 2D shape; update the calls that produce src_off and dst_off (the two lines calling EmitFlatOffsetSSAFromValues) to either pass the computed [rows, d_val] shape or compute offsets as idx * d_val + k_var/row_idx * d_val + k_var respectively.tests/ut/codegen/test_pto_codegen_ops.py (1)
1023-1025:⚠️ Potential issue | 🟡 MinorAssert the actual
scf.forbounds instead of a raw"8".This doesn't verify the lowering. The codegen emits nested loops with bounds
2and4, and"8"is already present in this fixture viasrc_t's shape, so the test can pass even if the loop bounds regress.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ut/codegen/test_pto_codegen_ops.py` around lines 1023 - 1025, The test currently checks for a raw "8" in the MLIR which can be present for other reasons; instead assert the actual scf.for loop bounds emitted by the lowering. Replace the assert "8" check with assertions that the generated MLIR (from self._generate_mlir(Prog)) contains scf.for loop headers with the expected upper bounds (e.g. occurrences of "scf.for" lines containing "to 2" and "to 4" or the equivalent "to <value>" text for b and s), so the test verifies the nested loop bounds (2 and 4) rather than a stray "8".
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/ir/op/tile_ops/transform.cpp`:
- Around line 425-433: The code currently overwrites any existing logical extent
by setting tile_view.valid_shape = input_type->shape_, which loses partial-tile
information and causes downstream tile.store to treat padded elements as valid;
change the logic in the tile view construction (the tile_view /
input_type->tile_view_ handling before constructing TileType) to preserve an
existing valid_shape when input_type->tile_view_ is present and only default to
input_type->shape_ when no valid_shape was previously set, then pass that
preserved/defaulted tile_view into the TileType constructor.
---
Duplicate comments:
In `@src/backend/common/pto_ops_common.cpp`:
- Around line 632-634: src_off and dst_off are computed with
EmitFlatOffsetSSAFromValues using the original rank-4 shapes (src_type->shape_
and input_type->shape_), but the indices provided (idx, k_var and row_idx,
k_var) refer to a synthetic [rows, d] layout; change the flattening to use a 2D
shape [total_rows, d_val] (i.e. compute src_total/d_val and dst_total/d_val) or
replace the call with explicit arithmetic idx * d_val + k_var so
EmitFlatOffsetSSAFromValues receives the correct 2D shape; update the calls that
produce src_off and dst_off (the two lines calling EmitFlatOffsetSSAFromValues)
to either pass the computed [rows, d_val] shape or compute offsets as idx *
d_val + k_var/row_idx * d_val + k_var respectively.
In `@tests/ut/codegen/test_pto_codegen_ops.py`:
- Around line 1023-1025: The test currently checks for a raw "8" in the MLIR
which can be present for other reasons; instead assert the actual scf.for loop
bounds emitted by the lowering. Replace the assert "8" check with assertions
that the generated MLIR (from self._generate_mlir(Prog)) contains scf.for loop
headers with the expected upper bounds (e.g. occurrences of "scf.for" lines
containing "to 2" and "to 4" or the equivalent "to <value>" text for b and s),
so the test verifies the nested loop bounds (2 and 4) rather than a stray "8".
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 11ae86ec-cb1d-4ecb-afb2-bd64c0f5bda1
📒 Files selected for processing (6)
include/pypto/codegen/pto/pto_codegen.hsrc/backend/common/pto_ops_common.cppsrc/ir/op/tile_ops/transform.cppsrc/ir/transforms/init_memref.cpptests/ut/codegen/test_pto_codegen_ops.pytests/ut/ir/transforms/test_init_memref.py
🚧 Files skipped from review as they are similar to previous changes (2)
- tests/ut/ir/transforms/test_init_memref.py
- include/pypto/codegen/pto/pto_codegen.h
4ee9633 to
1df5a21
Compare
- Add PTO lowering for tile.scatter_update using scalar pto.tgetval / pto.tsetval loops - Keep the in-place scatter result explicitly aliased to the input tile SSA via set_output_reuses_input(0) - Avoid runtime treshape by directly using original 2D tile buffers - Add UT coverage for PTO codegen - Add ST runtime coverage with 5 test cases: basic FP32, FP16 dtype, duplicate indices, single batch (b=1), and shuffled indices - Bump PTOAS_VERSION to v0.25 with matching SHA256 Fixes hw-native-sys#920
fb4eb0c to
4de9293
Compare
Summary
tile.scatter_updateusing scalarpto.tgetval/pto.tsetvalloopstreshapein scatter_update lowering by directly using the original 2D tile bufferstile.scatter_updateTesting
cmake --build build --parallelexport PYTHONPATH=$(pwd)/python:$PYTHONPATH && python -m pytest tests/ut/codegen/test_pto_codegen_ops.py -k scatter_update -vtask-submit --device auto --run pytest tests/st/runtime/test_scatter_update.py -v --platform=a2a3 --save-kernels --device TASK_DEVICENote
TLOAD -> SetValuehas been updated;tile.scatter_updatenow passes ona2a3hardware with the newer PTOAS export.Fixes #920