Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ repos:

- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.14.5
rev: v0.15.9
hooks:
# Run the linter.
- id: ruff-check
Expand Down
3 changes: 2 additions & 1 deletion examples/cpu/x86/matmul.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# RUN: %PYTHON %s --dump-kernel=vectorized | FileCheck %s
# RUN: %PYTHON %s --dump-kernel=vectorized --tile-size=64 | FileCheck %s
# RUN: %PYTHON %s --dump-kernel=vectorized --dtype=bf16 --avx512 | FileCheck %s --check-prefix=AVX512
# RUN: %PYTHON %s --dump-kernel=vectorized --dtype=bf16 --avx512 \
# RUN: | FileCheck %s --check-prefix=AVX512

# CHECK: vector.broadcast
# CHECK: vector.fma
Expand Down
4 changes: 3 additions & 1 deletion examples/feed-forward-mpi/feed-forward-mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ def parse_cla():
type=int,
default=[WORLD_SIZE],
nargs="+",
help="The shape of the device grid (1 or 2 dimensions). The product of the grid dimensions must match the number of MPI ranks. Use '0' if 2d grid dimensions should be inferred automatically.",
help="The shape of the device grid (1 or 2 dimensions). The product of the grid dimensions \
must match the number of MPI ranks. Use '0' if 2d grid dimensions should be inferred \
automatically.",
)
parser.add_argument(
"--nruns",
Expand Down
3 changes: 2 additions & 1 deletion examples/ingress/convert-kernel-bench-to-mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
)
sys.exit(1)


# ruff: disable[E501]
# The following kernels won't get converted:
level1, level2 = Path("level1"), Path("level2")
ignore_list = [
Expand Down Expand Up @@ -121,6 +121,7 @@
level2
/ "92_Conv2d_GroupNorm_Tanh_HardSwish_ResidualAdd_LogSumExp.py", # error: failed to legalize operation 'torch.constant.int'
]
# ruff: enable[E501]


@dataclass
Expand Down
11 changes: 6 additions & 5 deletions examples/ingress/torch/mlp_from_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
without initializing the model class on the user's side.

The script uses 'lighthouse.ingress.torch.import_from_file' function that
takes a path to a Python file containing the model definition (a Python class derived from 'nn.Module'),
along with the names of functions to get model init arguments and sample inputs. The function
imports the model class on its own, initializes it, and passes it to torch_mlir
to get a MLIR module in the specified dialect.
takes a path to a Python file containing the model definition (a Python class
derived from 'nn.Module'), along with the names of functions to get model init
arguments and sample inputs. The function imports the model class on its own,
initializes it, and passes it to torch_mlir to get a MLIR module in the specified
dialect.

