Skip to content

Commit eacc9d8

Browse files
committed
save work
1 parent 32a345a commit eacc9d8

2 files changed

Lines changed: 42 additions & 62 deletions

File tree

lighthouse/dialects/transform/transform_ext/ops/update_address_space.py

Lines changed: 34 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -32,51 +32,46 @@ def apply(
3232
state: transform.TransformState,
3333
) -> DiagnosedSilenceableFailure:
3434
# Get the target operations to transform
35-
target_ops = state.get_payload_ops(op.target)
35+
target_op = state.get_payload_ops(op.target)[0]
3636
# Get the address space value from the attribute
3737
address_space_value = ir.IntegerAttr(op.address_space).value
3838
new_ops = []
3939

40-
for target_op in target_ops:
41-
# Verify this is a memref.alloca operation
42-
if target_op.OPERATION_NAME != "memref.alloca":
43-
return DiagnosedSilenceableFailure.emit_silenceable_error(
44-
f"Expected memref.alloca operation, got {target_op.OPERATION_NAME}"
45-
)
46-
47-
# Get the current result type (should be a MemRefType)
48-
old_result_type = target_op.results[0].type
49-
memref_type = ir.MemRefType(old_result_type)
50-
# Create a new memref type with the specified address space
51-
new_memref_type = ir.MemRefType.get(
52-
memref_type.shape,
53-
memref_type.element_type,
54-
layout=memref_type.layout,
55-
memory_space=ir.Attribute.parse(f"{address_space_value}"),
40+
# Verify this is a memref.alloca operation
41+
if target_op.OPERATION_NAME != "memref.alloca":
42+
return DiagnosedSilenceableFailure.emit_silenceable_error(
43+
f"Expected memref.alloca operation, got {target_op.OPERATION_NAME}"
5644
)
5745

58-
# Replace the operation with a new one that has the updated type
59-
with ir.InsertionPoint(target_op):
60-
# Get the operands from the original alloca (dynamic sizes and symbols)
61-
dynamic_sizes = list(
62-
target_op.operands[
63-
: target_op.attributes["operandSegmentSizes"][0]
64-
]
65-
)
66-
symbol_operands = list(
67-
target_op.operands[
68-
target_op.attributes["operandSegmentSizes"][0] :
69-
]
70-
)
71-
# Create a new alloca with the updated type
72-
new_alloca = memref.alloca(
73-
new_memref_type, dynamic_sizes, symbol_operands
74-
)
75-
# Replace all uses of the old operation with the new one
76-
# FIXME: This won't handle operations that consume the memref type and
77-
# return a new memref (such as subview).
78-
rewriter.replace_op(target_op, [new_alloca])
79-
new_ops.append(new_alloca.owner)
46+
# Get the current result type (should be a MemRefType)
47+
old_result_type = target_op.results[0].type
48+
memref_type = ir.MemRefType(old_result_type)
49+
# Create a new memref type with the specified address space
50+
new_memref_type = ir.MemRefType.get(
51+
memref_type.shape,
52+
memref_type.element_type,
53+
layout=memref_type.layout,
54+
memory_space=ir.Attribute.parse(f"{address_space_value}"),
55+
)
56+
57+
# Replace the operation with a new one that has the updated type
58+
with ir.InsertionPoint(target_op):
59+
# Get the operands from the original alloca (dynamic sizes and symbols)
60+
dynamic_sizes = list(
61+
target_op.operands[: target_op.attributes["operandSegmentSizes"][0]]
62+
)
63+
symbol_operands = list(
64+
target_op.operands[target_op.attributes["operandSegmentSizes"][0] :]
65+
)
66+
# Create a new alloca with the updated type
67+
new_alloca = memref.alloca(
68+
new_memref_type, dynamic_sizes, symbol_operands
69+
)
70+
# Replace all uses of the old operation with the new one
71+
# FIXME: This won't handle operations that consume the memref type and
72+
# return a new memref (such as subview).
73+
rewriter.replace_op(target_op, [new_alloca])
74+
new_ops.append(new_alloca.owner)
8075

8176
# Set the results to the new operations
8277
results.set_ops(op.updated_op, new_ops)

lighthouse/schedule/xegpu/softmax_schedule.py

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
PipelineInterrupt,
1717
)
1818
from lighthouse.schedule.xegpu.helper import bundle_xegpu_to_binary
19+
from lighthouse.dialects.transform import transform_ext
1920

2021

