Skip to content

feat(spmd): add SPMD launch support with block intrinsics and system tests#992

Closed
lyfne123 wants to merge 3 commits intohw-native-sys:mainfrom
lyfne123:main
Closed

feat(spmd): add SPMD launch support with block intrinsics and system tests#992
lyfne123 wants to merge 3 commits intohw-native-sys:mainfrom
lyfne123:main

Conversation

@lyfne123
Copy link
Copy Markdown
Collaborator

Add pl.spmd_launch() DSL function for multi-block SPMD dispatch with core_num and sync_start parameters. Implement tile.get_block_idx() and tile.get_block_num() intrinsics for kernels to query block identity. Key changes:

  • Parser: parse spmd_launch() calls with core_num/sync_start kwargs
  • Orchestration codegen: emit launch_spec.set_core_num/set_require_sync_start
  • Kernel codegen: bridge get_block_idx/get_block_num to PTO dialect, add arith.index_cast for i64→index offset conversion in partition_view
  • Backend: generate block context bridge in kernel wrappers
  • Fix expand_mixed_kernel pass dropping SPMD kwargs when injecting GM pipe buffer arguments
  • Refactor: move SPMD intrinsics from memory.cpp to new spmd.cpp
  • Add system tests with multi-launch pattern (core_num=4,16,24,48)

…tests

Add pl.spmd_launch() DSL function for multi-block SPMD dispatch with
core_num and sync_start parameters. Implement tile.get_block_idx() and
tile.get_block_num() intrinsics for kernels to query block identity.
Key changes:
- Parser: parse spmd_launch() calls with core_num/sync_start kwargs
- Orchestration codegen: emit launch_spec.set_core_num/set_require_sync_start
- Kernel codegen: bridge get_block_idx/get_block_num to PTO dialect,
  add arith.index_cast for i64→index offset conversion in partition_view
- Backend: generate block context bridge in kernel wrappers
- Fix expand_mixed_kernel pass dropping SPMD kwargs when injecting
  GM pipe buffer arguments
- Refactor: move SPMD intrinsics from memory.cpp to new spmd.cpp
- Add system tests with multi-launch pattern (core_num=4,16,24,48)
@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Apr 13, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds SPMD support: introduces/intakes tile.get_block_idx, tile.get_block_num, and tile.get_subblock_idx as INT64 IR intrinsics, a new pl.spmd_launch DSL API and parser support, conditional kernel bridge injection when those ops are used, orchestration launch-spec emission, backend PTO op registrations, tests, docs, and build inclusion of spmd.cpp.

Changes