The script uses the model from 'MLPModel/model.py' as an example.
"""
Expand Down Expand Up @@ -39,7 +40,7 @@
model_path, # Path to the Python file containing the model
model_class_name="MLPModel", # Name of the PyTorch nn.Module class to convert
init_args_fn_name="get_init_inputs", # Function that returns args for model.__init__()
sample_args_fn_name="get_sample_inputs", # Function that returns sample inputs to pass to 'model(...)'
sample_args_fn_name="get_sample_inputs", # Function that returns sample inputs to pass to model
dialect="linalg-on-tensors", # Target MLIR dialect (linalg ops on tensor types)
ir_context=ir_context, # MLIR context for the conversion
)
Expand Down
7 changes: 4 additions & 3 deletions examples/llama/ref_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
# This software may be used and distributed in accordance with
# the terms of the Llama 3 Community License Agreement.


## This is a modified version of the LLaMA 3 model implementation.
## It doesn't use any FairScale components
# This is a modified version of the LLaMA 3 model implementation.
# It doesn't use any FairScale components

import math as pymath
from dataclasses import dataclass
Expand Down
11 changes: 7 additions & 4 deletions examples/llama/test_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,8 @@ def get_rotary_emb(
static_output_shape=xq_reshaped_shape,
)

# View xq as complex: (batch, seq_len, n_heads, head_dim//2, 2) -> (batch, seq_len, n_heads, head_dim//2) complex
# View xq as complex: (batch, seq_len, n_heads, head_dim//2, 2)
# -> (batch, seq_len, n_heads, head_dim//2) complex
xq_complex_shape = [batch, seq_len, n_heads, head_dim // 2]
xq_complex_uninit = tensor.empty(xq_complex_shape, ir.ComplexType.get(elty))
xq_complex = get_view_as_complex(xq_reshaped, xq_complex_uninit)
Expand Down Expand Up @@ -757,7 +758,8 @@ def get_attention(

# Compute attention scores: matmul(xq, keys.transpose(-2, -1))
# xq_transposed: (batch, n_heads, seq_len, head_dim)
# keys_transposed: (batch, n_heads, seq_len, head_dim) -> transpose to (batch, n_heads, head_dim, seq_len)
# keys_transposed: (batch, n_heads, seq_len, head_dim) -> transpose to
# (batch, n_heads, head_dim, seq_len)
# scores: (batch, n_heads, seq_len, seq_len)
scores_shape = [batch, n_heads, seq_len, seq_len]
scores_uninit = tensor.empty(scores_shape, elty)
Expand Down Expand Up @@ -964,8 +966,9 @@ def get_transformer(
get_outer: torch.outer,
get_linear: torch.nn.functional.linear,
get_repeat_kv: repeat_kv,
get_l2_norm: lambda x, eps: x
* torch.rsqrt(torch.mean(x.pow(2), dim=-1, keepdim=True) + eps),
get_l2_norm: lambda x, eps: (
x * torch.rsqrt(torch.mean(x.pow(2), dim=-1, keepdim=True) + eps)
),
get_rotary_emb: apply_rotary_emb,
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,19 @@ def example_payload() -> Module:
# NB: Do the CHECKing on the transformed output:
# CHECK-LABEL: result of applying schedule to payload
# CHECK: func.func @fold_add_on_two_matmuls
# CHECK-SAME: (%[[MATRIX_A:.*]]: {{.*}}, %[[MATRIX_B:.*]]: {{.*}}, %[[WEIGHTS:.*]]: {{.*}})
# CHECK-SAME: (%[[MATRIX_A:.*]]: {{.*}}, %[[MATRIX_B:.*]]: {{.*}},
# CHECK-SAME: %[[WEIGHTS:.*]]: {{.*}})
@func.func(matrixType, matrixType, matrixType)
def fold_add_on_two_matmuls(matrixA, matrixB, weights):
empty = tensor.empty(matrixType.shape, matrixType.element_type)
c0 = arith.constant(F32Type.get(), 0.0)
# CHECK: %[[ZERO_INIT:.*]] = linalg.fill
zero_init = linalg.fill(c0, outs=[empty])
# CHECK: %[[A_X_WEIGHTS:.*]] = linalg.matmul ins(%[[MATRIX_A]], %[[WEIGHTS]]{{.*}}) outs(%[[ZERO_INIT]]
# CHECK: %[[A_X_WEIGHTS:.*]] = linalg.matmul ins(%[[MATRIX_A]], %[[WEIGHTS]]
# CHECK-SAME: outs(%[[ZERO_INIT]]
A_x_weights = linalg.matmul(matrixA, weights, outs=[zero_init])
# CHECK: %[[RES:.*]] = linalg.matmul ins(%[[MATRIX_B]], %[[WEIGHTS]]{{.*}}) outs(%[[A_X_WEIGHTS]]
# CHECK: %[[RES:.*]] = linalg.matmul ins(%[[MATRIX_B]], %[[WEIGHTS]]
# CHECK-SAME: outs(%[[A_X_WEIGHTS]]
B_x_weights = linalg.matmul(matrixB, weights, outs=[zero_init])
# CHECK-NOT: linalg.add
added = linalg.add(A_x_weights, B_x_weights, outs=[empty])
Expand Down
3 changes: 2 additions & 1 deletion examples/xegpu/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
# RUN: %PYTHON %s --dump-kernel=xegpu-wg --hidden-sizes 1024 1024 --relu | FileCheck %s
# RUN: %PYTHON %s --dump-kernel=xegpu-wg --hidden-sizes 1024 1024 --bias | FileCheck %s
# RUN: %PYTHON %s --dump-kernel=xegpu-wg --hidden-sizes 1024 1024 --accumulate-c | FileCheck %s
# RUN: %PYTHON %s --dump-kernel=xegpu-wg --hidden-sizes 1024 1024 --bias --relu --accumulate-c | FileCheck %s
# RUN: %PYTHON %s --dump-kernel=xegpu-wg --hidden-sizes 1024 1024 --bias --relu --accumulate-c \
# RUN: | FileCheck %s
# CHECK: module attributes {gpu.container_module} {

"""
Expand Down
6 changes: 4 additions & 2 deletions lighthouse/dialects/transform_tune_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def wrapper(*args, **kwargs):
func_def_ast = func_ast.body[0]

