Skip to content

Commit d2d4c49

Browse files
committed
save work
1 parent b8f5616 commit d2d4c49

2 files changed

Lines changed: 104 additions & 0 deletions

File tree

lighthouse/dialects/transform/transform_ext/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .ops.extract_handle import extract_handle
1010
from .ops.get_tileable_consumers import get_tileable_consumers
1111
from .ops.get_tiling_sizes import get_tiling_sizes
12+
from .ops.update_address_space import update_address_space
1213

1314
__all__ = [
1415
"TransformExtensionDialect",
@@ -22,4 +23,5 @@
2223
"register_and_load",
2324
"replace",
2425
"wrap_in_benching_func",
26+
"update_address_space",
2527
]
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from mlir import ir
2+
from mlir.dialects import ext, transform, memref
3+
from mlir.dialects.transform import DiagnosedSilenceableFailure
4+
5+
from lighthouse.dialects.transform.transform_ext import TransformExtensionDialect
6+
7+
8+
class UpdateAddressSpace(TransformExtensionDialect.Operation, name="update_address_space"):
9+
"""Update the address space of a memref allocation operation.
10+
11+
Takes a target memref allocation operation and updates its address space
12+
to the provided value.
13+
"""
14+
15+
target: ext.Operand[transform.AnyOpType]
16+
address_space: ir.IntegerAttr
17+
updated_op: ext.Result[transform.AnyOpType[()]] = ext.result(infer_type=True)
18+
19+
@classmethod
20+
def attach_interface_impls(cls, ctx=None):
21+
cls.TransformOpInterfaceModel.attach(cls.OPERATION_NAME, context=ctx)
22+
cls.MemoryEffectsOpInterfaceModel.attach(cls.OPERATION_NAME, context=ctx)
23+
24+
class TransformOpInterfaceModel(transform.TransformOpInterface):
25+
@staticmethod
26+
def apply(
27+
op: "UpdateAddressSpace",
28+
rewriter: transform.TransformRewriter,
29+
results: transform.TransformResults,
30+
state: transform.TransformState,
31+
) -> DiagnosedSilenceableFailure:
32+
# Get the target operations to transform
33+
target_ops = state.get_payload_ops(op.target)
34+
35+
# Get the address space value from the attribute
36+
address_space_value = ir.IntegerAttr(op.address_space).value
37+
38+
new_ops = []
39+
40+
for target_op in target_ops:
41+
# Verify this is a memref.alloca operation
42+
if target_op.OPERATION_NAME != "memref.alloca":
43+
return DiagnosedSilenceableFailure.emit_silenceable_error(
44+
f"Expected memref.alloca operation, got {target_op.OPERATION_NAME}"
45+
)
46+
47+
# Get the current result type (should be a MemRefType)
48+
old_result_type = target_op.results[0].type
49+
50+
memref_type = ir.MemRefType(old_result_type)
51+
52+
# Create a new memref type with the specified address space
53+
new_memref_type = ir.MemRefType.get(
54+
memref_type.shape,
55+
memref_type.element_type,
56+
layout=memref_type.layout,
57+
memory_space=ir.Attribute.parse(f"{address_space_value}")
58+
)
59+
print(new_memref_type)
60+
61+
# Replace the operation with a new one that has the updated type
62+
with ir.InsertionPoint(target_op):
63+
64+
# Get the operands from the original alloca (dynamic sizes and symbols)
65+
dynamic_sizes = list(target_op.operands[:target_op.attributes["operandSegmentSizes"][0]])
66+
symbol_operands = list(target_op.operands[target_op.attributes["operandSegmentSizes"][0]:])
67+
68+
# Create a new alloca with the updated type
69+
new_alloca = memref.alloca(new_memref_type, dynamic_sizes, symbol_operands)
70+
print(new_alloca)
71+
72+
# Replace all uses of the old operation with the new one
73+
# rewriter.replace_all_uses_with(target_op.results[0], new_alloca.results[0])
74+
75+
# Erase the old operation
76+
rewriter.replace_op(target_op, [new_alloca])
77+
78+
new_ops.append(new_alloca.owner)
79+
80+
# Set the results to the new operations
81+
results.set_ops(op.updated_op, new_ops)
82+
return DiagnosedSilenceableFailure.Success
83+
84+
@staticmethod
85+
def allow_repeated_handle_operands(_op: "UpdateAddressSpace") -> bool:
86+
return False
87+
88+
class MemoryEffectsOpInterfaceModel(ir.MemoryEffectsOpInterface):
89+
@staticmethod
90+
def get_effects(op: ir.Operation, effects):
91+
transform.consumes_handle(op.op_operands[:1], effects)
92+
transform.produces_handle(op.results, effects)
93+
transform.modifies_payload(effects)
94+
95+
96+
def update_address_space(
97+
target: ir.Value,
98+
address_space: int | ir.IntegerAttr,
99+
) -> ir.Value:
100+
if not isinstance(address_space, ir.IntegerAttr):
101+
address_space = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), address_space)
102+
return UpdateAddressSpace(target, address_space=address_space).updated_op

0 commit comments

Comments
 (0)