Skip to content

feat: add tile scatter_update PTO lowering#1004

Open
Little-oil wants to merge 1 commit intohw-native-sys:mainfrom
Little-oil:codegen_for_tile_scatter_update
Open

feat: add tile scatter_update PTO lowering#1004
Little-oil wants to merge 1 commit intohw-native-sys:mainfrom
Little-oil:codegen_for_tile_scatter_update

Conversation

@Little-oil
Copy link
Copy Markdown
Contributor

@Little-oil Little-oil commented Apr 13, 2026

Summary

  • add PTO lowering support for tile.scatter_update using scalar pto.tgetval / pto.tsetval loops
  • keep the in-place scatter result explicitly aliased to the input tile SSA instead of relying on shared UB addresses
  • avoid runtime treshape in scatter_update lowering by directly using the original 2D tile buffers
  • add UT coverage for PTO codegen and ST runtime coverage for tile.scatter_update

Testing

  • cmake --build build --parallel
  • export PYTHONPATH=$(pwd)/python:$PYTHONPATH && python -m pytest tests/ut/codegen/test_pto_codegen_ops.py -k scatter_update -v
  • task-submit --device auto --run pytest tests/st/runtime/test_scatter_update.py -v --platform=a2a3 --save-kernels --device TASK_DEVICE

Note

  • PTOAS-side sync handling for TLOAD -> SetValue has been updated; tile.scatter_update now passes on a2a3 hardware with the newer PTOAS export.

Fixes #920

@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Apr 13, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds PTO codegen for tile.scatter_update, preserves tile metadata and in-place semantics in type deduction, refactors tile allocation sizing, exposes two PTOCodegen indent helpers, and adds unit tests covering lowering and allocation behavior.

Changes

Cohort / File(s) Summary
PTOCodegen Header Utilities
include/pypto/codegen/pto/pto_codegen.h
Added two inline public helpers: IncreaseIndent() and DecreaseIndent() which increment/decrement indent_level_.
Backend Tile Scatter Update Implementation
src/backend/common/pto_ops_common.cpp
Added EmitFlatOffsetSSAFromValues() and MakeScatterUpdateCodegenPTO() to lower tile.scatter_update using pto.treshape, nested scf.for loops, pto.tgetval/pto.tsetval, and flat-offset arithmetic; registered handler in RegisterPTOOps.
IR Type Deduction & Op Registration
src/ir/op/tile_ops/transform.cpp
Updated DeduceTileScatterUpdateType() to preserve TileView metadata and memory_space from the input; marked output 0 as reusing input 0 for in-place semantics.
Memory Allocation Refactoring
src/ir/transforms/init_memref.cpp
Extracted GetTileAllocationSizeBytes(); CreateMemRef() now uses it to size tile allocations considering tile_view_->fractal (physical capacity).
Tests: Codegen & InitMemRef
tests/ut/codegen/test_pto_codegen_ops.py, tests/ut/ir/transforms/test_init_memref.py
Added TestTileScatterUpdateCodegen suite (multiple assertions on generated MLIR: scf.for, pto.tgetval, pto.tsetval, index-cast, loop bound, no pto.tmov) and a test asserting Vec tile allocations use physical tile capacity.

Sequence Diagram

sequenceDiagram
    participant Client as Client Code
    participant Op as tile.scatter_update Op
    participant PTO as PTOCodegen
    participant MLIR as MLIR Builder

    Client->>Op: invoke tile.scatter_update(input, index, src, dim)
    Op->>PTO: MakeScatterUpdateCodegenPTO(...)
    PTO->>MLIR: pto.treshape(index) / pto.treshape(src) / pto.treshape(dest)
    PTO->>MLIR: scf.for i ...
    PTO->>MLIR: scf.for j ...
    PTO->>MLIR: EmitFlatOffsetSSAFromValues(...) → arith.muli/arith.addi
    PTO->>MLIR: pto.tgetval(src_linear, src_off) → load scalar
    PTO->>MLIR: pto.tsetval(dst_linear, dst_off, value) → in-place store
    PTO->>Op: return lowered IR (dest aliases input)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related PRs

Suggested reviewers

  • lyfne123
  • Hzfengsy

Poem

🐰 I hop through loops and flatten the way,