# TODO: in case of multiple decorators, remove just @KnobValue.ast_rewrite
func_def_ast.decorator_list.clear() # Remove the decorator to avoid infinite recursion.
# Remove the decorator to avoid infinite recursion.
func_def_ast.decorator_list.clear()
if in_exprs:
# Apply the rewriting of `in` expressions.
func_def_ast.body = [
Expand All @@ -97,7 +98,8 @@ def wrapper(*args, **kwargs):
mod = compile(func_ast, filename=source_file, mode="exec")
frame = inspect.currentframe()
assert frame and frame.f_back
# Make the original function's globals and locals available to the rewritten function.
# Make the original function's globals and locals available
# to the rewritten function.
temp_globals = frame.f_back.f_globals.copy()
temp_globals |= frame.f_back.f_locals.copy()
temp_locals = frame.f_back.f_locals.copy()
Expand Down
6 changes: 4 additions & 2 deletions lighthouse/execution/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ class KernelArgument:
"""
A kernel argument, initialized according to the specified type.
The argument value is stored in the `arg` attribute, which is a numpy array.
It will be initialized at construction time, so that the argument value is ready to use after construction.
It will be initialized at construction time, so that the argument value is ready
to use after construction.

Arguments are:
* dims: list of dimensions of the argument > 0 (e.g., [M, N, K])
* element_type: NumPy data type of the argument (e.g., np.float32, np.int64, "f16", "bf16", etc.)
* element_type: NumPy data type of the argument
(e.g., np.float32, np.int64, "f16", "bf16", etc.)
* init_type: type of initialization (InitType)
TODO: Add support for distribution parameters on random.
"""
Expand Down
3 changes: 2 additions & 1 deletion lighthouse/execution/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def _find_shared_libs(self, shared_libs: list[str]) -> list[str]:

def _get_engine(self) -> ExecutionEngine:
"""
Get an execution engine for the given payload module, loading the necessary shared libraries.
Get an execution engine for the given payload module,
loading the necessary shared libraries.
"""
execution_engine = ExecutionEngine(
self.payload, opt_level=self.opt_level, shared_libs=self.shared_libs
Expand Down
3 changes: 2 additions & 1 deletion lighthouse/ingress/mlir_gen/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def csints(s: str) -> Sequence[int]:
"--layers",
type=csints,
default=(128, 256, 512),
help="the number of neurons in each layer - the first layer is the input layer and the last layer is the output layer",
help="the number of neurons in each layer - the first layer is the input layer \
and the last layer is the output layer",
)

parser.add_argument(
Expand Down
11 changes: 6 additions & 5 deletions lighthouse/ingress/torch/importer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import importlib
import importlib.util
from pathlib import Path
from typing import Iterable, Mapping
Expand Down Expand Up @@ -52,10 +51,11 @@ def import_from_model(
``OutputType.LINALG_ON_TENSORS``.
ir_context (ir.Context, optional): An optional MLIR context to use for parsing
the module. If not provided, the module is returned as a string.
**kwargs: Additional keyword arguments passed to the ``torch_mlir.fx.export_and_import`` function.
**kwargs: Additional keyword arguments passed to the ``torch_mlir.fx.export_and_import``.

Returns:
str | ir.Module: The imported MLIR module as a string or an ir.Module if `ir_context` is provided.
str | ir.Module: The imported MLIR module as a string or an ir.Module
if `ir_context` is provided.

Examples:
>>> import torch
Expand Down Expand Up @@ -142,10 +142,11 @@ def import_from_file(
``OutputType.LINALG_ON_TENSORS``.
ir_context (ir.Context, optional): An optional MLIR context to use for parsing
the module. If not provided, the module is returned as a string.
**kwargs: Additional keyword arguments passed to the ``torch_mlir.fx.export_and_import`` function.
**kwargs: Additional keyword arguments passed to the ``torch_mlir.fx.export_and_import``.

Returns:
str | ir.Module: The imported MLIR module as a string or an ir.Module if `ir_context` is provided.
str | ir.Module: The imported MLIR module as a string or an ir.Module
if `ir_context` is provided.

Examples:
Given a file `path/to/model_file.py` with the following content:
Expand Down
9 changes: 6 additions & 3 deletions lighthouse/pipeline/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ def __init__(self, filename: str):
def _normalize_include_path(self, filename) -> str:
"""
Finds the file in some standard locations, in order:
* The path of the descriptor file that includes it. This allows for relative includes.
* The path of the Lighthouse schedule module, where all the standard pipelines are located.
* The path of the descriptor file that includes it.
This allows for relative includes.
* The path of the Lighthouse schedule module,
where all the standard pipelines are located.
"""
filename = remove_args_and_opts(filename)
descriptor_path = os.path.normpath(os.path.dirname(self.filename))
Expand Down Expand Up @@ -93,7 +95,8 @@ def _parse_stages(self) -> None:

else:
raise ValueError(
f"Invalid stage in pipeline description: {stage}. Must be one of 'pass', 'transform', 'bundle' or 'include'."
f"Invalid stage in pipeline description: {stage}. Must be one of 'pass', \
'transform', 'bundle' or 'include'."
)

def _include_pipeline(self, filename: str) -> None:
Expand Down
23 changes: 15 additions & 8 deletions lighthouse/pipeline/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
class PipelineDriver:
"""
A simple driver that runs the optimization pipeline on a given workload.
Helps create a list of Stages (passes, transforms, bundles) to apply to the module, and runs them in sequence.
Helps create a list of Stages (passes, transforms, bundles) to apply to the module,
and runs them in sequence.
"""

stages: list[lhs.Stage]
Expand All @@ -24,7 +25,7 @@ def add_pass(self, name: str) -> None:
self.stages.append(lhs.PassStage([lhs.Pass(name)], self.context))

def add_transform(self, stage: str | ir.Module) -> None:
# Transform will figure out if this is MLIR, Python or Module, and will handle it accordingly.
# Transform will figure out if it is MLIR, Python or Module, and will handle it accordingly.
if isinstance(stage, ir.Module):
# This is a transform already in module form. Assume it has been verified already.
if stage.context != self.context:
Expand All @@ -43,7 +44,8 @@ def add_bundle(self, name: str) -> None:

def add_stage(self, stage: lhs.Stage) -> None:
# A generit stage that isn't covered by the existing infrastructure.
# Users can derive their own classes from Stage and add them to the pipeline with this method.
# Users can derive their own classes from Stage and add them to
# the pipeline with this method.
self.stages.append(stage)

def apply(self, module: ir.Module, print_after_all: bool = False) -> ir.Module:
Expand All @@ -65,7 +67,8 @@ def reset(self):
class TransformDriver(PipelineDriver):
"""
A simple driver that runs a sequence of transform modules on a given workload.
This is a thin wrapper around PipelineDriver that is used to run a sequence of transform modules on a given workload.
This is a thin wrapper around PipelineDriver that is used to run
a sequence of transform modules on a given workload.
"""

def __init__(self, schedules: list[ir.Module]):
Expand All @@ -85,14 +88,18 @@ class CompilerDriver:
This is a high-level interface that abstracts away the details of the optimization pipeline,
and provides a simple interface for running the pipeline on a given workload.

The pipeline is flexible until the first time it is run, at which point it becomes fixed and cannot be modified until reset is called.
This is to allow running the same pipeline on different modules, without accidentally modifying the pipeline after it has been run.
The pipeline is flexible until the first time it is run, at which point it becomes fixed and
cannot be modified until reset is called.
This is to allow running the same pipeline on different modules, without accidentally modifying
the pipeline after it has been run.

Calling reset() will clear the pipeline and the module, allowing for a new pipeline to be constructed and run on a new module.
Calling reset() will clear the pipeline and the module, allowing for a new pipeline
to be constructed and run on a new module.
"""

def __init__(self, filename: str, stages: list[str] = []):
# The context is shared across the entire pipeline, and is used to create the PassManager and Transform Schedules.
# The context is shared across the entire pipeline, and is used to create the PassManager
# and Transform Schedules.
# The module is owned by the Driver to encapsulate its use through the pipeline.
# It is returned at the end of run() after being transformed by the stages in the pipeline.
self.context = ir.Context()
Expand Down
23 changes: 15 additions & 8 deletions lighthouse/pipeline/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def __str__(self) -> str:
# Predefined pass bundles for common transformations.
# These are not exhaustive and can be extended as needed.
# The idea is to group together passes that are commonly used together in a pipeline,
# so that they can be easily added to a PassManager or Transform Schedule with a single function call.
# so that they can be easily added to a PassManager or Transform Schedule with
# a single function call.
# FIXME: Deprecate bundles in favor of YAML pipeline descriptors.
PassBundles = {
# All in one bufferization bundle.
Expand Down Expand Up @@ -127,21 +128,24 @@ def __str__(self) -> str:

class Stage:
"""
A stage in the optimization pipeline. Each stage will apply a specific set of transformations to the module,
and will keep track of the current state of the module after the transformations are applied.
A stage in the optimization pipeline. Each stage will apply a specific
set of transformations to the module, and will keep track of the current
state of the module after the transformations are applied.
"""

@abstractmethod
def apply(self, module: ir.Module) -> ir.Module:
"""
Apply the transformations for this stage to the given module, and return the transformed module.
Apply the transformations for this stage to the given module,
and return the transformed module.
"""
pass


class PassStage(Stage):
"""
A stage that applies a predefined set of passes to the module. This is a simple wrapper around a PassManager.
A stage that applies a predefined set of passes to the module.
This is a simple wrapper around a PassManager.
"""

def __init__(self, passes: list[Pass], context: ir.Context):
Expand Down Expand Up @@ -197,11 +201,13 @@ def __init__(self, transform: Transform | ir.Module, context: ir.Context):
spec.loader.exec_module(transform_module)
if not hasattr(transform_module, transform.generator):
raise ValueError(
f"Transform module '{transform.filename}' does not define a '{transform.generator}' generator function."
f"Transform module '{transform.filename}' does not define \
a '{transform.generator}' generator function."
)
self.generator = getattr(transform_module, transform.generator)

# Run the function with the dictionary as the options that will create the named sequence.
# Run the function with the dictionary as the options
# that will create the named sequence.
with context, ir.Location.unknown():
self.module = self.generator(transform.options)
else:
Expand All @@ -210,7 +216,8 @@ def __init__(self, transform: Transform | ir.Module, context: ir.Context):
# Check if the imported module contains at least one schedule
if TransformStage.MLIR_ATTRIBUTE not in self.module.operation.attributes:
raise ValueError(
f"Transform module {transform.filename} does not define a {TransformStage.MLIR_ATTRIBUTE} attribute."
f"Transform module {transform.filename} does not define \
a {TransformStage.MLIR_ATTRIBUTE} attribute."
)

# Assume the first (or only) sequence.
Expand Down
Loading
Loading