Skip to content

Commit d3e3f02

Browse files
committed
print-after-all
1 parent 8f992e6 commit d3e3f02

4 files changed

Lines changed: 19 additions & 4 deletions

File tree

lighthouse/execution/runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def get_bench_wrapper_schedule(payload_func: str) -> ir.Module:
212212
"""
213213
Get a schedule that wraps the payload function in a benchmarking function.
214214
The function name is defined in Runner and will be used by the runner benchmark method.
215+
This schedule must apply to the module before any other in an optimizing pipeline.
215216
"""
216217
with ir.Location.unknown():
217218
with schedule_boilerplate(result_types=[transform.any_op_t()]) as (

lighthouse/pipeline/driver.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,13 @@ def add_stage(self, stage: lhs.Stage) -> None:
4646
# Users can derive their own classes from Stage and add them to the pipeline with this method.
4747
self.stages.append(stage)
4848

49-
def apply(self, module: ir.Module) -> ir.Module:
49+
def apply(self, module: ir.Module, print_after_all: bool = False) -> ir.Module:
5050
if module.context != self.context:
5151
raise ValueError("Module context does not match driver context.")
5252
for stage in self.stages:
5353
module = stage.apply(module)
54+
if print_after_all:
55+
print(f"After stage {stage}:\n{module}")
5456
return module
5557

5658
def __len__(self):
@@ -173,14 +175,14 @@ def reset(self) -> None:
173175
self.module = None
174176
self.pipeline_fixed = False
175177

176-
def run(self) -> ir.Module:
178+
def run(self, print_after_all: bool = False) -> ir.Module:
177179
if self.module is None:
178180
raise ValueError("Module must not be empty.")
179181
if len(self.pipeline) == 0:
180182
raise ValueError("Pipeline must have at least one stage.")
181183

182184
# Apply the whole pipeline.
183-
self.pipeline.apply(self.module)
185+
self.pipeline.apply(self.module, print_after_all=print_after_all)
184186

185187
# The pipeline is now fixed and cannot be modified until reset is called.
186188
# This is to prevent accidental modifications to the pipeline after it has been run,

lighthouse/pipeline/stage.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ class PassStage(Stage):
147147
def __init__(self, passes: list[Pass], context: ir.Context):
148148
self.context = context
149149
self.pm = PassManager("builtin.module", self.context)
150+
self.passes = passes
150151
add_bundle(self.pm, passes)
151152

152153
def apply(self, module: ir.Module) -> ir.Module:
@@ -157,6 +158,8 @@ def apply(self, module: ir.Module) -> ir.Module:
157158
self.pm.run(module.operation)
158159
return module
159160

161+
def __str__(self) -> str:
162+
return f"PassStage({[str(p) for p in self.passes]})"
160163

161164
class TransformStage(Stage):
162165
"""
@@ -218,3 +221,6 @@ def apply(self, module: ir.Module) -> ir.Module:
218221
raise ValueError("Missing module to apply transformations to.")
219222
self.schedule.apply(module.operation)
220223
return module
224+
225+
def __str__(self) -> str:
226+
return f"TransformStage({self.module})"

tools/lh-run

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,11 @@ if __name__ == "__main__":
111111
default=0,
112112
help="Print the Nth tensor. Default is 0 (no print).",
113113
)
114+
Parser.add_argument(
115+
"--print-mlir-after-all",
116+
action=argparse.BooleanOptionalAction,
117+
help="Whether to print the MLIR module after all stages. Default is False.",
118+
)
114119
args = Parser.parse_args()
115120

116121
# Initialize the random seed, for stable tests
@@ -125,6 +130,7 @@ if __name__ == "__main__":
125130

126131
if args.benchmark:
127132
# Calling the benchmark wrapper, not the entry point.
133+
# FIXME: Eliminate this cross-dependency between the Runner and the Driver.
128134
with driver.context:
129135
lh_dialects.register_and_load()
130136
bench_wrapper = Runner.get_bench_wrapper_schedule(args.entry_point)
@@ -140,7 +146,7 @@ if __name__ == "__main__":
140146
driver.add_stages(add_lower_to_llvm_stages())
141147

142148
# Run the pipeline to get the optimized module.
143-
optimized_module = driver.run()
149+
optimized_module = driver.run(print_after_all=args.print_mlir_after_all)
144150
if args.print_optimized_module:
145151
print(optimized_module)
146152

0 commit comments

Comments
 (0)