tgetval to borrow, tsetval to stay,
No extra moves, just in-place delight,
Indents increased, then back out of sight,
A tiny rabbit cheers the codegen night. 🥕

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 64.71% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title 'feat: add tile scatter_update PTO lowering' directly describes the main change: adding PTO codegen support for tile.scatter_update, which is the primary objective of the PR.
Linked Issues check ✅ Passed The PR fully addresses issue #920 by adding a PTO codegen mapping for tile.scatter_update in pto_ops_common.cpp with the MakeScatterUpdateCodegenPTO function, enabling proper lowering to PTO instructions.
Out of Scope Changes check ✅ Passed The InitMemRef changes and helper method refactoring in init_memref.cpp are justified by the PR description as necessary to fix tile buffer sizing for PTO codegen; all changes are aligned with the stated PR objectives.
Description check ✅ Passed The pull request description clearly relates to the changeset, detailing PTO lowering support for tile.scatter_update with specific implementation notes about in-place aliasing and test coverage.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements the backend code generation for the tile.scatter_update operation, which performs in-place updates on input tiles. Key changes include the addition of indentation helpers in PTOCodegen, the implementation of MakeScatterUpdateCodegenPTO using nested loops and PTO-specific value accessors, and a refactoring of memory allocation logic in InitMemRef to account for physical fractal capacity. A potential issue was identified in the scatter update codegen where a rank mismatch between indices and shapes could trigger an internal check failure if the input tiles are not 2D; a suggestion was provided to use a consistent 2D shape representation during offset calculation.

Comment thread src/backend/common/pto_ops_common.cpp
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/ir/op/tile_ops/transform.cpp (1)

388-433: ⚠️ Potential issue | 🟠 Major

Validate src against the documented scatter-update shapes.

This deduce path currently accepts any same-rank src, but the lowering only works for [b*s, d] in 2D and [b, s, 1, d] in 4D. For example, input=[16,32], index=[2,4], src=[2,4] passes here, and the backend still iterates k over 32 columns, which reads past the logical src extent. Please enforce the b*s / [b,s,1,d] shape relations here before codegen.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/ir/op/tile_ops/transform.cpp` around lines 388 - 433,
DeduceTileScatterUpdateType currently allows any same-rank src but lowering
requires specific shapes; update DeduceTileScatterUpdateType to validate
src_shape exactly: for 2D input (input_type->shape_.size()==2) require
index_type->shape_.size()==2 and enforce src_type->shape_.size()==2,
src_type->shape_[0] == index_type->shape_[0] * index_type->shape_[1] and
src_type->shape_[1] == input_type->shape_[1]; for 4D input require
src_type->shape_.size()==4 and enforce src_type->shape_[0] ==
index_type->shape_[0], src_type->shape_[1] == index_type->shape_[1],
src_type->shape_[2] == 1, and src_type->shape_[3] == input_type->shape_[3]; use
CHECK with clear error messages referencing input_type, index_type, and src_type
to fail when shapes mismatch.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/backend/common/pto_ops_common.cpp`:
- Around line 590-596: The code currently flattens offsets using the original
buffer shapes, causing EmitFlatOffsetSSAFromValues to see mismatched rank (e.g.,
2 vs 4) and abort; to fix, ensure offsets are computed against the reshaped 2D
views produced by make_linear_tile_type: after creating index_linear_type,
src_linear_type, and dst_linear_type and calling emit_treshape (index_linear,
src_linear, dst_linear), use the linear types' shape information when calling
EmitFlatOffsetSSAFromValues (or any offset-flattening helper) instead of
index_type->shape_, src_type->shape_, or input_type->shape_; update all places
that compute flat offsets (including around EmitFlatOffsetSSAFromValues and uses
of index_ssa/src_ssa/dst SSA) to reference
index_linear_type/src_linear_type/dst_linear_type so the indices.size() matches
shape.size().

In `@src/ir/transforms/init_memref.cpp`:
- Around line 70-90: Add a direct include of <algorithm> to this translation
unit to satisfy the use of std::max in GetTileAllocationSizeBytes; locate the
function GetTileAllocationSizeBytes in src/ir/transforms/init_memref.cpp and add
the header include for <algorithm> near the other includes so the call to
std::max<uint64_t>(...) is not relying on transitive includes.

In `@tests/ut/codegen/test_pto_codegen_ops.py`:
- Around line 1002-1025: The test test_tile_scatter_update_loop_bound currently
asserts "8" appears anywhere in the generated MLIR (mlir) which is brittle;
update the assertion to specifically check the scf.for loop upper bound (or the
SSA value used as the loop bound) so it can't be satisfied by the src_t shape.
In the test function (test_tile_scatter_update_loop_bound) locate the mlir
string and replace the loose assert with a targeted match — e.g., search for the
scf.for line pattern like "scf.for .* to 8" or extract the SSA constant that
feeds the loop bound and assert that its definition equals 8 — ensuring the
assertion references the scf.for loop bound or its loop-bound SSA symbol rather
than a raw "8".

