Skip to content
85 changes: 85 additions & 0 deletions examples/end-to-end/KernelBench/test_kernel_bench.py
Comment thread
rengolin marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# RUN: python %s | FileCheck %s
Comment thread
rengolin marked this conversation as resolved.

# REQUIRES: torch
# REQUIRES: kernel_bench

import subprocess
from pathlib import Path

tests = [
{
"kernel": "level1/1_Square_matrix_multiplication_.py",
"input_shapes": "32x32xf32xrnd,32x32xf32xid",
"output_shape": "32x32xf32x0",
},
{
"kernel": "level1/1_Square_matrix_multiplication_.py",
"input_shapes": "32x32xbf16xrnd,32x32xbf16xid",
"output_shape": "32x32xbf16x0",
},
{
"kernel": "level1/2_Standard_matrix_multiplication_.py",
"input_shapes": "8x16xf32xrnd,16x8xf32xrnd",
"output_shape": "8x8xf32x0",
},
{
"kernel": "level1/2_Standard_matrix_multiplication_.py",
"input_shapes": "8x16xbf16xrnd,16x8xbf16xrnd",
"output_shape": "8x8xbf16x0",
},
]

if __name__ == "__main__":
project_root = Path(__file__).parent.parent.parent.parent
kb_program = project_root / "tools" / "kernel_bench"
kb_path = project_root / "third_party" / "KernelBench" / "KernelBench"

for test in tests:
kb_kernel = kb_path / test["kernel"]
command_line = [
str(kb_program),
str(kb_kernel),
"--input-shapes",
test["input_shapes"],
"--output-shape",
test["output_shape"],
"--print-tensor=1",
"--seed=42",
]
print(f"Running command: {' '.join(command_line)}")
result = subprocess.run(
command_line,
capture_output=True,
text=True,
)

print("STDOUT:")
print(result.stdout)
print("STDERR:")
print(result.stderr)
print(f"Return code: {result.returncode}")
assert result.returncode == 0, "Execution failed"

# CHECK: 1_Square_matrix_multiplication_.mlir
# CHECK 0.37454012 0.9507143 0.7319939 ... 0.04645041 0.60754484 0.17052412
# CHECK: 0.27214515 0.59023064 0.3609739 ... 0.297349 0.9243962 0.97105825

# CHECK-NOT: Execution failed

# CHECK: 1_Square_matrix_multiplication_.mlir
# CHECK 0.375 0.949219 0.730469 ... 0.0463867 0.609375 0.170898
# CHECK: 0.271484 0.589844 0.361328 ... 0.296875 0.925781 0.972656

# CHECK-NOT: Execution failed

# CHECK: 2_Standard_matrix_multiplication_.mlir
# CHECK: 3.120935 3.7697 4.5365195 4.397648 4.4506536 3.2665431 3.5362916
# CHECK: 5.036752 5.312808 5.8109508 4.810084 4.7435184 4.35573 5.311559

# CHECK-NOT: Execution failed

# CHECK: 2_Standard_matrix_multiplication_.mlir
# CHECK: 3.125 3.76562 4.53125 4.40625 4.4375 3.26562 3.53125 3.9375
# CHECK: 5.03125 5.3125 5.8125 4.8125 4.75 4.34375 5.3125 5.5625

# CHECK-NOT: Execution failed
7 changes: 7 additions & 0 deletions lighthouse/execution/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,10 @@ def parse(shape_str: str) -> KernelArgument:
raise ValueError(f"Invalid init type in shape string: {shape_str}")

return KernelArgument(dims, element_type, init_type)

