1616 PipelineInterrupt ,
1717)
1818from lighthouse .schedule .xegpu .helper import bundle_xegpu_to_binary
19+ from lighthouse .dialects .transform import transform_ext
1920
2021
2122def get_softmax_schedule_module (
@@ -140,7 +141,6 @@ def bundle_xegpu_softmax_schedule(
140141 transform .AnyOpType .get (), func , ops = ["linalg.softmax" ]
141142 )
142143 structured .structured_decompose_interface (anytype , softmax_ops )
143- transform .print_ (target = func , name = "Aftemr structured_decompose_interface" )
144144
145145 linalg_ops = match_and_split (
146146 func , ops = {"linalg.generic" , "linalg.fill" }, nhandles = 6
@@ -157,18 +157,13 @@ def bundle_xegpu_softmax_schedule(
157157 div_op , sizes = [0 , reduction_step_size ]
158158 ).results
159159
160- transform .print_ (target = func , name = "After tiling div op" )
161-
162160 # Fuse max_center_and_exp_op into the div loop
163161 _ , fused_loop = structured .structured_fuse_into_containing_op (
164162 anytype ,
165163 anytype ,
166164 producer_op = max_center_and_exp_op ,
167165 containing_op = div_loop ,
168166 )
169- transform .print_ (
170- target = func , name = "After fusing max_center_and_exp_op into div loop"
171- )
172167
173168 # Tile the sum reduction and fuse the sub+exp producer into it
174169 _ , _ , _ , sum_loop = structured .structured_tile_reduction_using_for (
@@ -180,8 +175,6 @@ def bundle_xegpu_softmax_schedule(
180175 tile_sizes = [0 , reduction_step_size ],
181176 )
182177
183- transform .print_ (target = func , name = "After tiling sum reduction" )
184-
185178 func = transform .get_parent_op (
186179 anytype ,
187180 fused_loop ,
@@ -200,9 +193,6 @@ def bundle_xegpu_softmax_schedule(
200193 producer_op = max_center_and_exp_op ,
201194 containing_op = sum_loop ,
202195 )
203- transform .print_ (
204- target = func , name = "After fusing max_center_and_exp_op into sum loop"
205- )
206196
207197 # Tile the max reduction.
208198 max_reduction = linalg_ops [0 ]
@@ -214,7 +204,6 @@ def bundle_xegpu_softmax_schedule(
214204 target = max_reduction ,
215205 tile_sizes = [0 , reduction_step_size ],
216206 )
217- transform .print_ (target = func , name = "After tiling max reduction" )
218207
219208 # Cleanup after tiling and fusion
220209 transform .apply_cse (func )
@@ -231,8 +220,6 @@ def bundle_xegpu_softmax_schedule(
231220 transform .apply_cse (func )
232221 canonicalize (func )
233222
234- transform .print_ (target = func , name = "After vectorization" )
235-
236223 if stop_at_stage == "vectorized" :
237224 raise PipelineInterrupt ()
238225
@@ -250,8 +237,6 @@ def bundle_xegpu_softmax_schedule(
250237 transform .apply_cse (mod )
251238 canonicalize (mod )
252239
253- transform .print_ (target = mod , name = "After bufferization" )
254-
255240 # promote memref.alloc to memref.alloca in payload function
256241 func = match (mod , ops = {"func.func" })
257242 func = apply_registered_pass (
@@ -263,8 +248,6 @@ def bundle_xegpu_softmax_schedule(
263248 },
264249 )
265250
266- transform .print_ (target = func , name = "After promoting buffers to stack" )
267-
268251 if stop_at_stage == "bufferized" :
269252 raise PipelineInterrupt ()
270253
@@ -294,8 +277,6 @@ def bundle_xegpu_softmax_schedule(
294277 mod = apply_registered_pass (mod , "gpu-kernel-outlining" )
295278 transform .apply_cse (mod )
296279
297- transform .print_ (target = mod , name = "After GPU outlining" )
298-
299280 if stop_at_stage == "gpu-outlining" :
300281 raise PipelineInterrupt ()
301282
@@ -306,12 +287,16 @@ def bundle_xegpu_softmax_schedule(
306287 options = {"O" : "3" , "chip" : "bmg" },
307288 )
308289
309- # convert vector to xegpu
290+ # for each gpu function in the gpu module, change memref.alloca address
291+ # space to 3 (SLM) and convert vector to xegpu.
310292 gpu_mod_ops = match_and_split (mod , ops = {"gpu.module" })
311293 for gpu_mod in gpu_mod_ops :
312294 gpu_func = match (gpu_mod , ops = {"gpu.func" })
313- gpu_func = apply_registered_pass (gpu_func , "convert-vector-to-xegpu" )
314- transform .apply_cse (gpu_func )
295+ allocas = match_and_split (gpu_func , ops = {"memref.alloca" })
296+ for alloca in allocas :
297+ transform_ext .update_address_space (alloca , address_space = 3 )
298+ # gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu")
299+ # transform.apply_cse(gpu_func)
315300
316301 # Cleanup.
317302 transform .apply_cse (mod )
0 commit comments