Skip to content

Commit f3ebeb3

Browse files
committed
Refactor new ops after rebase
1 parent c87dfdd commit f3ebeb3

15 files changed

Lines changed: 525 additions & 868 deletions

lighthouse/dialects/transform/smt_ext/ops/constrain_params.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from mlir.dialects import ext, smt, transform
55

66
from lighthouse.tune import trace
7-
from ..dialect import TransformSMTExtensionDialect
7+
from lighthouse.dialects.transform.smt_ext import TransformSMTExtensionDialect
88

99

1010
class ConstrainParamsOp(

lighthouse/dialects/transform/transform_ext/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,19 @@
55
from .ops.get_named_attribute import get_named_attribute
66
from .ops.param_cmp_eq import param_cmp_eq
77
from .ops.replace import replace
8+
from .ops.convert_func_results_to_args import convert_func_results_to_args
9+
from .ops.extract_handle import extract_handle
10+
from .ops.get_tileable_consumers import get_tileable_consumers
11+
from .ops.get_tiling_sizes import get_tiling_sizes
812

913
__all__ = [
1014
"TransformExtensionDialect",
15+
"convert_func_results_to_args",
16+
"extract_handle",
1117
"get_named_attribute",
18+
"get_named_attribute",
19+
"get_tileable_consumers",
20+
"get_tiling_sizes",
1221
"param_cmp_eq",
1322
"register_and_load",
1423
"replace",
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
from mlir import ir
2+
from mlir.dialects import ext, transform, func, bufferization
3+
from mlir.dialects.transform import DiagnosedSilenceableFailure
4+
5+
from lighthouse.dialects.transform.transform_ext import TransformExtensionDialect
6+
from lighthouse.ingress.mlir_gen.utils import emit_buf_to_tensor
7+
from lighthouse.utils.mlir import func_cif
8+
9+
10+
class ConvertFuncResultsToArgsOp(
11+
TransformExtensionDialect.Operation, name="convert_func_results_to_args"
12+
):
13+
"""Converts all function return values to function arguments.
14+
15+
Function return values are placed in the beginning of the argument list,
16+
followed by the original function arguments.
17+
18+
Function arguments are converted to memrefs with appropriate bufferization
19+
annotations for inputs (bufferization.to_tensor with restrict=True) and
20+
outputs (bufferization.materialize_in_destination).
21+
22+
Currently supports only functions with tensor arguments and return values.
23+
"""
24+
25+
target: ext.Operand[transform.AnyOpType]
26+
converted_func: ext.Result[transform.AnyOpType[()]] = ext.result(infer_type=True)
27+
28+
@classmethod
29+
def attach_interface_impls(cls, context=None):
30+
cls.TransformOpInterfaceModel.attach(cls.OPERATION_NAME, context=context)
31+
cls.MemoryEffectsOpInterfaceModel.attach(cls.OPERATION_NAME, context=context)
32+
33+
@staticmethod
34+
def convert_func(target: func.FuncOp) -> func.FuncOp:
35+
def memref_t(ttype: ir.Type) -> ir.MemRefType:
36+
return ir.MemRefType.get(ttype.shape, ttype.element_type)
37+
38+
func_name = target.sym_name.value
39+
func_inputs = list(target.type.inputs)
40+
func_results = list(target.type.results)
41+
assert all(isinstance(ty, ir.RankedTensorType) for ty in func_inputs), (
42+
"Only tensors are supported as input types"
43+
)
44+
assert all(isinstance(ty, ir.RankedTensorType) for ty in func_results), (
45+
"Only tensors are supported as return types"
46+
)
47+
48+
nresults = len(func_results)
49+
new_args = [memref_t(ty) for ty in func_results + func_inputs]
50+
51+
@func_cif(*new_args, name=func_name)
52+
def f(*args):
53+
outputs = args[:nresults]
54+
inputs = args[nresults:]
55+
# convert input memrefs to tensors
56+
input_tensors = []
57+
for input in inputs:
58+
t = emit_buf_to_tensor(input, restrict=True)
59+
input_tensors.append(t)
60+
61+
# clone function body and map args and return values
62+
cloned_map = {}
63+
for op in target.regions[0].blocks[0].operations:
64+
if isinstance(op, func.ReturnOp):
65+
# emit materialize_in_destination for each return value
66+
for i, res_val in enumerate(op.operands):
67+
if res_val.owner not in cloned_map:
68+
raise NotImplementedError("Unsupported return value")
69+
iresult = res_val.result_number
70+
new_val = cloned_map[res_val.owner].results[iresult]
71+
bufferization.materialize_in_destination(
72+
None,
73+
new_val,
74+
outputs[i],
75+
restrict=True,
76+
writable=True,
77+
)
78+
else:
79+
new_op = op.clone()
80+
for i, oo in enumerate(op.operands):
81+
if isinstance(oo, ir.BlockArgument):
82+
# operand is func argument
83+
# replace with new input tensors
84+
new_op.operands[i] = input_tensors[oo.arg_number]
85+
else:
86+
# replace operands with cloned values
87+
if oo.owner in cloned_map:
88+
iresult = oo.result_number
89+
new_op.operands[i] = cloned_map[oo.owner].results[
90+
iresult
91+
]
92+
cloned_map[op] = new_op
93+
94+
return f.func_op
95+
96+
class TransformOpInterfaceModel(transform.TransformOpInterface):
97+
@staticmethod
98+
def apply(
99+
op: "ConvertFuncResultsToArgsOp",
100+
_rewriter: transform.TransformRewriter,
101+
results: transform.TransformResults,
102+
state: transform.TransformState,
103+
) -> DiagnosedSilenceableFailure:
104+
targets = state.get_payload_ops(op.target)
105+
converted_funcs = []
106+
107+
for target in targets:
108+
if not isinstance(target, func.FuncOp):
109+
return DiagnosedSilenceableFailure.SilenceableFailure
110+
111+
with ir.InsertionPoint(target), target.location:
112+
new_func = ConvertFuncResultsToArgsOp.convert_func(target)
113+
target.erase()
114+
converted_funcs.append(new_func)
115+
116+
results.set_ops(op.converted_func, converted_funcs)
117+
118+
return DiagnosedSilenceableFailure.Success
119+
120+
@staticmethod
121+
def allow_repeated_handle_operands(_op: "ConvertFuncResultsToArgsOp") -> bool:
122+
return False
123+
124+
class MemoryEffectsOpInterfaceModel(ir.MemoryEffectsOpInterface):
125+
@staticmethod
126+
def get_effects(op: "ConvertFuncResultsToArgsOp", effects):
127+
transform.consumes_handle(op.op_operands, effects)
128+
transform.produces_handle(op.results, effects)
129+
transform.modifies_payload(effects)
130+
131+
132+
def convert_func_results_to_args(
133+
target: ir.Value[transform.AnyOpType], bench_name: str | None = None
134+
) -> ir.Value[transform.AnyOpType]:
135+
"""snake_case wrapper to create a ConvertFuncResultsToArgsOp."""
136+
op = ConvertFuncResultsToArgsOp(target=target)
137+
return op.converted_func
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from mlir import ir
2+
from mlir.dialects import ext, transform
3+
from mlir.dialects.transform import DiagnosedSilenceableFailure
4+
5+
from lighthouse.dialects.transform.transform_ext import TransformExtensionDialect
6+
7+
8+
class ExtractHandleOp(TransformExtensionDialect.Operation, name="extract_handle"):
9+
"""
10+
Returns the handle at the specified index in `target`.
11+
12+
Args:
13+
target: Handle(s) to target op(s)
14+
index: Index of the handle to extract. Supports Python-style indexing.
15+
Returns:
16+
The handle at the specified index in `target`.
17+
"""
18+
19+
target: ext.Operand[transform.AnyOpType]
20+
index: ext.Operand[transform.AnyParamType]
21+
ops: ext.Result[transform.AnyOpType[()]] = ext.result(infer_type=True)
22+
23+
@classmethod
24+
def attach_interface_impls(cls, ctx=None):
25+
cls.TransformOpInterfaceModel.attach(cls.OPERATION_NAME, context=ctx)
26+
cls.MemoryEffectsOpInterfaceModel.attach(cls.OPERATION_NAME, context=ctx)
27+
28+
class TransformOpInterfaceModel(transform.TransformOpInterface):
29+
@staticmethod
30+
def apply(
31+
op: "ExtractHandleOp",
32+
_rewriter: transform.TransformRewriter,
33+
results: transform.TransformResults,
34+
state: transform.TransformState,
35+
) -> DiagnosedSilenceableFailure:
36+
target_ops = state.get_payload_ops(op.target)
37+
index_attr = state.get_params(op.index)
38+
if len(index_attr) == 1 and isinstance(index_attr[0], ir.IntegerAttr):
39+
index = index_attr[0].value
40+
else:
41+
return DiagnosedSilenceableFailure.SilenceableFailure
42+
43+
n = len(target_ops)
44+
if index >= n or index < -n:
45+
raise IndexError(
46+
f"extract_handle: Invalid index {index} for target of length {len(target_ops)}"
47+
)
48+
handle = target_ops[index]
49+
results.set_ops(op.ops, [handle])
50+
return DiagnosedSilenceableFailure.Success
51+
52+
@staticmethod
53+
def allow_repeated_handle_operands(_op: "ExtractHandleOp") -> bool:
54+
return False
55+
56+
class MemoryEffectsOpInterfaceModel(ir.MemoryEffectsOpInterface):
57+
@staticmethod
58+
def get_effects(op: ir.Operation, effects):
59+
transform.only_reads_handle(op.op_operands, effects)
60+
transform.produces_handle(op.results, effects)
61+
transform.only_reads_payload(effects)
62+
63+
64+
def extract_handle(
65+
target: ir.Value[transform.AnyOpType],
66+
index: int | ir.Value[transform.AnyParamType],
67+
) -> ir.Value:
68+
"""
69+
snake_case wrapper to create a ExtractHandleOp.
70+
71+
Args:
72+
target: Handle(s) to target op(s)
73+
index: Index of the handle to extract. Supports Python-style indexing.
74+
Returns:
75+
The handle at the specified index in `target`.
76+
"""
77+
if isinstance(index, int):
78+
param_attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), index)
79+
index = transform.ParamConstantOp(transform.AnyParamType.get(), param_attr)
80+
81+
return ExtractHandleOp(target=target, index=index).result