---

Outside diff comments:
In `@src/ir/op/tile_ops/transform.cpp`:
- Around line 388-433: DeduceTileScatterUpdateType currently allows any
same-rank src but lowering requires specific shapes; update
DeduceTileScatterUpdateType to validate src_shape exactly: for 2D input
(input_type->shape_.size()==2) require index_type->shape_.size()==2 and enforce
src_type->shape_.size()==2, src_type->shape_[0] == index_type->shape_[0] *
index_type->shape_[1] and src_type->shape_[1] == input_type->shape_[1]; for 4D
input require src_type->shape_.size()==4 and enforce src_type->shape_[0] ==
index_type->shape_[0], src_type->shape_[1] == index_type->shape_[1],
src_type->shape_[2] == 1, and src_type->shape_[3] == input_type->shape_[3]; use
CHECK with clear error messages referencing input_type, index_type, and src_type
to fail when shapes mismatch.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 59d5f13b-8be0-4257-ba97-f88aa3a60845

📥 Commits

Reviewing files that changed from the base of the PR and between 3e680d1 and c2a0ce3.

📒 Files selected for processing (6)
  • include/pypto/codegen/pto/pto_codegen.h
  • src/backend/common/pto_ops_common.cpp
  • src/ir/op/tile_ops/transform.cpp
  • src/ir/transforms/init_memref.cpp
  • tests/ut/codegen/test_pto_codegen_ops.py
  • tests/ut/ir/transforms/test_init_memref.py

Comment thread src/backend/common/pto_ops_common.cpp Outdated
Comment thread src/ir/transforms/init_memref.cpp Outdated
Comment thread tests/ut/codegen/test_pto_codegen_ops.py
@Little-oil Little-oil force-pushed the codegen_for_tile_scatter_update branch from c2a0ce3 to 0bb800a Compare April 13, 2026 11:40
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
src/backend/common/pto_ops_common.cpp (1)

632-634: ⚠️ Potential issue | 🔴 Critical

Flatten src_off/dst_off against synthetic [rows, d] shapes, not the original tile ranks.

The 4D form still aborts here: EmitFlatOffsetSSAFromValues(...) gets two indices (idx, k / row, k) but src_type->shape_ and input_type->shape_ are rank-4. Also, using the current [1, total_elements] reshape type would compute the wrong offset. This needs a [src_total / d_val, d_val] / [dst_total / d_val, d_val] shape (or direct idx * d + k arithmetic) before the 4D path can lower successfully.

