@@ -713,6 +713,84 @@ static std::string MakeTileStoreCodegenPTO(const CallPtr& op, codegen::CodegenBa
713713 return " " ;
714714}
715715
716+ // tile.mscatter(src, idx, output_tensor) -> pto.mscatter
717+ // Generates:
718+ // %pview = pto.partition_view %tensor_view, offsets=[0,...], sizes=[d0,...] : ... -> ...
719+ // pto.mscatter ins(%src, %idx : !pto.tile_buf<...>, !pto.tile_buf<...>)
720+ // outs(%pview : !pto.partition_tensor_view<...>)
721+ static std::string MakeTileMscatterCodegenPTO (const CallPtr& op, codegen::CodegenBase& codegen_base) {
722+ auto & codegen = dynamic_cast <codegen::PTOCodegen&>(codegen_base);
723+ INTERNAL_CHECK (op->args_ .size () == 3 )
724+ << " tile.mscatter requires 3 arguments (src, idx, output_tensor), got " << op->args_ .size ();
725+
726+ auto src = AsVarLike (op->args_ [0 ]);
727+ INTERNAL_CHECK (src) << " tile.mscatter src must be a Var or IterArg" ;
728+ auto idx = AsVarLike (op->args_ [1 ]);
729+ INTERNAL_CHECK (idx) << " tile.mscatter idx must be a Var or IterArg" ;
730+ auto output_tensor = AsVarLike (op->args_ [2 ]);
731+ INTERNAL_CHECK (output_tensor) << " tile.mscatter output_tensor must be a Var or IterArg" ;
732+
733+ auto tensor_type = As<TensorType>(output_tensor->GetType ());
734+ INTERNAL_CHECK (tensor_type) << " tile.mscatter output_tensor must have TensorType" ;
735+
736+ std::string src_name = codegen.GetVarName (src);
737+ std::string idx_name = codegen.GetVarName (idx);
738+ std::string src_type_annot = codegen.GetExprTypeAnnotation (op->args_ [0 ]);
739+ std::string idx_type_annot = codegen.GetExprTypeAnnotation (op->args_ [1 ]);
740+
741+ std::string dtype_str = codegen.GetTypeString (tensor_type->dtype_ );
742+ std::string tensor_view = codegen.GetOrCreateTensorView (output_tensor);
743+ std::string tensor_view_type = codegen.GetTensorViewTypeString (tensor_type.get ());
744+
745+ // Build pto.partition_view covering the entire tensor (mscatter uses per-element
746+ // indices, so the partition is the whole tensor — offsets all zero, sizes = shape).
747+ std::string partition_view = codegen.NewNamedTemp (output_tensor->name_hint_ + " _pview" );
748+ std::ostringstream partition_line;
749+ partition_line << partition_view << " = pto.partition_view " << tensor_view;
750+ partition_line << " , offsets = [" ;
751+ for (size_t i = 0 ; i < tensor_type->shape_ .size (); ++i) {
752+ if (i > 0 ) partition_line << " , " ;
753+ partition_line << codegen.GetIndexConstant (0 );
754+ }
755+ partition_line << " ], sizes = [" ;
756+ std::string partition_type = " !pto.partition_tensor_view<" ;
757+ for (size_t i = 0 ; i < tensor_type->shape_ .size (); ++i) {
758+ if (i > 0 ) {
759+ partition_line << " , " ;
760+ partition_type += " x" ;
761+ }
762+ if (auto c = As<ir::ConstInt>(tensor_type->shape_ [i])) {
763+ partition_line << codegen.GetIndexConstant (c->value_ );
764+ partition_type += std::to_string (c->value_ );
765+ } else {
766+ partition_line << codegen.GetExprAsCode (tensor_type->shape_ [i]);
767+ partition_type += " ?" ;
768+ }
769+ }
770+ partition_line << " ]" ;
771+ partition_type += " x" + dtype_str + " >" ;
772+ partition_line << " : " << tensor_view_type << " -> " << partition_type;
773+ codegen.Emit (partition_line.str ());
774+
775+ // Emit pto.mscatter with partition_view in outs()
776+ std::ostringstream mscatter_line;
777+ mscatter_line << " pto.mscatter ins(" << src_name << " , " << idx_name;
778+ if (!src_type_annot.empty () && !idx_type_annot.empty ()) {
779+ mscatter_line << " : " << src_type_annot << " , " << idx_type_annot;
780+ }
781+ mscatter_line << " ) outs(" << partition_view << " : " << partition_type << " )" ;
782+ codegen.Emit (mscatter_line.str ());
783+
784+ // Propagate tensor_view to the result var so downstream ops see the updated tensor
785+ auto result_var = codegen.GetCurrentResultVar ();
786+ if (result_var != nullptr ) {
787+ codegen.RegisterTensorView (result_var, tensor_view);
788+ codegen.RegisterVarToMlir (result_var, tensor_view);
789+ }
790+
791+ return " " ;
792+ }
793+
716794// Helper function for tile.alloc (no-op: allocation handled elsewhere)
717795static std::string MakeTileAllocCodegenPTO (const CallPtr& op, codegen::CodegenBase& codegen_base) {
718796 (void )op;
@@ -1171,7 +1249,6 @@ struct SimpleOpEntry {
11711249static const SimpleOpEntry kSimpleOps [] = {
11721250 // Memory operations
11731251 {" tile.mgather" , " pto.tmgather" , 2 },
1174- {" tile.mscatter" , " pto.tmscatter" , 2 },
11751252 // Tile x Tile arithmetic operations
11761253 {" tile.add" , " pto.tadd" , 2 },
11771254 {" tile.sub" , " pto.tsub" , 2 },
@@ -1321,6 +1398,15 @@ void RegisterPTOOps(Backend& backend, const std::unordered_set<std::string>& exc
13211398 reg (" tile.store" , [](const ir::CallPtr& op, codegen::CodegenBase& codegen) {
13221399 return MakeTileStoreCodegenPTO (op, codegen);
13231400 });
1401+ // tile.mscatter: src and idx must be row_major (MTE3 DMA reads UB linearly)
1402+ if (exclude_ops.count (" tile.mscatter" ) == 0 ) {
1403+ backend.RegisterOp (" tile.mscatter" )
1404+ .f_codegen ([](const ir::CallPtr& op, codegen::CodegenBase& codegen) {
1405+ return MakeTileMscatterCodegenPTO (op, codegen);
1406+ })
1407+ .set_input_layout (0 , ir::TileLayout::row_major)
1408+ .set_input_layout (1 , ir::TileLayout::row_major);
1409+ }
13241410 reg (" tile.alloc" , [](const ir::CallPtr& op, codegen::CodegenBase& codegen) {
13251411 return MakeTileAllocCodegenPTO (op, codegen);
13261412 });
0 commit comments