lighthouse/dialects/transform/transform_ext/ops/get_named_attribute.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from mlir.dialects import ext, transform
33
from mlir.dialects.transform import DiagnosedSilenceableFailure
44

5-
from ..dialect import TransformExtensionDialect
5+
from lighthouse.dialects.transform.transform_ext import TransformExtensionDialect
66

77

88
class GetNamedAttributeOp(
@@ -16,9 +16,9 @@ class GetNamedAttributeOp(
1616
with the name `attr_name`, the operation fails.
1717
"""
1818

19-
param: ext.Result[transform.AnyParamType[()]]
2019
target: ext.Operand[transform.AnyOpType]
2120
attr_name: ir.StringAttr
21+
param: ext.Result[transform.AnyParamType[()]] = ext.result(infer_type=True)
2222

2323
@classmethod
2424
def attach_interface_impls(cls, context=None):
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
from mlir import ir
2+
from mlir.dialects import ext, transform, linalg
3+
from mlir.dialects.transform import DiagnosedSilenceableFailure
4+
5+
from lighthouse.dialects.transform.transform_ext import TransformExtensionDialect
6+
7+
8+
class GetTileableConsumersOp(
9+
TransformExtensionDialect.Operation, name="get_tileable_consumers"
10+
):
11+
"""
12+
Find consumer ops of the `target` operation that are tileable linalg ops.
13+
14+
If no such consumers are found, the operation returns the target itself.
15+
16+
Args:
17+
target: Handle to target op
18+
Returns:
19+
List of tileable consumer ops, or the target op itself.
20+
"""
21+
22+
target: ext.Operand[transform.AnyOpType]
23+
ops: ext.Result[transform.AnyOpType[()]] = ext.result(infer_type=True)
24+
25+
@classmethod
26+
def attach_interface_impls(cls, ctx=None):
27+
cls.TransformOpInterfaceModel.attach(cls.OPERATION_NAME, context=ctx)
28+
cls.MemoryEffectsOpInterfaceModel.attach(cls.OPERATION_NAME, context=ctx)
29+
30+
@staticmethod
31+
def get_op_users(val: ir.Value) -> list[ir.Operation]:
32+
op_users = []
33+
for use in val.uses:
34+
user = use.owner
35+
if not isinstance(user, ir.OpView):
36+
continue
37+
op_users.append(user.operation)
38+
return op_users
39+
40+
@staticmethod
41+
def is_tileable_op(op: ir.Operation) -> bool:
42+
# TODO expand list as needed and/or check traits/interfaces
43+
linalg_ops = [
44+
linalg.ElementwiseOp,
45+
linalg.AddOp,
46+
linalg.SubOp,
47+
linalg.MulOp,
48+
linalg.DivOp,
49+
linalg.ExpOp,
50+
linalg.MaxOp,
51+
linalg.MinOp,
52+
linalg.FillOp,
53+
linalg.MatmulOp,
54+
linalg.GenericOp,
55+
]
56+
return isinstance(op.opview, tuple(linalg_ops))
57+
58+
class TransformOpInterfaceModel(transform.TransformOpInterface):
59+
@staticmethod
60+
def apply(
61+
op: "GetTileableConsumersOp",
62+
_rewriter: transform.TransformRewriter,
63+
results: transform.TransformResults,
64+
state: transform.TransformState,
65+
) -> DiagnosedSilenceableFailure:
66+
target_ops = state.get_payload_ops(op.target)
67+
68+
if len(target_ops) != 1:
69+
return DiagnosedSilenceableFailure.SilenceableFailure
70+
71+
new_ops = []
72+
target: ir.Operation = target_ops[0]
73+
op_res = target.results
74+
while len(op_res) == 1:
75+
users = op.get_op_users(op_res[0])
76+
if len(users) != 1:
77+
break
78+
user = users[0]
79+
if not op.is_tileable_op(user):
80+
break
81+
new_ops.append(user)
82+
op_res = user.results
83+
84+
if not new_ops:
85+
new_ops = [target]
86+
results.set_ops(op.ops, new_ops)
87+
return DiagnosedSilenceableFailure.Success
88+
89+
@staticmethod
90+
def allow_repeated_handle_operands(_op: "GetTileableConsumersOp") -> bool:
91+
return False
92+
93+
class MemoryEffectsOpInterfaceModel(ir.MemoryEffectsOpInterface):
94+
@staticmethod
95+
def get_effects(op: ir.Operation, effects):
96+
transform.only_reads_handle(op.op_operands, effects)
97+
transform.produces_handle(op.results, effects)
98+
transform.only_reads_payload(effects)
99+
100+
101+
def get_tileable_consumers(
102+
target: ir.Value[transform.AnyOpType],
103+
) -> ir.Value:
104+
"""
105+
snake_case wrapper to create a GetTileableConsumersOp.
106+
107+
Args:
108+
target: Handle to target op
109+
Returns:
110+
List of tileable consumer ops, or the target op itself.
111+
"""
112+
return GetTileableConsumersOp(target=target).ops

0 commit comments

Comments
 (0)