Cohort / File(s) Summary
IR SPMD ops
src/ir/op/tile_ops/spmd.cpp, src/ir/op/tile_ops/memory.cpp
New spmd.cpp registers tile.get_block_idx/get_block_num/get_subblock_idx returning INT64; removed prior registrations from memory.cpp.
Build
CMakeLists.txt
Added src/ir/op/tile_ops/spmd.cpp to PTO sources.
DSL & Language
python/pypto/language/op/system_ops.py, python/pypto/language/__init__.py, python/pypto/language/op/tile_ops.py
Added spmd_launch DSL API (exports in all); updated get_block_idx docs to INT64 and added get_block_num wrapper.
Parser
python/pypto/language/parser/ast_parser.py
Parse pl.spmd_launch(...) with required core_num and optional sync_start, validate callable/args, and produce IR Call with spmd kwargs.
IR Python bindings & debug
python/pypto/ir/op/tile_ops.py, python/pypto/debug/torch_codegen.py
Added IR binding get_block_num; changed get_block_idx docs to INT64; debug codegen emits "1" for get_block_num.
Backend kernel bridge
python/pypto/backend/pto_backend.py
Scan IR for tile SPMD ops; if used, set needs_block_ctx to inject block-context header, runtime [[block_local]] variables, get_blockidx/get_blocknum accessors, and wrapper prologue to populate them.
Backend PTO ops & offsets
src/backend/common/pto_ops_common.cpp
Added EmitOffsetAsIndex for offset casting; registered backend PTO ops tile.get_block_idx and tile.get_block_num emitting pto.get_block_idx/get_block_num.
Orchestration codegen & transforms
src/codegen/orchestration/orchestration_codegen.cpp, src/ir/transforms/expand_mixed_kernel_pass.cpp
Emit SPMD launch-spec (set_core_num, set_block_num, set_require_sync_start) for task submissions; preserve kwargs_ when rewriting calls.
Tests
tests/st/runtime/test_spmd.py, tests/ut/codegen/test_orchestration_codegen.py, tests/ut/ir/operators/test_op_registry.py, tests/ut/ir/transforms/test_flatten_call_expr_pass.py
Added SPMD system tests (single/multi-block), orchestration codegen unit tests, registry update, and updated transform expectations for INT64 block index.
Docs
docs/en/.../05-operators.md, docs/en/.../02-operation_reference.md, docs/zh-cn/..., src/ir/op/README.md
Updated docs to reflect INT64 return type for get_block_idx and document new get_block_num and tile_ops/spmd layout.

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant Parser as AST Parser
    participant IR as IR Generator
    participant Orch as Orchestration Codegen
    participant Backend as PTO Backend
    participant KernelWrap as Kernel Wrapper
    participant Kernel as PTO Kernel

    User->>Parser: pl.spmd_launch(kernel_fn, ..., core_num=N, sync_start=K)
    Parser->>IR: Build Call(spmd_launch, kernel_ref, args, kwargs)
    IR->>Orch: Emit task + attach spmd kwargs
    Orch->>Backend: Generate kernel & wrapper
    Backend->>Backend: Scan IR for tile.get_block_idx/get_block_num
    Backend->>KernelWrap: If needed, inject block_context_override & runtime vars
    KernelWrap->>Kernel: Populate runtime [[block_local]] from launch args
    Kernel->>Kernel: Calls to get_block_idx/get_block_num read runtime vars
    Kernel-->>User: Return results
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

  • PR #951: Similar runtime-bridge/wrapper injection work in python/pypto/backend/pto_backend.py for SPMD/subblock context.
  • PR #887: Related changes to PTO backend operator registration and codegen in pto_ops_common.cpp.
  • PR #501: Overlaps operator registration/type-deduction work for tile ops in the IR registry.

Suggested reviewers

  • Hzfengsy

"🐰
I hopped through ops and CMake lands,
Brought block counts to tiny hands.
INT64 now sings its codey song,
Spmd launches hop along—so strong! ✨"

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 49.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely summarizes the main change: adding SPMD launch support with block intrinsics and system tests.
Description check ✅ Passed The description is directly related to the changeset, detailing the key changes across parser, orchestration codegen, kernel codegen, backend, bug fixes, and refactoring.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for SPMD (Single Program Multiple Data) kernel launches, including new tile.get_block_idx and tile.get_block_num operations and a pl.spmd_launch orchestration construct. The implementation spans the DSL parser, IR, and both PTO and orchestration backends, with get_block_idx being updated from UINT64 to INT64 for consistency. Feedback focuses on improving the efficiency of IR traversal when checking for SPMD operations, reducing redundant checks in the backend, and allowing more flexible constant expressions for SPMD launch parameters.

Comment thread python/pypto/backend/pto_backend.py Outdated
Comment thread python/pypto/backend/pto_backend.py Outdated
Comment thread python/pypto/language/parser/ast_parser.py Outdated
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

🧹 Nitpick comments (3)
python/pypto/debug/torch_codegen.py (1)

297-297: Document the single-block assumption here.