Minimal fix sketch
-  std::string src_off = EmitFlatOffsetSSAFromValues({idx, k_var}, src_type->shape_, codegen, "src_off");
-
-  std::string dst_off = EmitFlatOffsetSSAFromValues({row_idx, k_var}, input_type->shape_, codegen, "dst_off");
+  const std::vector<ir::ExprPtr> src_linear_shape = {
+      std::make_shared<ir::ConstInt>(src_total / d_val, DataType::INT64, ir::Span::unknown()),
+      std::make_shared<ir::ConstInt>(d_val, DataType::INT64, ir::Span::unknown()),
+  };
+  const std::vector<ir::ExprPtr> dst_linear_shape = {
+      std::make_shared<ir::ConstInt>(dst_total / d_val, DataType::INT64, ir::Span::unknown()),
+      std::make_shared<ir::ConstInt>(d_val, DataType::INT64, ir::Span::unknown()),
+  };
+  std::string src_off =
+      EmitFlatOffsetSSAFromValues({idx, k_var}, src_linear_shape, codegen, "src_off");
+  std::string dst_off =
+      EmitFlatOffsetSSAFromValues({row_idx, k_var}, dst_linear_shape, codegen, "dst_off");
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/backend/common/pto_ops_common.cpp` around lines 632 - 634, src_off and
dst_off are computed with EmitFlatOffsetSSAFromValues using the original rank-4
shapes (src_type->shape_ and input_type->shape_), but the indices provided (idx,
k_var and row_idx, k_var) refer to a synthetic [rows, d] layout; change the
flattening to use a 2D shape [total_rows, d_val] (i.e. compute src_total/d_val
and dst_total/d_val) or replace the call with explicit arithmetic idx * d_val +
k_var so EmitFlatOffsetSSAFromValues receives the correct 2D shape; update the
calls that produce src_off and dst_off (the two lines calling
EmitFlatOffsetSSAFromValues) to either pass the computed [rows, d_val] shape or
compute offsets as idx * d_val + k_var/row_idx * d_val + k_var respectively.
tests/ut/codegen/test_pto_codegen_ops.py (1)

1023-1025: ⚠️ Potential issue | 🟡 Minor

Assert the actual scf.for bounds instead of a raw "8".

This doesn't verify the lowering. The codegen emits nested loops with bounds 2 and 4, and "8" is already present in this fixture via src_t's shape, so the test can pass even if the loop bounds regress.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ut/codegen/test_pto_codegen_ops.py` around lines 1023 - 1025, The test
currently checks for a raw "8" in the MLIR which can be present for other
reasons; instead assert the actual scf.for loop bounds emitted by the lowering.
Replace the assert "8" check with assertions that the generated MLIR (from
self._generate_mlir(Prog)) contains scf.for loop headers with the expected upper
bounds (e.g. occurrences of "scf.for" lines containing "to 2" and "to 4" or the
equivalent "to <value>" text for b and s), so the test verifies the nested loop
bounds (2 and 4) rather than a stray "8".
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/ir/op/tile_ops/transform.cpp`:
- Around line 425-433: The code currently overwrites any existing logical extent
by setting tile_view.valid_shape = input_type->shape_, which loses partial-tile
information and causes downstream tile.store to treat padded elements as valid;
change the logic in the tile view construction (the tile_view /
input_type->tile_view_ handling before constructing TileType) to preserve an
existing valid_shape when input_type->tile_view_ is present and only default to
input_type->shape_ when no valid_shape was previously set, then pass that
preserved/defaulted tile_view into the TileType constructor.

---

Duplicate comments:
In `@src/backend/common/pto_ops_common.cpp`:
- Around line 632-634: src_off and dst_off are computed with
EmitFlatOffsetSSAFromValues using the original rank-4 shapes (src_type->shape_
and input_type->shape_), but the indices provided (idx, k_var and row_idx,
k_var) refer to a synthetic [rows, d] layout; change the flattening to use a 2D
shape [total_rows, d_val] (i.e. compute src_total/d_val and dst_total/d_val) or
replace the call with explicit arithmetic idx * d_val + k_var so
EmitFlatOffsetSSAFromValues receives the correct 2D shape; update the calls that
produce src_off and dst_off (the two lines calling EmitFlatOffsetSSAFromValues)
to either pass the computed [rows, d_val] shape or compute offsets as idx *
d_val + k_var/row_idx * d_val + k_var respectively.

In `@tests/ut/codegen/test_pto_codegen_ops.py`:
- Around line 1023-1025: The test currently checks for a raw "8" in the MLIR
which can be present for other reasons; instead assert the actual scf.for loop
bounds emitted by the lowering. Replace the assert "8" check with assertions
that the generated MLIR (from self._generate_mlir(Prog)) contains scf.for loop
headers with the expected upper bounds (e.g. occurrences of "scf.for" lines
containing "to 2" and "to 4" or the equivalent "to <value>" text for b and s),
so the test verifies the nested loop bounds (2 and 4) rather than a stray "8".
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 11ae86ec-cb1d-4ecb-afb2-bd64c0f5bda1

📥 Commits

Reviewing files that changed from the base of the PR and between c2a0ce3 and 0bb800a.

📒 Files selected for processing (6)
  • include/pypto/codegen/pto/pto_codegen.h
  • src/backend/common/pto_ops_common.cpp
  • src/ir/op/tile_ops/transform.cpp
  • src/ir/transforms/init_memref.cpp
  • tests/ut/codegen/test_pto_codegen_ops.py
  • tests/ut/ir/transforms/test_init_memref.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/ut/ir/transforms/test_init_memref.py
  • include/pypto/codegen/pto/pto_codegen.h

Comment thread src/ir/op/tile_ops/transform.cpp
@Little-oil Little-oil force-pushed the codegen_for_tile_scatter_update branch 8 times, most recently from 4ee9633 to 1df5a21 Compare April 15, 2026 07:55
- Add PTO lowering for tile.scatter_update using scalar pto.tgetval /
  pto.tsetval loops
- Keep the in-place scatter result explicitly aliased to the input tile
  SSA via set_output_reuses_input(0)
- Avoid runtime treshape by directly using original 2D tile buffers
- Add UT coverage for PTO codegen
- Add ST runtime coverage with 5 test cases: basic FP32, FP16 dtype,
  duplicate indices, single batch (b=1), and shuffled indices
- Bump PTOAS_VERSION to v0.25 with matching SHA256

Fixes hw-native-sys#920
@Little-oil Little-oil force-pushed the codegen_for_tile_scatter_update branch from fb4eb0c to 4de9293 Compare April 15, 2026 08:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

[Bug] tile.scatter_update missing codegen mapping to PTO instruction

1 participant