2122
def get_softmax_schedule_module(
@@ -140,7 +141,6 @@ def bundle_xegpu_softmax_schedule(
140141
transform.AnyOpType.get(), func, ops=["linalg.softmax"]
141142
)
142143
structured.structured_decompose_interface(anytype, softmax_ops)
143-
transform.print_(target=func, name="Aftemr structured_decompose_interface")
144144

145145
linalg_ops = match_and_split(
146146
func, ops={"linalg.generic", "linalg.fill"}, nhandles=6
@@ -157,18 +157,13 @@ def bundle_xegpu_softmax_schedule(
157157
div_op, sizes=[0, reduction_step_size]
158158
).results
159159

160-
transform.print_(target=func, name="After tiling div op")
161-
162160
# Fuse max_center_and_exp_op into the div loop
163161
_, fused_loop = structured.structured_fuse_into_containing_op(
164162
anytype,
165163
anytype,
166164
producer_op=max_center_and_exp_op,
167165
containing_op=div_loop,
168166
)
169-
transform.print_(
170-
target=func, name="After fusing max_center_and_exp_op into div loop"
171-
)
172167

173168
# Tile the sum reduction and fuse the sub+exp producer into it
174169
_, _, _, sum_loop = structured.structured_tile_reduction_using_for(
@@ -180,8 +175,6 @@ def bundle_xegpu_softmax_schedule(
180175
tile_sizes=[0, reduction_step_size],
181176
)
182177

183-
transform.print_(target=func, name="After tiling sum reduction")
184-
185178
func = transform.get_parent_op(
186179
anytype,
187180
fused_loop,
@@ -200,9 +193,6 @@ def bundle_xegpu_softmax_schedule(
200193
producer_op=max_center_and_exp_op,
201194
containing_op=sum_loop,
202195
)
203-
transform.print_(
204-
target=func, name="After fusing max_center_and_exp_op into sum loop"
205-
)
206196

207197
# Tile the max reduction.
208198
max_reduction = linalg_ops[0]
@@ -214,7 +204,6 @@ def bundle_xegpu_softmax_schedule(
214204
target=max_reduction,
215205
tile_sizes=[0, reduction_step_size],
216206
)
217-
transform.print_(target=func, name="After tiling max reduction")
218207

219208
# Cleanup after tiling and fusion
220209
transform.apply_cse(func)
@@ -231,8 +220,6 @@ def bundle_xegpu_softmax_schedule(
231220
transform.apply_cse(func)
232221
canonicalize(func)
233222

234-
transform.print_(target=func, name="After vectorization")
235-
236223
if stop_at_stage == "vectorized":
237224
raise PipelineInterrupt()
238225

@@ -250,8 +237,6 @@ def bundle_xegpu_softmax_schedule(
250237
transform.apply_cse(mod)
251238
canonicalize(mod)
252239

253-
transform.print_(target=mod, name="After bufferization")
254-
255240
# promote memref.alloc to memref.alloca in payload function
256241
func = match(mod, ops={"func.func"})
257242
func = apply_registered_pass(
@@ -263,8 +248,6 @@ def bundle_xegpu_softmax_schedule(
263248
},
264249
)
265250

266-
transform.print_(target=func, name="After promoting buffers to stack")
267-
268251
if stop_at_stage == "bufferized":
269252
raise PipelineInterrupt()
270253

@@ -294,8 +277,6 @@ def bundle_xegpu_softmax_schedule(
294277
mod = apply_registered_pass(mod, "gpu-kernel-outlining")
295278
transform.apply_cse(mod)
296279

297-
transform.print_(target=mod, name="After GPU outlining")
298-
299280
if stop_at_stage == "gpu-outlining":
300281
raise PipelineInterrupt()
301282

@@ -306,12 +287,16 @@ def bundle_xegpu_softmax_schedule(
306287
options={"O": "3", "chip": "bmg"},
307288
)
308289

309-
# convert vector to xegpu
290+
# for each gpu function in the gpu module, change memref.alloca address
291+
# space to 3 (SLM) and convert vector to xegpu.
310292
gpu_mod_ops = match_and_split(mod, ops={"gpu.module"})
311293
for gpu_mod in gpu_mod_ops:
312294
gpu_func = match(gpu_mod, ops={"gpu.func"})
313-
gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu")
314-
transform.apply_cse(gpu_func)
295+
allocas = match_and_split(gpu_func, ops={"memref.alloca"})
296+
for alloca in allocas:
297+
transform_ext.update_address_space(alloca, address_space=3)
298+
# gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu")
299+
# transform.apply_cse(gpu_func)
315300

316301
# Cleanup.
317302
transform.apply_cse(mod)

0 commit comments

Comments
 (0)