This hardcodes tile.get_block_num() to 1, which is reasonable for the same single-block debug model that maps tile.get_block_idx() to 0 right above. A short comment would make it explicit that torch debug codegen does not emulate multi-block SPMD execution.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/pypto/debug/torch_codegen.py` at line 297, Add a short explanatory
comment next to the mapping m["tile.get_block_num"] = lambda _a, _kw: "1"
stating that this hardcodes the debug model to a single block (matching the
earlier m["tile.get_block_idx"] -> "0" mapping) and that the torch debug codegen
intentionally does not emulate multi-block SPMD execution; reference the
m["tile.get_block_num"] mapping and the tile.get_block_idx mapping when
inserting the comment so future readers understand the single-block assumption.
tests/ut/codegen/test_orchestration_codegen.py (1)

2064-2065: Narrow the negative assertion to a specific exception type.

pytest.raises(Exception, ...) can pass on unrelated failures. Please assert the concrete parser error type for missing core_num to keep this regression test strict.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ut/codegen/test_orchestration_codegen.py` around lines 2064 - 2065,
Replace the broad pytest.raises(Exception, ...) in the test that asserts missing
core_num with the concrete parser exception your parser actually raises (import
that class and use it instead of Exception); locate the pytest.raises call in
tests/ut/codegen/test_orchestration_codegen.py and change Exception to the
specific exception class (e.g., ParserError / ConfigParseError /
OrchestrationParseError) exported by the parser implementation so the test only
catches the intended parse error for missing core_num.
tests/st/runtime/test_spmd.py (1)

93-98: Consider adding explicit runtime coverage for pl.tile.get_block_num().

This test validates get_block_idx() well, but get_block_num() (also added in this PR scope) is not exercised. A small dedicated kernel/test path using get_block_num() would reduce regression risk for that intrinsic.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/st/runtime/test_spmd.py` around lines 93 - 98, Add a small SPMD
test/kernal that explicitly calls pl.tile.get_block_num() to exercise the new
intrinsic: create a variant of the current test (or a new test function) that
queries pl.tile.get_block_num() inside the kernel (e.g., store or compute a
value based on get_block_num() and write it to the output buffer) and assert the
expected block count on the host side; update the test name to reflect coverage
(e.g., test_spmd_block_num) and reuse the existing TILE_H/TILE_W setup so it
runs alongside the get_block_idx() test.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@docs/en/dev/ir/05-operators.md`:
- Line 386: Update the table entry that lists the SPMD intrinsics so the
source-file reference points to the new registration location: change the file
for tile.get_block_idx and tile.get_block_num from tile_ops/memory.cpp to
src/ir/op/tile_ops/spmd.cpp; locate the row containing `tile_ops/memory.cpp` and
the symbols `tile.get_block_idx` / `tile.get_block_num` and update the file name
to the new path so future edits route to the correct implementation.

In `@docs/zh-cn/dev/ir/05-operators.md`:
- Around line 386-387: Update the docs table to reflect the new ownership of
SPMD intrinsics: change the entry that currently lists get_block_idx and
get_block_num under `tile_ops/memory.cpp` so those symbols (`get_block_idx`,
`get_block_num`) point to the new SPMD source file (the SPMD intrinsics file
added in this PR) instead of `tile_ops/memory.cpp`; ensure the table row for
`tile_ops/memory.cpp` no longer mentions those functions and that the new SPMD
file row lists them so contributors are not misled.

In `@docs/zh-cn/user/02-operation_reference.md`:
- Around line 73-74: The docstring for the Python public wrapper in
python/pypto/ir/op/tile_ops.py currently claims get_block_idx() returns UINT64
which contradicts the user docs that now state it returns INT64; update the
wrapper docstring for get_block_idx (and get_block_num if present) to state the
return type is INT64 (or Scalar(INT64) as used elsewhere), ensure any examples
or type annotations in the function/method docstring mention INT64, and keep the
wording consistent with the docs/zh-cn/user/02-operation_reference.md
description.