@staticmethod
def parse_all(shape_str: str) -> list[KernelArgument]:
"""
Parse a shape string in the format MxNx...xTypexInit into a list of KernelArguments.
"""
return [KernelArgumentParser.parse(s) for s in shape_str.split(",")]
34 changes: 25 additions & 9 deletions lighthouse/ingress/torch/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,10 @@ def import_from_file(
model_class_name: str = "Model",
init_args_fn_name: str | None = "get_init_inputs",
init_kwargs_fn_name: str | None = None,
model_init_args: Iterable | None = None,
sample_args_fn_name: str = "get_inputs",
sample_kwargs_fn_name: str | None = None,
sample_args: Iterable | None = None,
state_path: str | Path | None = None,
dialect: OutputType | str = OutputType.LINALG_ON_TENSORS,
ir_context: ir.Context | None = None,
Expand All @@ -131,10 +133,16 @@ def import_from_file(
init_kwargs_fn_name (str | None, optional): The name of the function in the file
that returns the keyword arguments for initializing the model. If None, the model
is initialized without keyword arguments.
model_init_args (Iterable | None, optional): If provided, these are used directly as
initialization arguments instead of calling ``init_args_fn_name`` from the file.
Useful for overriding hard-coded sizes in the model file. Defaults to None.
sample_args_fn_name (str, optional): The name of the function in the file that
returns the sample input arguments for the model. Defaults to "get_inputs".
sample_kwargs_fn_name (str, optional): The name of the function in the file that
returns the sample keyword input arguments for the model. Defaults to None.
sample_args (Iterable | None, optional): If provided, these are used directly as
sample inputs instead of calling ``sample_args_fn_name`` from the file.
Useful for overriding hard-coded sizes in the model file. Defaults to None.
state_path (str | Path | None, optional): Optional path to a file containing
the model's ``state_dict``. Defaults to None.
dialect (torch_mlir.fx.OutputType | {"linalg-on-tensors", "torch", "tosa"}, optional):
Expand Down Expand Up @@ -199,22 +207,30 @@ def get_inputs():
if model is None:
raise ValueError(f"Model class '{model_class_name}' not found in {filepath}")

model_init_args = maybe_load_and_run_callable(
module,
init_args_fn_name,
default=tuple(),
error_msg=f"Init args function '{init_args_fn_name}' not found in {filepath}",
model_init_args = (
maybe_load_and_run_callable(
module,
init_args_fn_name,
default=tuple(),
error_msg=f"Init args function '{init_args_fn_name}' not found in {filepath}",
)
if model_init_args is None
else model_init_args
)
model_init_kwargs = maybe_load_and_run_callable(
module,
init_kwargs_fn_name,
default={},
error_msg=f"Init kwargs function '{init_kwargs_fn_name}' not found in {filepath}",
)
sample_args = load_and_run_callable(
module,
sample_args_fn_name,
f"Sample args function '{sample_args_fn_name}' not found in {filepath}",
sample_args = (
load_and_run_callable(
module,
sample_args_fn_name,
f"Sample args function '{sample_args_fn_name}' not found in {filepath}",
)
if sample_args is None
else sample_args
)
sample_kwargs = maybe_load_and_run_callable(
module,
Expand Down
37 changes: 27 additions & 10 deletions lighthouse/pipeline/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ class PipelineDescriptor:
...
"""

search_path = {
".py": "../schedule",
".yaml": "./descriptors",
}

def __init__(self, filename: str):
self.filename = filename
with open(filename, "r") as f:
Expand All @@ -41,21 +46,33 @@ def _normalize_include_path(self, filename) -> str:
* The path of the descriptor file that includes it. This allows for relative includes.
* The path of the Lighthouse schedule module, where all the standard pipelines are located.
"""
# First look in the same directory as the including file, to allow for relative includes.
filename = remove_args_and_opts(filename)
descriptor_path = os.path.normpath(os.path.dirname(self.filename))
file = os.path.join(descriptor_path, filename)
if os.path.exists(file):
return file

# If not found, look for an include path, based on the file extension.
file_ext = os.path.splitext(file)[1]
if file_ext not in self.search_path:
raise ValueError(
f"Included pipeline descriptor file does not exist: {filename} \
(searched in {descriptor_path})"
)

# If include path, look in the descriptor/schedule module path.
schedule_module_path = os.path.normpath(
os.path.join(os.path.dirname(__file__), "../schedule")
os.path.join(os.path.dirname(__file__), self.search_path[file_ext])
)
file = os.path.join(schedule_module_path, filename)
if os.path.exists(file):
return file

file = os.path.join(descriptor_path, filename)
if not os.path.exists(file):
file = os.path.join(schedule_module_path, filename)
if not os.path.exists(file):
raise ValueError(
f"Included pipeline descriptor file does not exist: {filename} \
(searched in {descriptor_path} and {schedule_module_path})"
)
return file
raise ValueError(
f"Included pipeline descriptor file does not exist: {filename} \
(searched in {descriptor_path} and {schedule_module_path})"
)

def _parse_stages(self) -> None:
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Pipeline:
- pass: "eliminate-empty-tensors"
- pass: "one-shot-bufferize{function-boundary-type-conversion=identity-layout-map bufferize-function-boundaries}"
- pass: "drop-equivalent-buffer-results"
- pass: "buffer-deallocation-pipeline"
- pass: "convert-bufferization-to-memref"

11 changes: 11 additions & 0 deletions lighthouse/pipeline/descriptors/llvm_lowering.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Pipeline:
- pass: "convert-linalg-to-loops"
- pass: "fold-memref-alias-ops"
- pass: "expand-strided-metadata"
- pass: "canonicalize"
- pass: "convert-vector-to-scf"
- pass: "lower-affine"
- pass: "convert-scf-to-cf"
- pass: "convert-vector-to-llvm"
- pass: "convert-to-llvm"
- pass: "reconcile-unrealized-casts"
Loading
Loading