Skip to content

Commit 9a8577a

Browse files
author
Youhezhen
committed
feat(ir): add tile.mscatter op for per-element scatter-store to GM (#921)
Add tile.mscatter operation mapping to pto.mscatter instruction: mem[idx[i, j]] = src[i, j] - C++ op registration with type deduction and validation - PTO codegen emitting partition_view + pto.mscatter - Python IR and DSL wrappers with pl.mscatter export - Unit tests covering basic usage and error paths - ST runtime tests (skipped: PTOAS lacks NPU mscatter impl)
1 parent 1589d7e commit 9a8577a

8 files changed

Lines changed: 987 additions & 1 deletion

File tree

python/pypto/ir/op/tile_ops.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,31 @@ def scatter_update(
325325
return _ir_core.create_op_call("tile.scatter_update", op_args, kwargs, actual_span)
326326

327327

328+
def mscatter(
329+
src: Expr,
330+
idx: Expr,
331+
output_tensor: Expr,
332+
span: Span | None = None,
333+
) -> Call:
334+
"""Scatter-store elements from src tile to output_tensor at per-element indices.
335+
336+
Semantics: ``output_tensor[idx[i, j]] = src[i, j]``
337+
338+
Maps to the PTOAS ``pto.mscatter`` instruction.
339+
340+
Args:
341+
src: Source tile (FP16, FP32, INT16, or INT32)
342+
idx: Index tile (INT32, same rank as src)
343+
output_tensor: Output tensor (TensorType, same dtype as src)
344+
span: Optional source span for debugging (auto-captured if not provided)
345+
346+
Returns:
347+
Call expression that returns the output tensor
348+
"""
349+
actual_span = _get_span_or_capture(span)
350+
return _ir_core.create_op_call("tile.mscatter", [src, idx, output_tensor], {}, actual_span)
351+
352+
328353
def concat(
329354
src0: Expr,
330355
src1: Expr,

python/pypto/language/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,9 @@ def scalar_func(x: pl.Scalar[pl.FP32]) -> pl.Scalar[pl.FP32]:
133133
xor,
134134
xors,
135135
)
136+
from .op.tile_ops import (
137+
mscatter as mscatter,
138+
)
136139
from .op.unified_ops import (
137140
add,
138141
cast,

python/pypto/language/op/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
minimum,
5353
mins,
5454
move,
55+
mscatter,
5556
not_,
5657
or_,
5758
ors,
@@ -177,6 +178,7 @@
177178
"shrs",
178179
"maxs",
179180
"mins",
181+
"mscatter",
180182
"prelu",
181183
"not_",
182184
"addc",

python/pypto/language/op/tile_ops.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@
112112
"tpop_from_aiv",
113113
"sort32",
114114
"gather",
115+
"mscatter",
115116
"MaskPattern",
116117
"mrgsort",
117118
]
@@ -1687,6 +1688,28 @@ def gather(
16871688
return Tile(expr=call_expr)
16881689

16891690

1691+
def mscatter(src: Tile, idx: Tile, output_tensor: Tensor) -> Tensor:
1692+
"""Scatter-store tile elements into a tensor at per-element indices.
1693+
1694+
Semantics: ``output_tensor[idx[i, j]] = src[i, j]``
1695+
1696+
Maps to the PTOAS ``pto.mscatter`` instruction.
1697+
1698+
Args:
1699+
src: Source tile (FP16, FP32, INT16, or INT32)
1700+
idx: Index tile (INT32, same rank as src)
1701+
output_tensor: Output tensor to scatter into (same dtype as src)
1702+
1703+
Returns:
1704+
Tensor wrapping the mscatter operation
1705+
1706+
Example:
1707+
>>> result = pl.tile.mscatter(src_tile, idx_tile, out_tensor)
1708+
"""
1709+
call_expr = _ir_ops.mscatter(src.unwrap(), idx.unwrap(), output_tensor.unwrap())
1710+
return Tensor(expr=call_expr)
1711+
1712+
16901713
@overload
16911714
def mrgsort(src0: Tile, *, block_len: int | Scalar) -> Tile: ...
16921715

src/backend/common/pto_ops_common.cpp

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,84 @@ static std::string MakeTileStoreCodegenPTO(const CallPtr& op, codegen::CodegenBa
713713
return "";
714714
}
715715

716+
// tile.mscatter(src, idx, output_tensor) -> pto.mscatter
717+
// Generates:
718+
// %pview = pto.partition_view %tensor_view, offsets=[0,...], sizes=[d0,...] : ... -> ...
719+
// pto.mscatter ins(%src, %idx : !pto.tile_buf<...>, !pto.tile_buf<...>)
720+
// outs(%pview : !pto.partition_tensor_view<...>)
721+
static std::string MakeTileMscatterCodegenPTO(const CallPtr& op, codegen::CodegenBase& codegen_base) {
722+
auto& codegen = dynamic_cast<codegen::PTOCodegen&>(codegen_base);
723+
INTERNAL_CHECK(op->args_.size() == 3)
724+
<< "tile.mscatter requires 3 arguments (src, idx, output_tensor), got " << op->args_.size();
725+
726+
auto src = AsVarLike(op->args_[0]);
727+
INTERNAL_CHECK(src) << "tile.mscatter src must be a Var or IterArg";
728+
auto idx = AsVarLike(op->args_[1]);
729+
INTERNAL_CHECK(idx) << "tile.mscatter idx must be a Var or IterArg";
730+
auto output_tensor = AsVarLike(op->args_[2]);
731+
INTERNAL_CHECK(output_tensor) << "tile.mscatter output_tensor must be a Var or IterArg";
732+
733+
auto tensor_type = As<TensorType>(output_tensor->GetType());
734+
INTERNAL_CHECK(tensor_type) << "tile.mscatter output_tensor must have TensorType";
735+
736+
std::string src_name = codegen.GetVarName(src);
737+
std::string idx_name = codegen.GetVarName(idx);
738+
std::string src_type_annot = codegen.GetExprTypeAnnotation(op->args_[0]);
739+
std::string idx_type_annot = codegen.GetExprTypeAnnotation(op->args_[1]);
740+
741+
std::string dtype_str = codegen.GetTypeString(tensor_type->dtype_);
742+
std::string tensor_view = codegen.GetOrCreateTensorView(output_tensor);
743+
std::string tensor_view_type = codegen.GetTensorViewTypeString(tensor_type.get());
744+
745+
// Build pto.partition_view covering the entire tensor (mscatter uses per-element
746+
// indices, so the partition is the whole tensor — offsets all zero, sizes = shape).
747+
std::string partition_view = codegen.NewNamedTemp(output_tensor->name_hint_ + "_pview");
748+
std::ostringstream partition_line;
749+
partition_line << partition_view << " = pto.partition_view " << tensor_view;
750+
partition_line << ", offsets = [";
751+
for (size_t i = 0; i < tensor_type->shape_.size(); ++i) {
752+
if (i > 0) partition_line << ", ";
753+
partition_line << codegen.GetIndexConstant(0);
754+
}
755+
partition_line << "], sizes = [";
756+
std::string partition_type = "!pto.partition_tensor_view<";
757+
for (size_t i = 0; i < tensor_type->shape_.size(); ++i) {
758+
if (i > 0) {
759+
partition_line << ", ";
760+
partition_type += "x";
761+
}
762+
if (auto c = As<ir::ConstInt>(tensor_type->shape_[i])) {
763+
partition_line << codegen.GetIndexConstant(c->value_);
764+
partition_type += std::to_string(c->value_);
765+
} else {
766+
partition_line << codegen.GetExprAsCode(tensor_type->shape_[i]);
767+
partition_type += "?";
768+
}
769+
}
770+
partition_line << "]";
771+
partition_type += "x" + dtype_str + ">";
772+
partition_line << " : " << tensor_view_type << " -> " << partition_type;
773+
codegen.Emit(partition_line.str());
774+
775+
// Emit pto.mscatter with partition_view in outs()
776+
std::ostringstream mscatter_line;
777+
mscatter_line << "pto.mscatter ins(" << src_name << ", " << idx_name;
778+
if (!src_type_annot.empty() && !idx_type_annot.empty()) {
779+
mscatter_line << " : " << src_type_annot << ", " << idx_type_annot;
780+
}
781+
mscatter_line << ") outs(" << partition_view << " : " << partition_type << ")";
782+
codegen.Emit(mscatter_line.str());
783+
784+
// Propagate tensor_view to the result var so downstream ops see the updated tensor
785+
auto result_var = codegen.GetCurrentResultVar();
786+
if (result_var != nullptr) {
787+
codegen.RegisterTensorView(result_var, tensor_view);
788+
codegen.RegisterVarToMlir(result_var, tensor_view);
789+
}
790+
791+
return "";
792+
}
793+
716794
// Helper function for tile.alloc (no-op: allocation handled elsewhere)
717795
static std::string MakeTileAllocCodegenPTO(const CallPtr& op, codegen::CodegenBase& codegen_base) {
718796
(void)op;
@@ -1171,7 +1249,6 @@ struct SimpleOpEntry {
11711249
static const SimpleOpEntry kSimpleOps[] = {
11721250
// Memory operations
11731251
{"tile.mgather", "pto.tmgather", 2},
1174-
{"tile.mscatter", "pto.tmscatter", 2},
11751252
// Tile x Tile arithmetic operations
11761253
{"tile.add", "pto.tadd", 2},
11771254
{"tile.sub", "pto.tsub", 2},
@@ -1321,6 +1398,15 @@ void RegisterPTOOps(Backend& backend, const std::unordered_set<std::string>& exc
13211398
reg("tile.store", [](const ir::CallPtr& op, codegen::CodegenBase& codegen) {
13221399
return MakeTileStoreCodegenPTO(op, codegen);
13231400
});
1401+
// tile.mscatter: src and idx must be row_major (MTE3 DMA reads UB linearly)
1402+
if (exclude_ops.count("tile.mscatter") == 0) {
1403+
backend.RegisterOp("tile.mscatter")
1404+
.f_codegen([](const ir::CallPtr& op, codegen::CodegenBase& codegen) {
1405+
return MakeTileMscatterCodegenPTO(op, codegen);
1406+
})
1407+
.set_input_layout(0, ir::TileLayout::row_major)
1408+
.set_input_layout(1, ir::TileLayout::row_major);
1409+
}
13241410
reg("tile.alloc", [](const ir::CallPtr& op, codegen::CodegenBase& codegen) {
13251411
return MakeTileAllocCodegenPTO(op, codegen);
13261412
});

src/ir/op/tile_ops/memory.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,65 @@ REGISTER_OP("tile.store")
535535
return DeduceTileStoreType(args, kwargs, "tile.store");
536536
});
537537

538+
// ============================================================================
539+
// tile.mscatter: scatter-store tile elements to tensor via per-element indices
540+
// Maps to pto.mscatter: mem[idx[i, j]] = src[i, j]
541+
// ============================================================================
542+
543+
TypePtr DeduceTileMscatterType(const std::vector<ExprPtr>& args,
544+
const std::vector<std::pair<std::string, std::any>>& kwargs,
545+
const std::string& op_name) {
546+
CHECK(args.size() == 3) << "The operator " << op_name
547+
<< " requires 3 arguments (src, idx, output_tensor), but got " << args.size();
548+
549+
// First arg: src tile (FP16/FP32/INT16/INT32)
550+
auto src_type = As<TileType>(args[0]->GetType());
551+
CHECK(src_type) << "The operator " << op_name << " requires first argument to be a TileType, but got "
552+
<< args[0]->GetType()->TypeName();
553+
CHECK(src_type->dtype_ == DataType::FP16 || src_type->dtype_ == DataType::FP32 ||
554+
src_type->dtype_ == DataType::INT16 || src_type->dtype_ == DataType::INT32)
555+
<< "The operator " << op_name << " requires src dtype to be FP16, FP32, INT16, or INT32, but got "
556+
<< src_type->dtype_.ToString();
557+
558+
// Second arg: idx tile (INT32, same rank as src)
559+
auto idx_type = As<TileType>(args[1]->GetType());
560+
CHECK(idx_type) << "The operator " << op_name << " requires second argument to be a TileType, but got "
561+
<< args[1]->GetType()->TypeName();
562+
CHECK(idx_type->dtype_ == DataType::INT32)
563+
<< "The operator " << op_name << " requires idx dtype to be INT32, but got "
564+
<< idx_type->dtype_.ToString();
565+
CHECK(idx_type->shape_.size() == src_type->shape_.size())
566+
<< "The operator " << op_name << " requires idx rank to match src rank (" << src_type->shape_.size()
567+
<< "), but got " << idx_type->shape_.size();
568+
569+
// Third arg: output tensor (same dtype as src)
570+
auto tensor_type = As<TensorType>(args[2]->GetType());
571+
CHECK(tensor_type) << "The operator " << op_name << " requires third argument to be a TensorType, but got "
572+
<< args[2]->GetType()->TypeName();
573+
CHECK(tensor_type->dtype_ == src_type->dtype_)
574+
<< "The operator " << op_name << " requires output_tensor dtype (" << tensor_type->dtype_.ToString()
575+
<< ") to match src dtype (" << src_type->dtype_.ToString() << ")";
576+
577+
// mscatter returns the output tensor (same type)
578+
return tensor_type;
579+
}
580+
581+
REGISTER_OP("tile.mscatter")
582+
.set_op_category("TileOp")
583+
.set_description(
584+
"Scatter-store elements from src tile to tensor at per-element indices "
585+
"(maps to pto.mscatter)")
586+
.add_argument("src", "Source tile (FP16, FP32, INT16, or INT32)")
587+
.add_argument("idx", "Index tile (INT32, same rank as src)")
588+
.add_argument("output_tensor", "Output tensor (TensorType, same dtype as src)")
589+
.set_input_memory(0, MemorySpace::Vec)
590+
.set_input_memory(1, MemorySpace::Vec)
591+
.set_output_reuses_input(2)
592+
.f_deduce_type([](const std::vector<ExprPtr>& args,
593+
const std::vector<std::pair<std::string, std::any>>& kwargs) {
594+
return DeduceTileMscatterType(args, kwargs, "tile.mscatter");
595+
});
596+
538597
REGISTER_OP("tile.move")
539598
.set_op_category("TileOp")
540599
.set_description("Move tile between memory levels (Vec/Mat/Left/Right)")

0 commit comments

Comments
 (0)