In `@python/pypto/backend/pto_backend.py`:
- Around line 385-387: The fix is to make the block-context bridge decision at
the translation-unit / shared ptoas_cpp level instead of per-function: update
_emit_group_output to compute a single boolean (using
_needs_block_context_bridge on every group member or directly scanning the
group's functions) and emit the bridge into the shared ptoas_cpp body whenever
any sibling uses tile.get_block_idx/get_block_num; remove relying on the
per-function _needs_block_context_bridge to skip emission for non-SPMD wrappers
so that static helpers referencing get_blockidx()/get_blocknum() are always
available when the shared ptoas_cpp is used.

In `@python/pypto/language/parser/ast_parser.py`:
- Around line 2521-2523: The shorthand matcher currently triggers for any
attribute chain starting with pl.spmd_launch (e.g., pl.spmd_launch.foo(...));
update the condition in ast_parser.py where attrs is inspected so it only
matches the exact two-segment form pl.spmd_launch by requiring len(attrs) == 2
(instead of >= 2) and attrs[0] == "pl" and attrs[1] == "spmd_launch", then call
self._parse_spmd_launch(call) only in that case to avoid routing longer chains
into _parse_spmd_launch.

In `@src/backend/common/pto_ops_common.cpp`:
- Around line 503-520: The generated MLIR sometimes passes non-index integer SSA
values where index is declared; reuse the existing EmitOffsetAsIndex helper
(which handles constants and emits arith.index_cast) for all index operands
instead of directly using GetExprAsCode; specifically update
MakeTileAssembleCodegenPTO, the tile.slice lowering, and any places emitting
pto.tinsert / pto.textract operands (e.g. row/col values derived from
tile.get_block_idx()) to call EmitOffsetAsIndex(expr, codegen) and use its
return value for index-typed operands so the emitted MLIR always uses index or
an index_casted value.

In `@src/codegen/orchestration/orchestration_codegen.cpp`:
- Around line 535-539: In EmitSpmdLaunchSpec, validate the retrieved core_num
from call->GetKwarg<int>("core_num") before emitting to
task_var.launch_spec.set_core_num(...); specifically, after
call->HasKwarg("core_num") read core_num, check that it is > 0 (and optionally
within any maximum if required) and only emit the set_core_num line when valid;
for invalid (<= 0) either skip emitting and emit a diagnostic/log via the
existing logging mechanism or assert/throw a clear error so an invalid launch
configuration is not generated.

In `@src/ir/op/README.md`:
- Line 14: Update the README mapping so the SPMD intrinsics are shown under
spmd.cpp instead of memory.cpp: replace or move the entries listing
get_block_idx and get_block_num from the current tile_ops/memory.cpp line to the
dedicated spmd.cpp entry and ensure the README reflects the actual source layout
(reference symbols: get_block_idx, get_block_num, spmd.cpp, memory.cpp).

In `@src/ir/op/tile_ops/spmd.cpp`:
- Around line 71-99: The new registrations for "tile.get_block_idx" and
"tile.get_subblock_idx" duplicate earlier REGISTER_OP entries; remove the
earlier duplicate REGISTER_OP calls that register "tile.get_block_idx" and
"tile.get_subblock_idx" (leave the single canonical registrations here), or
alternatively keep the original ones and remove the duplicates from this diff;
ensure only one REGISTER_OP per op name and preserve the associated deduce-type
lambdas (e.g., DeduceTileGetBlockIdxType, DeduceTileGetSubblockIdxType) and that
"tile.get_block_num" remains registered exactly once; run a build to confirm no
startup/registration conflicts.

In `@src/ir/transforms/expand_mixed_kernel_pass.cpp`:
- Around line 1112-1114: The rewrite paths rebuild Call nodes without preserving
kwargs, causing launch-config fields like core_num/sync_start to be dropped;
update all Call reconstruction sites (including RewriteCallsForGMBuffer() and
RewriteGroupCaller() as well as the other make_shared<Call> usages) to pass
call->kwargs_ into the std::make_shared<Call>(...) constructor (i.e., replace
constructors that use only op_ + args_ with the variant that includes
call->kwargs_), ensuring every call-rewrite preserves the original kwargs.

---

Nitpick comments:
In `@python/pypto/debug/torch_codegen.py`:
- Line 297: Add a short explanatory comment next to the mapping
m["tile.get_block_num"] = lambda _a, _kw: "1" stating that this hardcodes the
debug model to a single block (matching the earlier m["tile.get_block_idx"] ->
"0" mapping) and that the torch debug codegen intentionally does not emulate
multi-block SPMD execution; reference the m["tile.get_block_num"] mapping and
the tile.get_block_idx mapping when inserting the comment so future readers
understand the single-block assumption.

In `@tests/st/runtime/test_spmd.py`:
- Around line 93-98: Add a small SPMD test/kernal that explicitly calls
pl.tile.get_block_num() to exercise the new intrinsic: create a variant of the
current test (or a new test function) that queries pl.tile.get_block_num()
inside the kernel (e.g., store or compute a value based on get_block_num() and
write it to the output buffer) and assert the expected block count on the host
side; update the test name to reflect coverage (e.g., test_spmd_block_num) and
reuse the existing TILE_H/TILE_W setup so it runs alongside the get_block_idx()
test.

In `@tests/ut/codegen/test_orchestration_codegen.py`:
- Around line 2064-2065: Replace the broad pytest.raises(Exception, ...) in the
test that asserts missing core_num with the concrete parser exception your
parser actually raises (import that class and use it instead of Exception);
locate the pytest.raises call in tests/ut/codegen/test_orchestration_codegen.py
and change Exception to the specific exception class (e.g., ParserError /
ConfigParseError / OrchestrationParseError) exported by the parser
implementation so the test only catches the intended parse error for missing
core_num.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 12546818-ae47-4c7c-a6f0-067e172a4d86

📥 Commits

Reviewing files that changed from the base of the PR and between aa8ea46 and 942da7a.

📒 Files selected for processing (22)
  • CMakeLists.txt
  • docs/en/dev/ir/05-operators.md
  • docs/en/user/02-operation_reference.md
  • docs/zh-cn/dev/ir/05-operators.md
  • docs/zh-cn/user/02-operation_reference.md
  • python/pypto/backend/pto_backend.py
  • python/pypto/debug/torch_codegen.py
  • python/pypto/ir/op/tile_ops.py
  • python/pypto/language/__init__.py
  • python/pypto/language/op/system_ops.py
  • python/pypto/language/op/tile_ops.py
  • python/pypto/language/parser/ast_parser.py
  • src/backend/common/pto_ops_common.cpp
  • src/codegen/orchestration/orchestration_codegen.cpp
  • src/ir/op/README.md
  • src/ir/op/tile_ops/memory.cpp
  • src/ir/op/tile_ops/spmd.cpp
  • src/ir/transforms/expand_mixed_kernel_pass.cpp
  • tests/st/runtime/test_spmd.py
  • tests/ut/codegen/test_orchestration_codegen.py
  • tests/ut/ir/operators/test_op_registry.py
  • tests/ut/ir/transforms/test_flatten_call_expr_pass.py

Comment thread docs/en/dev/ir/05-operators.md Outdated
Comment thread docs/zh-cn/dev/ir/05-operators.md Outdated
Comment thread docs/zh-cn/user/02-operation_reference.md
Comment thread python/pypto/backend/pto_backend.py Outdated
Comment thread python/pypto/language/parser/ast_parser.py
Comment thread src/backend/common/pto_ops_common.cpp Outdated
Comment thread src/codegen/orchestration/orchestration_codegen.cpp
Comment thread src/ir/op/README.md Outdated
Comment thread src/ir/op/tile_ops/spmd.cpp
Comment thread src/ir/transforms/expand_mixed_kernel_pass.cpp
- Orchestration codegen: emit set_block_num for 910B, set_core_num for
  950, with core_num > 0 validation
- Fix clang-tidy: replace error.h with logging.h in spmd.cpp
- Preserve kwargs in all Call rewrite paths (RewriteCallsForGMBuffer,
  RewriteGroupCaller) to prevent SPMD config loss
- Refactor _uses_op into single-pass _collect_used_ops; eliminate
  redundant _needs_block_context_bridge calls via precomputed flag
- Scope block context bridge to group translation unit so all members
  get the bridge when any sibling uses SPMD intrinsics
- Apply EmitOffsetAsIndex to tile.assemble and tile.slice for correct
  i64→index casting with SPMD-derived offsets
- Narrow spmd_launch parser matcher to exact 2-segment form
- Use try_eval_expr for spmd_launch kwargs to support closure variables
- Update docs to reflect SPMD ops moved to spmd.cpp

Made-with: Cursor
Tests use BackendType.Ascend910B, so the orchestration codegen now
correctly emits set_block_num instead of set_core_num. Update assertions
to match.

Made-with: Cursor
Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@python/pypto/backend/pto_backend.py`:
- Around line 368-389: The current _collect_used_ops only inspects top-level
Call objects on EvalStmt/AssignStmt roots and misses Calls nested inside
expressions; update _collect_used_ops to recursively traverse expression trees
(visiting _ir_core.Call nodes within args, operands, left/right, etc.) for a
given _ir_core.Function (still using _ir_core.flatten_to_stmts), so that any
nested uses of _ir_core.Op (e.g., "tile.get_block_idx"/"tile.get_block_num") are
detected; this will make _needs_block_context_bridge correctly return true when
those builtins appear anywhere inside expressions.

In `@python/pypto/language/parser/ast_parser.py`:
- Around line 2963-2985: The parser currently only checks that spmd_launch
kwargs are compile-time constants; update the loop in ast_parser.py (where
spmd_kwargs, valid_kwargs and self.expr_evaluator.try_eval_expr are used) to
validate types before storing: after success, assert that for kw.arg ==
"core_num" the evaluated val is an int (and positive if desired) and for kw.arg
== "sync_start" the val is a bool; if the type is wrong raise a
ParserSyntaxError with a clear message/hint (similar style as existing errors)
instead of inserting the invalid value into spmd_kwargs so downstream code
expecting GetKwarg<int>/GetKwarg<bool> fails at parse time.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 24ac4c6f-20fd-405c-9137-e63b2dd8efa7

📥 Commits

Reviewing files that changed from the base of the PR and between 942da7a and 12c669b.

📒 Files selected for processing (9)
  • docs/en/dev/ir/05-operators.md
  • docs/zh-cn/dev/ir/05-operators.md
  • python/pypto/backend/pto_backend.py
  • python/pypto/language/parser/ast_parser.py
  • src/backend/common/pto_ops_common.cpp
  • src/codegen/orchestration/orchestration_codegen.cpp
  • src/ir/op/README.md
  • src/ir/op/tile_ops/spmd.cpp
  • src/ir/transforms/expand_mixed_kernel_pass.cpp
✅ Files skipped from review due to trivial changes (2)
  • src/ir/op/README.md
  • src/ir/transforms/expand_mixed_kernel_pass.cpp
🚧 Files skipped from review as they are similar to previous changes (3)
  • docs/en/dev/ir/05-operators.md
  • docs/zh-cn/dev/ir/05-operators.md
  • src/ir/op/tile_ops/spmd.cpp

Comment on lines +368 to +389
def _collect_used_ops(func: _ir_core.Function) -> set[str]:
"""Return the set of op names used in the function body (single pass)."""
used: set[str] = set()
stmts = _ir_core.flatten_to_stmts(func.body)
for stmt in stmts:
call = None
if isinstance(stmt, _ir_core.EvalStmt):
call = stmt.expr
elif isinstance(stmt, _ir_core.AssignStmt):
call = stmt.value
if not isinstance(call, _ir_core.Call):
continue
op = getattr(call, "op", None)
if isinstance(op, _ir_core.Op):
used.add(op.name)
return used


def _needs_block_context_bridge(func: _ir_core.Function) -> bool:
"""Return whether the kernel needs runtime block context bridge (get_block_idx/get_block_num)."""
ops = _collect_used_ops(func)
return "tile.get_block_idx" in ops or "tile.get_block_num" in ops
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Traverse nested expressions when collecting used ops.

This helper only records statement-root Calls. If tile.get_block_idx() / tile.get_block_num() appears inside arithmetic, a comparison, or as an argument to another op, _needs_block_context_bridge() returns false and the wrapper skips the runtime bridge even though the generated kernel still references block-context builtins.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/pypto/backend/pto_backend.py` around lines 368 - 389, The current
_collect_used_ops only inspects top-level Call objects on EvalStmt/AssignStmt
roots and misses Calls nested inside expressions; update _collect_used_ops to
recursively traverse expression trees (visiting _ir_core.Call nodes within args,
operands, left/right, etc.) for a given _ir_core.Function (still using
_ir_core.flatten_to_stmts), so that any nested uses of _ir_core.Op (e.g.,
"tile.get_block_idx"/"tile.get_block_num") are detected; this will make
_needs_block_context_bridge correctly return true when those builtins appear
anywhere inside expressions.

Comment on lines +2963 to +2985
spmd_kwargs: dict[str, Any] = {}
valid_kwargs = {"core_num", "sync_start"}
for kw in call.keywords:
if kw.arg not in valid_kwargs:
raise ParserSyntaxError(
f"Unknown spmd_launch keyword argument: '{kw.arg}'",
span=span,
hint=f"Valid keyword arguments: {sorted(valid_kwargs)}",
)
success, val = self.expr_evaluator.try_eval_expr(kw.value)
if not success:
raise ParserSyntaxError(
f"spmd_launch keyword '{kw.arg}' must be a compile-time constant",
span=span,
)
spmd_kwargs[kw.arg] = val

if "core_num" not in spmd_kwargs:
raise ParserSyntaxError(
"spmd_launch requires 'core_num' keyword argument",
span=span,
hint="Usage: pl.system.spmd_launch(self.kernel, arg1, ..., core_num=N)",
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate core_num and sync_start types before storing kwargs.

Lines 2972-2978 only check that the kwarg is compile-time evaluable. That still accepts values like core_num=True or sync_start=1, but orchestration codegen later does typed GetKwarg<int>("core_num") / GetKwarg<bool>("sync_start"), so these invalid calls fail late with a much less useful internal error instead of a parser diagnostic.

Suggested fix
         for kw in call.keywords:
             if kw.arg not in valid_kwargs:
                 raise ParserSyntaxError(
                     f"Unknown spmd_launch keyword argument: '{kw.arg}'",
                     span=span,
                     hint=f"Valid keyword arguments: {sorted(valid_kwargs)}",
                 )
             success, val = self.expr_evaluator.try_eval_expr(kw.value)
             if not success:
                 raise ParserSyntaxError(
                     f"spmd_launch keyword '{kw.arg}' must be a compile-time constant",
                     span=span,
                 )
+            if kw.arg == "core_num":
+                if not isinstance(val, int) or isinstance(val, bool):
+                    raise ParserSyntaxError(
+                        "spmd_launch 'core_num' must be an integer constant",
+                        span=self.span_tracker.get_span(kw.value),
+                    )
+            elif kw.arg == "sync_start" and not isinstance(val, bool):
+                raise ParserSyntaxError(
+                    "spmd_launch 'sync_start' must be a boolean constant",
+                    span=self.span_tracker.get_span(kw.value),
+                )
             spmd_kwargs[kw.arg] = val
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/pypto/language/parser/ast_parser.py` around lines 2963 - 2985, The
parser currently only checks that spmd_launch kwargs are compile-time constants;
update the loop in ast_parser.py (where spmd_kwargs, valid_kwargs and
self.expr_evaluator.try_eval_expr are used) to validate types before storing:
after success, assert that for kw.arg == "core_num" the evaluated val is an int
(and positive if desired) and for kw.arg == "sync_start" the val is a bool; if
the type is wrong raise a ParserSyntaxError with a clear message/hint (similar
style as existing errors) instead of inserting the invalid value into
spmd_kwargs so downstream code expecting GetKwarg<int>/GetKwarg<bool> fails at
parse time.

@lyfne123 lyfne123 closed this Apr 16, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

1 participant