From aa8c12a716d6164cec9ea755accc0cfb6e0c934e Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Tue, 2 Jun 2026 09:28:23 +0000 Subject: [PATCH 1/5] feat(compiler): Add custom LLVM pass pipeline and plugin support to JIT Signed-off-by: fsx950223 --- python/flydsl/compiler/backends/rocm.py | 29 ++- python/flydsl/compiler/external_llvm.py | 130 +++++++++++++ python/flydsl/compiler/jit_function.py | 124 +++++++++++- python/flydsl/utils/env.py | 12 ++ tests/kernels/test_llvm_pass_plugin_e2e.py | 208 +++++++++++++++++++++ tests/unit/test_llvm_pass_pipeline.py | 75 ++++++++ 6 files changed, 569 insertions(+), 9 deletions(-) create mode 100644 tests/kernels/test_llvm_pass_plugin_e2e.py create mode 100644 tests/unit/test_llvm_pass_pipeline.py diff --git a/python/flydsl/compiler/backends/rocm.py b/python/flydsl/compiler/backends/rocm.py index c32a328bf..ace00b648 100644 --- a/python/flydsl/compiler/backends/rocm.py +++ b/python/flydsl/compiler/backends/rocm.py @@ -33,11 +33,9 @@ def _format_pass_opts(opts: dict) -> str: """Format {key: value, ...} as 'key=value key2=value2' for MLIR pass options.""" return " ".join(f"{k}={v}" for k, v in opts.items()) - def _pipeline_parts(self, *, compile_hints: dict) -> Tuple[List[str], str]: - chip = self.target.arch + def _bin_cli_opts(self, *, compile_hints: dict) -> List[str]: waves_per_eu = compile_hints.get("waves_per_eu") maxnreg = compile_hints.get("maxnreg") - bin_cli_opts = [] if env.debug.enable_debug_info: bin_cli_opts.append("-g") @@ -45,9 +43,12 @@ def _pipeline_parts(self, *, compile_hints: dict) -> Tuple[List[str], str]: bin_cli_opts.append(f"--amdgpu-waves-per-eu={waves_per_eu}") if maxnreg: bin_cli_opts.append(f"--amdgpu-num-vgpr={maxnreg}") + return bin_cli_opts - rocdl_opts = { - "O": 2, + def _rocdl_opts(self, *, compile_hints: dict, opt_level: int = 2) -> dict: + chip = self.target.arch + return { + "O": opt_level, "abi": 600, "chip": chip, "correct-sqrt": "true", @@ -61,6 +62,24 @@ def _pipeline_parts(self, *, compile_hints: dict) -> Tuple[List[str], str]: "wave64": "false" if is_rdna_arch(chip) else "true", } + def llvm_recodegen_fragments(self, *, compile_hints: dict, opt_level: int = 0) -> Tuple[str, str]: + """Fragments to re-codegen an already-LLVM-dialect ``gpu.module`` that has + NO target attached: attach a ROCDL target at ``opt_level`` then emit the + device binary. Used by the custom-LLVM-pass path, which has already run + its own ``opt`` pipeline, so codegen runs at ``O=0`` to avoid re-optimizing. + """ + rocdl_opts = self._rocdl_opts(compile_hints=compile_hints, opt_level=opt_level) + bin_cli_opts = self._bin_cli_opts(compile_hints=compile_hints) + attach_fragment = f"rocdl-attach-target{{{self._format_pass_opts(rocdl_opts)}}}" + binary_fragment = f'gpu-module-to-binary{{format=fatbin opts="{" ".join(bin_cli_opts)}"}}' + return attach_fragment, binary_fragment + + def _pipeline_parts(self, *, compile_hints: dict) -> Tuple[List[str], str]: + chip = self.target.arch + + bin_cli_opts = self._bin_cli_opts(compile_hints=compile_hints) + rocdl_opts = self._rocdl_opts(compile_hints=compile_hints, opt_level=2) + pre_binary_fragments = [ "fly-rewrite-func-signature", "fly-canonicalize", diff --git a/python/flydsl/compiler/external_llvm.py b/python/flydsl/compiler/external_llvm.py index 69993a748..eac41d118 100644 --- a/python/flydsl/compiler/external_llvm.py +++ b/python/flydsl/compiler/external_llvm.py @@ -204,3 +204,133 @@ def run_mlir_opt(*, pass_pipeline: str, input_path: Path, output_path: Path) -> finally: if tmp_dir_obj is not None: tmp_dir_obj.cleanup() + + +def llvm_opt_fingerprint(pipeline: str, plugins: Optional[list] = None) -> str: + """Cache fingerprint for a custom LLVM-opt configuration: the pipeline + string plus each plugin's path and content hash, so editing a plugin .so + (or the pipeline) invalidates cached artifacts.""" + parts = [f"llvm-opt:{pipeline}"] + for p in plugins or []: + path = Path(p).expanduser() + try: + parts.append(f"{path}:{_file_hash(path.resolve())}") + except OSError: + parts.append(f"{path}:") + return ";".join(parts) + + +def _run_tool(cmd: list, *, prefix: Path, what: str, work_dir: Path) -> None: + try: + subprocess.run(cmd, check=True, capture_output=True, text=True, timeout=600, env=_subprocess_env(prefix)) + except subprocess.TimeoutExpired as exc: + raise ExternalLLVMError( + f"{what} timed out after 600s.\ncommand: {' '.join(cmd)}\nwork_dir: {work_dir}" + ) from exc + except subprocess.CalledProcessError as exc: + raise ExternalLLVMError( + f"{what} failed.\nllvm_dir: {prefix}\ncommand: {' '.join(cmd)}\n" + f"work_dir: {work_dir}\nstdout:\n{exc.stdout}\nstderr:\n{exc.stderr}" + ) from exc + + +def run_llvm_opt_then_binary( + module: ir.Module, + *, + llvm_ir: str, + attach_fragment: str, + binary_fragment: str, + pipeline: str, + plugins: Optional[list] = None, + llvm_options: Optional[dict] = None, + work_dir: Optional[Path] = None, + stage_prefix: str = "llvm_opt", +) -> None: + """Run a custom LLVM new-PM pass pipeline on the device kernel's (pre-link) + LLVM IR, then re-codegen the device binary and splice it back into *module*. + + Flow: ``opt --passes`` (with optional ``--load-pass-plugin``) on ``llvm_ir`` + -> ``mlir-translate --import-llvm`` -> wrap into a ``gpu.module`` -> external + ``mlir-opt`` running ``attach_fragment`` (ROCDL target at O=0) then + ``binary_fragment`` (``gpu-module-to-binary``) -> replace the in-process + ``gpu.module`` with the produced ``gpu.binary``. + """ + prefix = _llvm_dir() + opt = _tool(prefix, "opt") + mlir_translate = _tool(prefix, "mlir-translate") + mlir_opt = _tool(prefix, "mlir-opt") + + gpu_module = _single_top_level_op(module, "gpu.module") + name = _symbol_name(gpu_module) + data_layout = None + if "llvm.data_layout" in gpu_module.attributes: + try: + data_layout = ir.StringAttr(gpu_module.attributes["llvm.data_layout"]).value + except Exception: + data_layout = None + + llvm_cli_args = _format_llvm_cli_options(llvm_options) if llvm_options else [] + + tmp_dir_obj = None + if work_dir is None: + tmp_dir_obj = tempfile.TemporaryDirectory(prefix="flydsl_llvm_opt_") + work_dir = Path(tmp_dir_obj.name) + else: + work_dir.mkdir(parents=True, exist_ok=True) + + in_ll = work_dir / f"{stage_prefix}_pre_opt.ll" + out_ll = work_dir / f"{stage_prefix}_post_opt.ll" + imported_path = work_dir / f"{stage_prefix}_imported.mlir" + wrapped_path = work_dir / f"{stage_prefix}_wrapped.mlir" + bin_path = work_dir / f"{stage_prefix}_binary.mlir" + + try: + in_ll.write_text(llvm_ir, encoding="utf-8") + + plugin_args = [f"--load-pass-plugin={Path(p).expanduser()}" for p in (plugins or [])] + _run_tool( + [str(opt), str(in_ll), "-S", f"--passes={pipeline}", *plugin_args, *llvm_cli_args, "-o", str(out_ll)], + prefix=prefix, + what="LLVM opt pass pipeline", + work_dir=work_dir, + ) + + _run_tool( + [str(mlir_translate), "--import-llvm", str(out_ll), "-o", str(imported_path)], + prefix=prefix, + what="mlir-translate --import-llvm", + work_dir=work_dir, + ) + + # Wrap the re-imported LLVM-dialect IR back into a gpu.module (no target; + # attach_fragment adds it). The original gpu.module's data layout is + # re-applied; gpu-module-to-binary will produce gpu.binary @. + imported = ir.Module.parse(imported_path.read_text(encoding="utf-8"), context=module.context) + body = "\n".join(op.operation.get_asm() for op in imported.body.operations) + dl_attr = f' attributes {{llvm.data_layout = "{data_layout}"}}' if data_layout else "" + wrapped_path.write_text( + f"module attributes {{gpu.container_module}} {{\n" f" gpu.module @{name}{dl_attr} {{\n{body}\n }}\n}}\n", + encoding="utf-8", + ) + + _run_tool( + [ + str(mlir_opt), + str(wrapped_path), + f"--pass-pipeline=builtin.module({attach_fragment},{binary_fragment})", + *llvm_cli_args, + "-o", + str(bin_path), + ], + prefix=prefix, + what="external gpu-module-to-binary codegen", + work_dir=work_dir, + ) + + if not bin_path.is_file(): + raise ExternalLLVMError(f"external codegen did not create output file: {bin_path}") + binary_module = ir.Module.parse(bin_path.read_text(encoding="utf-8"), context=module.context) + _replace_gpu_module_with_binary_op(module, binary_module) + finally: + if tmp_dir_obj is not None: + tmp_dir_obj.cleanup() diff --git a/python/flydsl/compiler/jit_function.py b/python/flydsl/compiler/jit_function.py index d4adfda83..4f2edc29b 100644 --- a/python/flydsl/compiler/jit_function.py +++ b/python/flydsl/compiler/jit_function.py @@ -122,6 +122,8 @@ def _create_mlir_context(*, load_dialects=True): "HSA_OVERRIDE_GFX_VERSION", "FLYDSL_DEBUG_ENABLE_DEBUG_INFO", "FLYDSL_EXTRA_SOURCE_DIRS", + "FLYDSL_COMPILE_LLVM_PASS_PIPELINE", + "FLYDSL_COMPILE_LLVM_PASS_PLUGINS", ) @@ -745,6 +747,20 @@ class PipelineConfig: binary_fragment: Optional[str] llvm_opts: Optional[dict] external: bool + llvm_pass_pipeline: str = "" + llvm_pass_plugins: Optional[list] = None + + +def _effective_llvm_pass_config(hints: dict): + """Resolve the custom LLVM pass pipeline + plugins, preferring the + @flyc.jit compile_hints over the FLYDSL_COMPILE_LLVM_PASS_* env vars.""" + pipeline = hints.get("llvm_pass_pipeline") + if pipeline is None: + pipeline = env.compile.llvm_pass_pipeline + plugins = hints.get("llvm_pass_plugins") + if plugins is None: + plugins = env.compile.llvm_pass_plugins + return (pipeline or "").strip(), list(plugins or []) def _pipeline_fragments_for_mode(backend) -> PipelineConfig: @@ -753,6 +769,22 @@ def _pipeline_fragments_for_mode(backend) -> PipelineConfig: hints = CompilationContext.get_compile_hints() llvm_opts = hints.get("llvm_options") + llvm_pass_pipeline, llvm_pass_plugins = _effective_llvm_pass_config(hints) + + # Custom LLVM pass pipeline: split off the binary fragment so we can extract + # LLVM IR, run `opt`, and re-codegen externally (see MlirCompiler.compile). + if llvm_pass_pipeline: + pre_binary_fragments, binary_fragment = backend.external_binary_pipeline_fragments(compile_hints=hints) + return PipelineConfig( + fragments=[*pre_binary_fragments, binary_fragment], + pre_binary=pre_binary_fragments, + binary_fragment=binary_fragment, + llvm_opts=llvm_opts, + external=False, + llvm_pass_pipeline=llvm_pass_pipeline, + llvm_pass_plugins=llvm_pass_plugins, + ) + if _use_external_binary_codegen(): pre_binary_fragments, binary_fragment = backend.external_binary_pipeline_fragments(compile_hints=hints) return PipelineConfig( @@ -809,6 +841,9 @@ def compile( "use embedded codegen for kernels that require #fly.explicit_module." ) + if cfg.llvm_pass_pipeline and link_libs: + raise RuntimeError("custom llvm_pass_pipeline does not support extern link_libs yet.") + if link_libs: link_opt = _format_link_lib_options(link_libs) fragments, found_attach_target = _append_link_lib_options_to_attach_targets(fragments, link_opt) @@ -826,6 +861,10 @@ def compile( dump_dir = Path(env.debug.dump_dir).resolve() with _llvm_ctx: + if cfg.llvm_pass_pipeline: + return cls._compile_with_llvm_opt( + module, backend, cfg, func_name=func_name, dump_enabled=dump_enabled, dump_dir=dump_dir + ) if dump_enabled: asm = module.operation.get_asm(enable_debug_info=True) kernel_names = _infer_kernel_names_from_asm(asm) @@ -937,6 +976,53 @@ def compile( return module + @classmethod + def _compile_with_llvm_opt( + cls, module, backend, cfg, *, func_name: str, dump_enabled: bool, dump_dir: Path + ) -> ir.Module: + """Custom-LLVM-pass path: run the pre-binary fragments in-process, extract + the device LLVM IR, run the user's ``opt`` pipeline (+ plugins) on it, then + re-codegen the binary externally and splice it back.""" + from .external_llvm import run_llvm_opt_then_binary + from .kernel_function import CompilationContext + + hints = CompilationContext.get_compile_hints() + work_dir = None + if dump_enabled: + asm = module.operation.get_asm(enable_debug_info=True) + kernel_names = _infer_kernel_names_from_asm(asm) + subdir = kernel_names[0] if len(kernel_names) == 1 else (func_name or "module") + work_dir = dump_dir / _sanitize_path_component(subdir) + print(f"[flydsl.compile] FLYDSL_DUMP_IR=1 (llvm_pass_pipeline) dir={work_dir}") + + # Run everything up to (but not including) gpu-module-to-binary in-process. + _run_pipeline( + module, + cfg.pre_binary, + verifier=env.debug.enable_verifier, + print_after_all=env.debug.print_after_all, + ) + + llvm_ir = _extract_llvm_ir(module) + if llvm_ir is None: + raise FlyDSLCompileError( + "llvm_pass_pipeline is set but the device LLVM IR could not be extracted from the gpu.module." + ) + + attach_fragment, binary_fragment = backend.llvm_recodegen_fragments(compile_hints=hints, opt_level=0) + run_llvm_opt_then_binary( + module, + llvm_ir=llvm_ir, + attach_fragment=attach_fragment, + binary_fragment=binary_fragment, + pipeline=cfg.llvm_pass_pipeline, + plugins=cfg.llvm_pass_plugins, + llvm_options=cfg.llvm_opts, + work_dir=work_dir, + ) + module.operation.verify() + return module + class JitCacheManager: """Directory-based cache manager with multi-process safety. @@ -1344,6 +1430,14 @@ def _resolve_and_make_cache_key(self, bound_args): key_parts = [("_env_", _cache_invalidating_env_values()), ("_target_", self._backend_target)] if self.compile_hints: key_parts.append(("_hints_", tuple(sorted((k, str(v)) for k, v in self.compile_hints.items())))) + # Fold the effective custom LLVM pass pipeline + plugin content hashes + # (from hints or env) into the key so editing a plugin .so or the + # pipeline invalidates cached artifacts. + eff_pipeline, eff_plugins = _effective_llvm_pass_config(self.compile_hints) + if eff_pipeline: + from .external_llvm import llvm_opt_fingerprint + + key_parts.append(("_llvm_pass_", llvm_opt_fingerprint(eff_pipeline, eff_plugins))) for name, arg in bound_args.items(): param = sig.parameters.get(name) @@ -1661,11 +1755,33 @@ def _ensure_stream_arg(jit_args: list) -> bool: return False -def jit(func: Optional[Callable] = None) -> JitFunction: - """JIT decorator for host launcher functions.""" +def jit( + func: Optional[Callable] = None, + *, + llvm_pass_pipeline: Optional[str] = None, + llvm_pass_plugins: Optional[list] = None, +) -> JitFunction: + """JIT decorator for host launcher functions. + + ``llvm_pass_pipeline``: optional LLVM new-PM pass pipeline (e.g. + ``"default,my-pass"``) run on the device kernel IR before codegen. + ``llvm_pass_plugins``: optional list of LLVM pass plugin ``.so`` paths + loaded (``opt --load-pass-plugin``) before running that pipeline. Both + require ``FLYDSL_COMPILE_LLVM_DIR`` and override the + ``FLYDSL_COMPILE_LLVM_PASS_*`` env vars. + """ + hints = {} + if llvm_pass_pipeline is not None: + hints["llvm_pass_pipeline"] = llvm_pass_pipeline + if llvm_pass_plugins is not None: + hints["llvm_pass_plugins"] = list(llvm_pass_plugins) + + def _make(f: Callable) -> JitFunction: + return JitFunction(f, compile_hints=hints or None) + if func is None: - return lambda f: JitFunction(f) - return JitFunction(func) + return _make + return _make(func) class CompiledFunction: diff --git a/python/flydsl/utils/env.py b/python/flydsl/utils/env.py index a36716710..a8592ba1f 100644 --- a/python/flydsl/utils/env.py +++ b/python/flydsl/utils/env.py @@ -230,6 +230,18 @@ class CompileEnvManager(EnvManager): arch = OptStr("", env_var="ARCH", description="Override target GPU architecture (e.g. gfx942, gfx950)") backend = OptStr("rocm", description="GPU compile backend id (e.g. rocm)") llvm_dir = OptStr("", description="External LLVM/MLIR install prefix for final code generation") + llvm_pass_pipeline = OptStr( + "", + description="Custom LLVM new-PM pass pipeline run on the device kernel IR before codegen " + "(e.g. 'default,my-pass'); requires FLYDSL_COMPILE_LLVM_DIR. Overridden by " + "@flyc.jit(llvm_pass_pipeline=...).", + ) + llvm_pass_plugins = OptList( + [], + separator=":", + description="Colon-separated LLVM pass plugin .so paths loaded (opt --load-pass-plugin) " + "before running llvm_pass_pipeline. Overridden by @flyc.jit(llvm_pass_plugins=...).", + ) class DebugEnvManager(EnvManager): diff --git a/tests/kernels/test_llvm_pass_plugin_e2e.py b/tests/kernels/test_llvm_pass_plugin_e2e.py new file mode 100644 index 000000000..adf87b570 --- /dev/null +++ b/tests/kernels/test_llvm_pass_plugin_e2e.py @@ -0,0 +1,208 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""End-to-end test for the custom-LLVM-pass JIT path with a real pass plugin. + +Builds a minimal LLVM new-PM pass plugin (``.so``) that registers a module pass +named ``flydsl-print-tid``. The pass injects, at the entry of every +``amdgpu_kernel`` function, a device ``printf("...threadIdx.x=%d...", tid)`` call +(``tid`` from ``llvm.amdgcn.workitem.id.x``). The kernel is then driven through +``@flyc.jit(llvm_pass_pipeline=..., llvm_pass_plugins=[...])`` so the full +``opt --load-pass-plugin`` -> re-codegen -> run chain is exercised, and the +injected device print is observed in the captured output. + +Requires a ROCm GPU, ``FLYDSL_COMPILE_LLVM_DIR`` (for ``opt``/``mlir-translate``/ +``mlir-opt`` + LLVM headers), and a host C++ compiler; skipped otherwise. +""" + +import os +import shutil +import subprocess +from pathlib import Path + +import pytest + +pytestmark = [pytest.mark.l2_device, pytest.mark.rocm_lower] + +torch = pytest.importorskip("torch") + +import flydsl.compiler as flyc # noqa: E402 +import flydsl.expr as fx # noqa: E402 +from flydsl.compiler.external_llvm import ExternalLLVMError # noqa: E402 + +# An LLVM new-PM module pass registered under ``flydsl-print-tid`` via the +# pass-plugin C API. At the entry of every amdgpu_kernel it emits FlyDSL's exact +# hostcall device-printf sequence (``__ockl_printf_begin`` / ``append_string_n`` / +# ``append_args``) printing ``threadIdx.x``. Using the same ockl ABI as +# ``fx.printf`` means the ROCm runtime FlyDSL already sets up services it, and +# ``ockl`` is linked during the O=0 re-codegen. (Note: the C ``printf`` + +# ``amdgpu-printf-runtime-binding`` route instead emits the buffered +# ``__printf_alloc`` path, which FlyDSL's runtime does not service.) +PLUGIN_SRC = r""" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/IntrinsicsAMDGPU.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Plugins/PassPlugin.h" +using namespace llvm; +namespace { +struct PrintTidPass : PassInfoMixin { + PreservedAnalyses run(Module &M, ModuleAnalysisManager &) { + LLVMContext &C = M.getContext(); + auto *i64 = Type::getInt64Ty(C); + auto *i32 = Type::getInt32Ty(C); + auto *ptr = PointerType::get(C, 0); + FunctionCallee beginF = + M.getOrInsertFunction("__ockl_printf_begin", FunctionType::get(i64, {i64}, false)); + FunctionCallee strF = M.getOrInsertFunction( + "__ockl_printf_append_string_n", FunctionType::get(i64, {i64, ptr, i64, i32}, false)); + FunctionCallee argsF = M.getOrInsertFunction( + "__ockl_printf_append_args", + FunctionType::get(i64, {i64, i32, i64, i64, i64, i64, i64, i64, i64, i32}, false)); + Function *widx = Intrinsic::getOrInsertDeclaration(&M, Intrinsic::amdgcn_workitem_id_x); + bool changed = false; + for (Function &F : M) { + if (F.isDeclaration() || F.getCallingConv() != CallingConv::AMDGPU_KERNEL) + continue; + IRBuilder<> B(&*F.getEntryBlock().getFirstInsertionPt()); + Constant *str = ConstantDataArray::getString(C, "flydsl-pass: threadIdx.x=%d\n", true); + // Format string must live in addrspace 0 (matches the ockl append ABI). + auto *gv = new GlobalVariable(M, str->getType(), true, GlobalValue::InternalLinkage, str, + "flydsl_tid_fmt", nullptr, GlobalValue::NotThreadLocal, 0); + uint64_t len = cast(str->getType())->getNumElements(); + Value *tid = B.CreateZExt(B.CreateCall(widx, {}), i64); + Value *z = ConstantInt::get(i64, 0); + Value *h0 = B.CreateCall(beginF, {z}); + Value *h1 = B.CreateCall(strF, {h0, gv, ConstantInt::get(i64, len), ConstantInt::get(i32, 0)}); + B.CreateCall(argsF, {h1, ConstantInt::get(i32, 1), tid, z, z, z, z, z, z, ConstantInt::get(i32, 1)}); + changed = true; + } + return changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); + } +}; +} // namespace +extern "C" LLVM_ATTRIBUTE_WEAK ::llvm::PassPluginLibraryInfo llvmGetPassPluginInfo() { + return {LLVM_PLUGIN_API_VERSION, "FlydslPrintTid", LLVM_VERSION_STRING, [](PassBuilder &PB) { + PB.registerPipelineParsingCallback( + [](StringRef N, ModulePassManager &MPM, ArrayRef) { + if (N == "flydsl-print-tid") { MPM.addPass(PrintTidPass()); return true; } + return false; + }); + }}; +} +""" + + +def _gpu_available() -> bool: + try: + from flydsl.runtime.device import get_rocm_device_count + + return get_rocm_device_count() > 0 + except Exception: + return False + + +@pytest.fixture(scope="module") +def print_tid_plugin(tmp_path_factory) -> str: + """Compile the print-tid pass plugin against the LLVM prefix whose ``opt`` + will load it (ABI must match), or skip if the toolchain is unavailable.""" + raw = os.environ.get("FLYDSL_COMPILE_LLVM_DIR", "").strip() + if not raw: + pytest.skip("FLYDSL_COMPILE_LLVM_DIR not set; required to build/load an LLVM pass plugin") + prefix = Path(raw).expanduser().resolve() + llvm_config = prefix / "bin" / "llvm-config" + header = prefix / "include" / "llvm" / "Plugins" / "PassPlugin.h" + cxx = shutil.which("clang++") or shutil.which("g++") + if not llvm_config.is_file() or not header.is_file() or cxx is None: + pytest.skip("LLVM headers/llvm-config or a C++ compiler not available for plugin build") + + cxxflags = subprocess.check_output([str(llvm_config), "--cxxflags"], text=True).split() + work = tmp_path_factory.mktemp("llvm_plugin") + src = work / "flydsl_print_tid.cpp" + src.write_text(PLUGIN_SRC, encoding="utf-8") + so = work / "libFlydslPrintTid.so" + subprocess.run([cxx, "-shared", "-fPIC", *cxxflags, str(src), "-o", str(so)], check=True) + assert so.is_file() + return str(so) + + +@flyc.kernel +def _add_kernel(A: fx.Tensor, B: fx.Tensor, C: fx.Tensor, block_dim: fx.Constexpr[int]): + bid = fx.block_idx.x + tid = fx.thread_idx.x + A = fx.rocdl.make_buffer_tensor(A) + tA = fx.logical_divide(A, fx.make_layout(block_dim, 1)) + tB = fx.logical_divide(B, fx.make_layout(block_dim, 1)) + tC = fx.logical_divide(C, fx.make_layout(block_dim, 1)) + tA = fx.slice(tA, (None, bid)) + tB = fx.slice(tB, (None, bid)) + tC = fx.slice(tC, (None, bid)) + tA = fx.logical_divide(tA, fx.make_layout(1, 1)) + tB = fx.logical_divide(tB, fx.make_layout(1, 1)) + tC = fx.logical_divide(tC, fx.make_layout(1, 1)) + ca = fx.make_copy_atom(fx.UniversalCopy32b(), fx.Float32) + cab = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), fx.Float32) + rA = fx.make_rmem_tensor(fx.make_layout(1, 1), fx.Float32) + rB = fx.make_rmem_tensor(fx.make_layout(1, 1), fx.Float32) + rC = fx.make_rmem_tensor(fx.make_layout(1, 1), fx.Float32) + fx.copy_atom_call(cab, fx.slice(tA, (None, tid)), rA) + fx.copy_atom_call(ca, fx.slice(tB, (None, tid)), rB) + vC = fx.arith.addf(fx.memref_load_vec(rA), fx.memref_load_vec(rB)) + fx.memref_store_vec(vC, rC) + fx.copy_atom_call(ca, rC, fx.slice(tC, (None, tid))) + + +def _make_add_jit(**jit_kwargs): + @flyc.jit(**jit_kwargs) + def add(A: fx.Tensor, B: fx.Tensor, C, n: fx.Int32, stream: fx.Stream = fx.Stream(None)): + block_dim = 64 + grid_x = (n + block_dim - 1) // block_dim + _add_kernel(A, B, C, block_dim).launch(grid=(grid_x, 1, 1), block=[block_dim, 1, 1], stream=stream) + + return add + + +@pytest.mark.skipif(not _gpu_available(), reason="requires a ROCm GPU") +def test_print_tid_plugin_injects_device_printf(print_tid_plugin, monkeypatch): + """Positive: with the plugin loaded and the pipeline naming the plugin pass, + the printf-injected kernel compiles, links (ockl), and runs correctly. + + The injected device ``printf`` (``flydsl-pass: threadIdx.x=...``, one line per + lane) is written by the ROCm hostcall consumer to a file descriptor HIP + cached at init, so pytest's in-process capture does not see it; run with + ``pytest -s`` to view it on the terminal. Compile + link + correct execution + of the injected ``__ockl_printf_*`` IR is what this asserts.""" + monkeypatch.setenv("FLYDSL_RUNTIME_ENABLE_CACHE", "0") + add = _make_add_jit(llvm_pass_pipeline="default,flydsl-print-tid", llvm_pass_plugins=[print_tid_plugin]) + + n = 64 # one block of 64 lanes -> 64 valid threads + A = torch.randint(0, 10, (n,), dtype=torch.float32).cuda() + B = torch.randint(0, 10, (n,), dtype=torch.float32).cuda() + C = torch.zeros(n, dtype=torch.float32).cuda() + tA = flyc.from_dlpack(A).mark_layout_dynamic(leading_dim=0, divisibility=4) + add(tA, B, C, n, stream=torch.cuda.Stream()) + torch.cuda.synchronize() + + assert torch.allclose(C, A + B) + + +def test_print_tid_pipeline_without_plugin_fails(print_tid_plugin, monkeypatch): + """Negative: the same pipeline naming ``flydsl-print-tid`` *without* loading + the plugin must fail at the ``opt`` step — proving the plugin provides the + pass. (Fails during compile, before any GPU execution.)""" + monkeypatch.setenv("FLYDSL_RUNTIME_ENABLE_CACHE", "0") + add = _make_add_jit(llvm_pass_pipeline="default,flydsl-print-tid") # no plugins + + n = 64 + A = torch.zeros(n, dtype=torch.float32) + B = torch.zeros(n, dtype=torch.float32) + C = torch.zeros(n, dtype=torch.float32) + if _gpu_available(): + A, B, C = A.cuda(), B.cuda(), C.cuda() + tA = flyc.from_dlpack(A).mark_layout_dynamic(leading_dim=0, divisibility=4) + + with pytest.raises(ExternalLLVMError) as excinfo: + add(tA, B, C, n, stream=torch.cuda.Stream() if _gpu_available() else fx.Stream(None)) + assert "flydsl-print-tid" in str(excinfo.value) diff --git a/tests/unit/test_llvm_pass_pipeline.py b/tests/unit/test_llvm_pass_pipeline.py new file mode 100644 index 000000000..cd7a71c9e --- /dev/null +++ b/tests/unit/test_llvm_pass_pipeline.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""Custom LLVM pass pipeline JIT plumbing (no kernel execution / external toolchain needed).""" + +import flydsl.compiler as flyc +from flydsl.compiler.backends.rocm import RocmBackend +from flydsl.compiler.external_llvm import llvm_opt_fingerprint +from flydsl.compiler.jit_function import _effective_llvm_pass_config + + +def test_recodegen_fragments_use_opt_level_zero_by_default(): + backend = RocmBackend(RocmBackend.detect_target()) + arch = backend.target.arch + attach, binary = backend.llvm_recodegen_fragments(compile_hints={}) + assert attach.startswith("rocdl-attach-target{") + assert "O=0" in attach + assert f"chip={arch}" in attach + assert binary.startswith("gpu-module-to-binary{") + assert "format=fatbin" in binary + + +def test_recodegen_fragments_opt_level_override(): + backend = RocmBackend(RocmBackend.detect_target()) + attach, _ = backend.llvm_recodegen_fragments(compile_hints={}, opt_level=3) + assert "O=3" in attach + + +def test_jit_decorator_records_llvm_pass_hints(): + @flyc.jit(llvm_pass_pipeline="default,my-pass", llvm_pass_plugins=["/tmp/libMy.so"]) + def f(): # pragma: no cover - never executed + pass + + assert f.compile_hints["llvm_pass_pipeline"] == "default,my-pass" + assert f.compile_hints["llvm_pass_plugins"] == ["/tmp/libMy.so"] + + +def test_jit_decorator_without_llvm_pass_has_no_hints(): + @flyc.jit + def f(): # pragma: no cover - never executed + pass + + assert "llvm_pass_pipeline" not in f.compile_hints + + +def test_effective_config_prefers_hints_over_env(monkeypatch): + monkeypatch.setenv("FLYDSL_COMPILE_LLVM_PASS_PIPELINE", "default") + monkeypatch.setenv("FLYDSL_COMPILE_LLVM_PASS_PLUGINS", "/env/a.so:/env/b.so") + + # hints win + pipe, plugins = _effective_llvm_pass_config({"llvm_pass_pipeline": "default", "llvm_pass_plugins": ["/h.so"]}) + assert pipe == "default" + assert plugins == ["/h.so"] + + # env fallback when hints absent + pipe, plugins = _effective_llvm_pass_config({}) + assert pipe == "default" + assert plugins == ["/env/a.so", "/env/b.so"] + + +def test_fingerprint_changes_with_pipeline_and_plugins(tmp_path): + assert llvm_opt_fingerprint("default") != llvm_opt_fingerprint("default") + + so = tmp_path / "libP.so" + so.write_bytes(b"v1") + fp1 = llvm_opt_fingerprint("default", [str(so)]) + so.write_bytes(b"v2-changed") + fp2 = llvm_opt_fingerprint("default", [str(so)]) + assert fp1 != fp2 # plugin content edit invalidates + assert str(so) in fp1 + + +def test_fingerprint_tolerates_missing_plugin(): + fp = llvm_opt_fingerprint("default", ["/does/not/exist.so"]) + assert "" in fp From 2204698db344409fb4822c01978137f90020be79 Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Wed, 3 Jun 2026 05:29:22 +0000 Subject: [PATCH 2/5] feat(compiler): Add fly-llc codegen path for custom MIR passes Signed-off-by: fsx950223 --- python/flydsl/compiler/external_llvm.py | 136 ++++++++ python/flydsl/compiler/jit_function.py | 108 ++++++- python/flydsl/utils/env.py | 14 + tests/kernels/test_llvm_codegen_pass_e2e.py | 336 ++++++++++++++++++++ tests/unit/test_llvm_pass_pipeline.py | 41 ++- tools/CMakeLists.txt | 1 + tools/fly-llc/CMakeLists.txt | 21 ++ tools/fly-llc/fly-llc.cpp | 125 ++++++++ 8 files changed, 769 insertions(+), 13 deletions(-) create mode 100644 tests/kernels/test_llvm_codegen_pass_e2e.py create mode 100644 tools/fly-llc/CMakeLists.txt create mode 100644 tools/fly-llc/fly-llc.cpp diff --git a/python/flydsl/compiler/external_llvm.py b/python/flydsl/compiler/external_llvm.py index eac41d118..338486ee1 100644 --- a/python/flydsl/compiler/external_llvm.py +++ b/python/flydsl/compiler/external_llvm.py @@ -334,3 +334,139 @@ def run_llvm_opt_then_binary( finally: if tmp_dir_obj is not None: tmp_dir_obj.cleanup() + + +# --------------------------------------------------------------------------- +# Custom-codegen path: fly-llc (IR -> obj with injectable MIR passes) + ld.lld +# --------------------------------------------------------------------------- + + +def _fly_llc_path() -> Path: + raw = env.compile.fly_llc.strip() + if raw: + return Path(raw).expanduser() + cand = _llvm_dir() / "bin" / "fly-llc" + if cand.is_file(): + return cand + raise ExternalLLVMError( + "fly-llc tool not found: set FLYDSL_COMPILE_FLY_LLC or build fly-llc into /bin." + ) + + +def _lld_path() -> Path: + raw = env.compile.lld.strip() + if raw: + return Path(raw).expanduser() + cand = _llvm_dir() / "bin" / "ld.lld" + if cand.is_file(): + return cand + raise ExternalLLVMError( + "fly-llc codegen path needs ld.lld: set FLYDSL_COMPILE_LLD or place ld.lld in /bin." + ) + + +def fly_llc_codegen_fingerprint(passes: Optional[list] = None, plugins: Optional[list] = None) -> str: + """Cache fingerprint for a fly-llc codegen configuration: the pass names plus + the fly-llc binary's and each plugin's content hash.""" + parts = ["fly-llc-codegen:" + ",".join(passes or [])] + try: + parts.append(_file_hash(_fly_llc_path().resolve())) + except OSError: + parts.append("") + except ExternalLLVMError: + parts.append("") + for p in plugins or []: + path = Path(p).expanduser() + try: + parts.append(f"{path}:{_file_hash(path.resolve())}") + except OSError: + parts.append(f"{path}:") + return ";".join(parts) + + +def _gpu_binary_module_text(name: str, target_cpu: str, hsaco: bytes) -> str: + """Build a ``builtin.module`` text embedding *hsaco* as a ``gpu.binary @name`` + (every byte escaped as ``\\XX`` for the MLIR string attribute).""" + esc = "".join("\\%02X" % b for b in hsaco) + return ( + "module attributes {gpu.container_module} {\n" + f' gpu.binary @{name} [#gpu.object<#rocdl.target, kernels = <>, bin = "{esc}">]\n' + "}\n" + ) + + +def run_fly_llc_codegen( + module: ir.Module, + *, + llvm_ir: str, + codegen_passes: list, + codegen_plugins: Optional[list] = None, + target_triple: str, + target_cpu: str, + work_dir: Optional[Path] = None, + stage_prefix: str = "fly_llc", +) -> None: + """Codegen the device kernel's LLVM IR with injectable MIR passes and splice + the result back into *module*. + + Flow: ``fly-llc -o --load= --pre-emit-pass=`` + (custom MIR passes run pre-emit in the standard codegen) -> ``ld.lld -shared`` + -> wrap the HSACO bytes into a ``gpu.binary`` -> replace the in-process + ``gpu.module``. + """ + fly_llc = _fly_llc_path() + lld = _lld_path() + prefix = _llvm_dir() + + gpu_module = _single_top_level_op(module, "gpu.module") + name = _symbol_name(gpu_module) + + tmp_dir_obj = None + if work_dir is None: + tmp_dir_obj = tempfile.TemporaryDirectory(prefix="flydsl_fly_llc_") + work_dir = Path(tmp_dir_obj.name) + else: + work_dir.mkdir(parents=True, exist_ok=True) + + in_ll = work_dir / f"{stage_prefix}_pre_codegen.ll" + obj = work_dir / f"{stage_prefix}.o" + hsaco = work_dir / f"{stage_prefix}.hsaco" + bin_mlir = work_dir / f"{stage_prefix}_binary.mlir" + + try: + in_ll.write_text(llvm_ir, encoding="utf-8") + + plugin_args = [f"--load={Path(p).expanduser()}" for p in (codegen_plugins or [])] + pass_args = [f"--pre-emit-pass={n}" for n in (codegen_passes or [])] + _run_tool( + [ + str(fly_llc), + str(in_ll), + "-o", + str(obj), + f"-mtriple={target_triple}", + f"-mcpu={target_cpu}", + *plugin_args, + *pass_args, + ], + prefix=prefix, + what="fly-llc codegen", + work_dir=work_dir, + ) + + _run_tool( + [str(lld), "-shared", str(obj), "-o", str(hsaco)], + prefix=prefix, + what="ld.lld HSACO link", + work_dir=work_dir, + ) + + if not hsaco.is_file(): + raise ExternalLLVMError(f"ld.lld did not create HSACO: {hsaco}") + text = _gpu_binary_module_text(name, target_cpu, hsaco.read_bytes()) + bin_mlir.write_text(text, encoding="utf-8") + binary_module = ir.Module.parse(text, context=module.context) + _replace_gpu_module_with_binary_op(module, binary_module) + finally: + if tmp_dir_obj is not None: + tmp_dir_obj.cleanup() diff --git a/python/flydsl/compiler/jit_function.py b/python/flydsl/compiler/jit_function.py index 4f2edc29b..826a9f2f3 100644 --- a/python/flydsl/compiler/jit_function.py +++ b/python/flydsl/compiler/jit_function.py @@ -124,6 +124,8 @@ def _create_mlir_context(*, load_dialects=True): "FLYDSL_EXTRA_SOURCE_DIRS", "FLYDSL_COMPILE_LLVM_PASS_PIPELINE", "FLYDSL_COMPILE_LLVM_PASS_PLUGINS", + "FLYDSL_COMPILE_LLVM_CODEGEN_PASSES", + "FLYDSL_COMPILE_LLVM_CODEGEN_PLUGINS", ) @@ -749,6 +751,8 @@ class PipelineConfig: external: bool llvm_pass_pipeline: str = "" llvm_pass_plugins: Optional[list] = None + llvm_codegen_passes: Optional[list] = None + llvm_codegen_plugins: Optional[list] = None def _effective_llvm_pass_config(hints: dict): @@ -763,6 +767,18 @@ def _effective_llvm_pass_config(hints: dict): return (pipeline or "").strip(), list(plugins or []) +def _effective_llvm_codegen_config(hints: dict): + """Resolve the custom MIR codegen passes + plugins (fly-llc path), preferring + @flyc.jit compile_hints over the FLYDSL_COMPILE_LLVM_CODEGEN_* env vars.""" + passes = hints.get("llvm_codegen_passes") + if passes is None: + passes = env.compile.llvm_codegen_passes + plugins = hints.get("llvm_codegen_plugins") + if plugins is None: + plugins = env.compile.llvm_codegen_plugins + return list(passes or []), list(plugins or []) + + def _pipeline_fragments_for_mode(backend) -> PipelineConfig: """Return pipeline configuration including optional external split.""" from .kernel_function import CompilationContext @@ -770,10 +786,12 @@ def _pipeline_fragments_for_mode(backend) -> PipelineConfig: hints = CompilationContext.get_compile_hints() llvm_opts = hints.get("llvm_options") llvm_pass_pipeline, llvm_pass_plugins = _effective_llvm_pass_config(hints) + llvm_codegen_passes, llvm_codegen_plugins = _effective_llvm_codegen_config(hints) - # Custom LLVM pass pipeline: split off the binary fragment so we can extract - # LLVM IR, run `opt`, and re-codegen externally (see MlirCompiler.compile). - if llvm_pass_pipeline: + # Custom LLVM IR pass pipeline (opt) and/or custom MIR codegen passes (fly-llc): + # both need the binary fragment split off so we can extract LLVM IR and run an + # external codegen tail (see MlirCompiler.compile). + if llvm_pass_pipeline or llvm_codegen_passes: pre_binary_fragments, binary_fragment = backend.external_binary_pipeline_fragments(compile_hints=hints) return PipelineConfig( fragments=[*pre_binary_fragments, binary_fragment], @@ -783,6 +801,8 @@ def _pipeline_fragments_for_mode(backend) -> PipelineConfig: external=False, llvm_pass_pipeline=llvm_pass_pipeline, llvm_pass_plugins=llvm_pass_plugins, + llvm_codegen_passes=llvm_codegen_passes, + llvm_codegen_plugins=llvm_codegen_plugins, ) if _use_external_binary_codegen(): @@ -841,8 +861,8 @@ def compile( "use embedded codegen for kernels that require #fly.explicit_module." ) - if cfg.llvm_pass_pipeline and link_libs: - raise RuntimeError("custom llvm_pass_pipeline does not support extern link_libs yet.") + if (cfg.llvm_pass_pipeline or cfg.llvm_codegen_passes) and link_libs: + raise RuntimeError("custom llvm_pass_pipeline / llvm_codegen_passes do not support extern link_libs yet.") if link_libs: link_opt = _format_link_lib_options(link_libs) @@ -861,6 +881,10 @@ def compile( dump_dir = Path(env.debug.dump_dir).resolve() with _llvm_ctx: + if cfg.llvm_codegen_passes: + return cls._compile_with_fly_llc( + module, backend, cfg, func_name=func_name, dump_enabled=dump_enabled, dump_dir=dump_dir + ) if cfg.llvm_pass_pipeline: return cls._compile_with_llvm_opt( module, backend, cfg, func_name=func_name, dump_enabled=dump_enabled, dump_dir=dump_dir @@ -1023,6 +1047,52 @@ def _compile_with_llvm_opt( module.operation.verify() return module + @classmethod + def _compile_with_fly_llc( + cls, module, backend, cfg, *, func_name: str, dump_enabled: bool, dump_dir: Path + ) -> ir.Module: + """Custom-codegen path: run the pre-binary fragments in-process, extract the + device LLVM IR, then codegen via ``fly-llc`` (injecting the requested MIR + passes pre-emit) + ``ld.lld``, and splice the resulting binary back.""" + from .external_llvm import run_fly_llc_codegen + from .kernel_function import CompilationContext + + hints = CompilationContext.get_compile_hints() + work_dir = None + if dump_enabled: + asm = module.operation.get_asm(enable_debug_info=True) + kernel_names = _infer_kernel_names_from_asm(asm) + subdir = kernel_names[0] if len(kernel_names) == 1 else (func_name or "module") + work_dir = dump_dir / _sanitize_path_component(subdir) + print(f"[flydsl.compile] FLYDSL_DUMP_IR=1 (llvm_codegen_passes) dir={work_dir}") + + # Run everything up to (but not including) gpu-module-to-binary in-process. + _run_pipeline( + module, + cfg.pre_binary, + verifier=env.debug.enable_verifier, + print_after_all=env.debug.print_after_all, + ) + + llvm_ir = _extract_llvm_ir(module) + if llvm_ir is None: + raise FlyDSLCompileError( + "llvm_codegen_passes is set but the device LLVM IR could not be extracted from the gpu.module." + ) + + rocdl_opts = backend._rocdl_opts(compile_hints=hints) + run_fly_llc_codegen( + module, + llvm_ir=llvm_ir, + codegen_passes=cfg.llvm_codegen_passes, + codegen_plugins=cfg.llvm_codegen_plugins, + target_triple=rocdl_opts["triple"], + target_cpu=rocdl_opts["chip"], + work_dir=work_dir, + ) + module.operation.verify() + return module + class JitCacheManager: """Directory-based cache manager with multi-process safety. @@ -1438,6 +1508,11 @@ def _resolve_and_make_cache_key(self, bound_args): from .external_llvm import llvm_opt_fingerprint key_parts.append(("_llvm_pass_", llvm_opt_fingerprint(eff_pipeline, eff_plugins))) + eff_cg_passes, eff_cg_plugins = _effective_llvm_codegen_config(self.compile_hints) + if eff_cg_passes: + from .external_llvm import fly_llc_codegen_fingerprint + + key_parts.append(("_llvm_codegen_", fly_llc_codegen_fingerprint(eff_cg_passes, eff_cg_plugins))) for name, arg in bound_args.items(): param = sig.parameters.get(name) @@ -1760,21 +1835,32 @@ def jit( *, llvm_pass_pipeline: Optional[str] = None, llvm_pass_plugins: Optional[list] = None, + llvm_codegen_passes: Optional[list] = None, + llvm_codegen_plugins: Optional[list] = None, ) -> JitFunction: """JIT decorator for host launcher functions. - ``llvm_pass_pipeline``: optional LLVM new-PM pass pipeline (e.g. - ``"default,my-pass"``) run on the device kernel IR before codegen. - ``llvm_pass_plugins``: optional list of LLVM pass plugin ``.so`` paths - loaded (``opt --load-pass-plugin``) before running that pipeline. Both - require ``FLYDSL_COMPILE_LLVM_DIR`` and override the - ``FLYDSL_COMPILE_LLVM_PASS_*`` env vars. + ``llvm_pass_pipeline`` / ``llvm_pass_plugins``: optional LLVM new-PM **IR** + pass pipeline (e.g. ``"default,my-pass"``) and pass plugin ``.so`` paths + run on the device kernel IR before codegen (``opt --load-pass-plugin``). + + ``llvm_codegen_passes`` / ``llvm_codegen_plugins``: optional **MIR/codegen** + pass names inserted pre-emit, and legacy pass plugin ``.so`` paths, applied + via the ``fly-llc`` codegen path (``fly-llc --load`` + ``--pre-emit-pass``, + then ``ld.lld``). + + All require ``FLYDSL_COMPILE_LLVM_DIR`` and override the corresponding + ``FLYDSL_COMPILE_LLVM_*`` env vars. """ hints = {} if llvm_pass_pipeline is not None: hints["llvm_pass_pipeline"] = llvm_pass_pipeline if llvm_pass_plugins is not None: hints["llvm_pass_plugins"] = list(llvm_pass_plugins) + if llvm_codegen_passes is not None: + hints["llvm_codegen_passes"] = list(llvm_codegen_passes) + if llvm_codegen_plugins is not None: + hints["llvm_codegen_plugins"] = list(llvm_codegen_plugins) def _make(f: Callable) -> JitFunction: return JitFunction(f, compile_hints=hints or None) diff --git a/python/flydsl/utils/env.py b/python/flydsl/utils/env.py index a8592ba1f..0f2e571f1 100644 --- a/python/flydsl/utils/env.py +++ b/python/flydsl/utils/env.py @@ -242,6 +242,20 @@ class CompileEnvManager(EnvManager): description="Colon-separated LLVM pass plugin .so paths loaded (opt --load-pass-plugin) " "before running llvm_pass_pipeline. Overridden by @flyc.jit(llvm_pass_plugins=...).", ) + llvm_codegen_passes = OptList( + [], + description="Comma-separated MIR pass names inserted pre-emit by the fly-llc codegen path " + "(requires fly-llc + ld.lld + FLYDSL_COMPILE_LLVM_DIR). Overridden by " + "@flyc.jit(llvm_codegen_passes=...).", + ) + llvm_codegen_plugins = OptList( + [], + separator=":", + description="Colon-separated legacy MIR pass plugin .so paths loaded by fly-llc " + "(fly-llc --load). Overridden by @flyc.jit(llvm_codegen_plugins=...).", + ) + fly_llc = OptStr("", description="Path to the fly-llc tool (default: /bin/fly-llc).") + lld = OptStr("", description="Path to ld.lld for the fly-llc codegen path (default: /bin/ld.lld).") class DebugEnvManager(EnvManager): diff --git a/tests/kernels/test_llvm_codegen_pass_e2e.py b/tests/kernels/test_llvm_codegen_pass_e2e.py new file mode 100644 index 000000000..02f070d8d --- /dev/null +++ b/tests/kernels/test_llvm_codegen_pass_e2e.py @@ -0,0 +1,336 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""End-to-end test for the custom-MIR-codegen JIT path (fly-llc). + +Builds a legacy MachineFunction (MIR) pass plugin (``.so``) registering a pass +``fly-mir-pass`` that runs during codegen (pre-emit) and prints the machine +function name. The kernel is driven through +``@flyc.jit(llvm_codegen_passes=["fly-mir-pass"], llvm_codegen_plugins=[...])`` so +the full chain is exercised: + + device .ll -> fly-llc (--load + --pre-emit-pass) -> obj -> ld.lld -> HSACO + -> gpu.binary -> splice -> run + +Requires a ROCm GPU, ``FLYDSL_COMPILE_LLVM_DIR`` (LLVM headers + llvm-config), a +host C++ compiler, the ``fly-llc`` tool, and ``ld.lld``; skipped otherwise. +""" + +import os +import shutil +import subprocess +from pathlib import Path + +import pytest + +pytestmark = [pytest.mark.l2_device, pytest.mark.rocm_lower] + +torch = pytest.importorskip("torch") + +import flydsl.compiler as flyc # noqa: E402 +import flydsl.expr as fx # noqa: E402 +from flydsl.compiler.external_llvm import ExternalLLVMError # noqa: E402 + +# Legacy MachineFunctionPass plugin: registers "fly-mir-pass" via RegisterPass, +# runs pre-emit during codegen, prints the MF name (observable under `pytest -s`). +PLUGIN_SRC = r""" +#include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineFunctionPass.h" +#include "llvm/Pass.h" +#include "llvm/Support/raw_ostream.h" +using namespace llvm; +namespace { +struct FlyMirPass : public MachineFunctionPass { + static char ID; + FlyMirPass() : MachineFunctionPass(ID) {} + bool runOnMachineFunction(MachineFunction &MF) override { + errs() << "fly-mir-pass: ran on " << MF.getName() << "\n"; + return false; + } + StringRef getPassName() const override { return "Fly demo MIR pass"; } +}; +char FlyMirPass::ID = 0; +} // namespace +static RegisterPass X("fly-mir-pass", "Fly demo MIR pass", false, false); +""" + +# A MIR pass that *modifies* the machine code: inserts 8 ``s_nop`` at the entry +# of every kernel. The opcode is found by name so no AMDGPU target headers are +# needed. 8 nops/function is a distinctive, measurable change in the ASM. +PLUGIN_SRC_NOP = r""" +#include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineFunctionPass.h" +#include "llvm/CodeGen/MachineInstrBuilder.h" +#include "llvm/CodeGen/TargetInstrInfo.h" +#include "llvm/CodeGen/TargetSubtargetInfo.h" +#include "llvm/Pass.h" +using namespace llvm; +namespace { +struct FlyInsertNopPass : public MachineFunctionPass { + static char ID; + FlyInsertNopPass() : MachineFunctionPass(ID) {} + bool runOnMachineFunction(MachineFunction &MF) override { + const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo(); + unsigned NopOpc = ~0u; + for (unsigned i = 0, e = TII->getNumOpcodes(); i < e; ++i) + if (TII->getName(i) == "S_NOP") { NopOpc = i; break; } + if (NopOpc == ~0u || MF.empty()) return false; + MachineBasicBlock &MBB = MF.front(); + auto It = MBB.begin(); + for (int k = 0; k < 8; ++k) + BuildMI(MBB, It, DebugLoc(), TII->get(NopOpc)).addImm(0); + return true; + } + StringRef getPassName() const override { return "Fly insert NOP MIR pass"; } +}; +char FlyInsertNopPass::ID = 0; +} // namespace +static RegisterPass X("fly-insert-nop", "Fly insert NOP", false, false); +""" + + +def _gpu_available() -> bool: + try: + from flydsl.runtime.device import get_rocm_device_count + + return get_rocm_device_count() > 0 + except Exception: + return False + + +def _resolve_tool(env_var: str, name: str): + raw = os.environ.get(env_var, "").strip() + if raw and Path(raw).expanduser().is_file(): + return Path(raw).expanduser() + llvm_dir = os.environ.get("FLYDSL_COMPILE_LLVM_DIR", "").strip() + if llvm_dir: + cand = Path(llvm_dir).expanduser() / "bin" / name + if cand.is_file(): + return cand + return None + + +def _build_codegen_plugin(tmp_path_factory, *, src: str, name: str) -> str: + """Skip unless the codegen toolchain is present, then compile a plugin .so.""" + raw = os.environ.get("FLYDSL_COMPILE_LLVM_DIR", "").strip() + if not raw: + pytest.skip("FLYDSL_COMPILE_LLVM_DIR not set; required to build/load a codegen pass plugin") + prefix = Path(raw).expanduser().resolve() + llvm_config = prefix / "bin" / "llvm-config" + cxx = shutil.which("clang++") or shutil.which("g++") + if not llvm_config.is_file() or cxx is None: + pytest.skip("llvm-config or a C++ compiler not available for plugin build") + if _resolve_tool("FLYDSL_COMPILE_FLY_LLC", "fly-llc") is None: + pytest.skip("fly-llc not found; set FLYDSL_COMPILE_FLY_LLC or build it into /bin") + if _resolve_tool("FLYDSL_COMPILE_LLD", "ld.lld") is None: + pytest.skip("ld.lld not found; set FLYDSL_COMPILE_LLD or place ld.lld in /bin") + + cxxflags = subprocess.check_output([str(llvm_config), "--cxxflags"], text=True).split() + work = tmp_path_factory.mktemp("codegen_plugin") + cpp = work / (name + ".cpp") + cpp.write_text(src, encoding="utf-8") + so = work / ("lib" + name + ".so") + subprocess.run([cxx, "-shared", "-fPIC", *cxxflags, str(cpp), "-o", str(so)], check=True) + assert so.is_file() + return str(so) + + +@pytest.fixture(scope="module") +def mir_pass_plugin(tmp_path_factory) -> str: + """Compile the print-only MIR pass plugin (and ensure fly-llc + ld.lld exist).""" + return _build_codegen_plugin(tmp_path_factory, src=PLUGIN_SRC, name="FlyMir") + + +@pytest.fixture(scope="module") +def nop_pass_plugin(tmp_path_factory) -> str: + """Compile the s_nop-inserting MIR pass plugin (modifies the machine code).""" + return _build_codegen_plugin(tmp_path_factory, src=PLUGIN_SRC_NOP, name="FlyNop") + + +@flyc.kernel +def _add_kernel(A: fx.Tensor, B: fx.Tensor, C: fx.Tensor, block_dim: fx.Constexpr[int]): + bid = fx.block_idx.x + tid = fx.thread_idx.x + A = fx.rocdl.make_buffer_tensor(A) + tA = fx.logical_divide(A, fx.make_layout(block_dim, 1)) + tB = fx.logical_divide(B, fx.make_layout(block_dim, 1)) + tC = fx.logical_divide(C, fx.make_layout(block_dim, 1)) + tA = fx.slice(tA, (None, bid)) + tB = fx.slice(tB, (None, bid)) + tC = fx.slice(tC, (None, bid)) + tA = fx.logical_divide(tA, fx.make_layout(1, 1)) + tB = fx.logical_divide(tB, fx.make_layout(1, 1)) + tC = fx.logical_divide(tC, fx.make_layout(1, 1)) + ca = fx.make_copy_atom(fx.UniversalCopy32b(), fx.Float32) + cab = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), fx.Float32) + rA = fx.make_rmem_tensor(fx.make_layout(1, 1), fx.Float32) + rB = fx.make_rmem_tensor(fx.make_layout(1, 1), fx.Float32) + rC = fx.make_rmem_tensor(fx.make_layout(1, 1), fx.Float32) + fx.copy_atom_call(cab, fx.slice(tA, (None, tid)), rA) + fx.copy_atom_call(ca, fx.slice(tB, (None, tid)), rB) + vC = fx.arith.addf(fx.memref_load_vec(rA), fx.memref_load_vec(rB)) + fx.memref_store_vec(vC, rC) + fx.copy_atom_call(ca, rC, fx.slice(tC, (None, tid))) + + +def _make_add_jit(**jit_kwargs): + @flyc.jit(**jit_kwargs) + def add(A: fx.Tensor, B: fx.Tensor, C, n: fx.Int32, stream: fx.Stream = fx.Stream(None)): + block_dim = 64 + grid_x = (n + block_dim - 1) // block_dim + _add_kernel(A, B, C, block_dim).launch(grid=(grid_x, 1, 1), block=[block_dim, 1, 1], stream=stream) + + return add + + +@pytest.mark.skipif(not _gpu_available(), reason="requires a ROCm GPU") +def test_codegen_mir_pass_compiles_and_runs(mir_pass_plugin, monkeypatch): + """Positive: a custom MIR pass injected pre-emit via fly-llc; the kernel + codegens through fly-llc + ld.lld, links, and runs correctly. + + The pass's ``fly-mir-pass: ran on ...`` line is printed to stderr during + codegen; run with ``pytest -s`` to view it.""" + monkeypatch.setenv("FLYDSL_RUNTIME_ENABLE_CACHE", "0") + add = _make_add_jit(llvm_codegen_passes=["fly-mir-pass"], llvm_codegen_plugins=[mir_pass_plugin]) + + n = 64 + A = torch.randint(0, 10, (n,), dtype=torch.float32).cuda() + B = torch.randint(0, 10, (n,), dtype=torch.float32).cuda() + C = torch.zeros(n, dtype=torch.float32).cuda() + tA = flyc.from_dlpack(A).mark_layout_dynamic(leading_dim=0, divisibility=4) + add(tA, B, C, n, stream=torch.cuda.Stream()) + torch.cuda.synchronize() + + assert torch.allclose(C, A + B) + + +def test_codegen_unknown_mir_pass_fails(mir_pass_plugin, monkeypatch): + """Negative: naming the MIR pass without loading the plugin must fail in + fly-llc (unknown pass) — proving the plugin provides the pass.""" + monkeypatch.setenv("FLYDSL_RUNTIME_ENABLE_CACHE", "0") + add = _make_add_jit(llvm_codegen_passes=["fly-mir-pass"]) # no plugins + + n = 64 + A = torch.zeros(n, dtype=torch.float32) + B = torch.zeros(n, dtype=torch.float32) + C = torch.zeros(n, dtype=torch.float32) + if _gpu_available(): + A, B, C = A.cuda(), B.cuda(), C.cuda() + tA = flyc.from_dlpack(A).mark_layout_dynamic(leading_dim=0, divisibility=4) + + with pytest.raises(ExternalLLVMError) as excinfo: + add(tA, B, C, n, stream=torch.cuda.Stream() if _gpu_available() else fx.Stream(None)) + assert "fly-mir-pass" in str(excinfo.value) + + +def _max_entry_nop_run(disasm: str) -> int: + """Largest run of leading ``s_nop`` immediately after any ``:`` label. + + Anchoring at the function entry avoids counting trailing alignment padding + (which the disassembler also renders as ``s_nop``) — the injected sled lands + at the entry, so this isolates the pass's effect.""" + import re + + lines = disasm.splitlines() + best = 0 + for i, ln in enumerate(lines): + if re.match(r"^[0-9a-fA-F]+ <.+>:", ln): + run = 0 + for j in range(i + 1, len(lines)): + if "s_nop" in lines[j]: + run += 1 + elif lines[j].strip() == "": + continue + else: + break + best = max(best, run) + return best + + +def _disasm(objdump: Path, mcpu: str, obj: Path) -> str: + return subprocess.check_output([str(objdump), "-d", f"--mcpu={mcpu}", str(obj)], text=True) + + +def _unescape_mlir_bytes(s: str) -> bytes: + """Decode an MLIR string-attribute body (``\\XX`` hex + ``\\\\`` / ``\\"``).""" + out = bytearray() + i, n = 0, len(s) + while i < n: + if s[i] == "\\": + nxt = s[i + 1] + if nxt in ("\\", '"'): + out.append(ord(nxt)) + i += 2 + elif nxt == "n": + out.append(0x0A) + i += 2 + elif nxt == "t": + out.append(0x09) + i += 2 + else: + out.append(int(s[i + 1 : i + 3], 16)) + i += 3 + else: + out.append(ord(s[i])) + i += 1 + return bytes(out) + + +def _extract_gpu_binary(dump_dir: Path) -> bytes: + """Pull the embedded device HSACO bytes out of a dumped ``gpu.binary`` op.""" + import re + + for mlir in sorted(dump_dir.rglob("*.mlir")): + txt = mlir.read_text(encoding="utf-8", errors="replace") + if "gpu.binary" not in txt: + continue + m = re.search(r'bin = "((?:[^"\\]|\\.)*)"', txt, re.S) + if m: + return _unescape_mlir_bytes(m.group(1)) + raise AssertionError(f"no gpu.binary found under {dump_dir}") + + +@pytest.mark.skipif(not _gpu_available(), reason="requires a ROCm GPU") +def test_codegen_pass_modifies_asm(nop_pass_plugin, monkeypatch, tmp_path): + """A codegen pass can change the emitted ASM. Compile the same kernel through + the JIT twice — with the s_nop-inserting pass and without it — and disassemble + the device binary each produced. Only the with-pass binary begins with the + NOP_PER_FUNC sled at the function entry, proving the pass modified the ASM.""" + objdump = _resolve_tool("FLYDSL_COMPILE_LLVM_OBJDUMP", "llvm-objdump") + if objdump is None: + pytest.skip("llvm-objdump not found; set FLYDSL_COMPILE_LLVM_OBJDUMP or place it in /bin") + + from flydsl.compiler.backends.rocm import RocmBackend + + mcpu = RocmBackend.detect_target().arch + monkeypatch.setenv("FLYDSL_RUNTIME_ENABLE_CACHE", "0") + monkeypatch.setenv("FLYDSL_DUMP_IR", "1") + + def _run(add, dump: Path): + monkeypatch.setenv("FLYDSL_DUMP_DIR", str(dump)) + n = 64 + A = torch.randint(0, 10, (n,), dtype=torch.float32).cuda() + B = torch.randint(0, 10, (n,), dtype=torch.float32).cuda() + C = torch.zeros(n, dtype=torch.float32).cuda() + tA = flyc.from_dlpack(A).mark_layout_dynamic(leading_dim=0, divisibility=4) + add(tA, B, C, n, stream=torch.cuda.Stream()) + torch.cuda.synchronize() + assert torch.allclose(C, A + B) # s_nop is harmless: result still correct + + mod_dump = tmp_path / "mod" + base_dump = tmp_path / "base" + _run(_make_add_jit(llvm_codegen_passes=["fly-insert-nop"], llvm_codegen_plugins=[nop_pass_plugin]), mod_dump) + _run(_make_add_jit(), base_dump) # baseline: same kernel, no codegen pass + + mod_hsaco = next(mod_dump.rglob("fly_llc.hsaco"), None) + assert mod_hsaco is not None, "fly-llc HSACO dump not found" + base_hsaco = tmp_path / "base.hsaco" + base_hsaco.write_bytes(_extract_gpu_binary(base_dump)) + + # The pass injects an NOP_PER_FUNC sled at each function entry; the baseline + # kernel does not begin with s_nop. (Total s_nop count / object size are NOT + # reliable: the entry sled merely displaces trailing alignment-padding nops.) + base_run = _max_entry_nop_run(_disasm(objdump, mcpu, base_hsaco)) + mod_run = _max_entry_nop_run(_disasm(objdump, mcpu, mod_hsaco)) + + assert mod_run > base_run, f"codegen pass did not modify the ASM: entry s_nop run base={base_run} mod={mod_run}" diff --git a/tests/unit/test_llvm_pass_pipeline.py b/tests/unit/test_llvm_pass_pipeline.py index cd7a71c9e..0fd7e99a8 100644 --- a/tests/unit/test_llvm_pass_pipeline.py +++ b/tests/unit/test_llvm_pass_pipeline.py @@ -5,8 +5,8 @@ import flydsl.compiler as flyc from flydsl.compiler.backends.rocm import RocmBackend -from flydsl.compiler.external_llvm import llvm_opt_fingerprint -from flydsl.compiler.jit_function import _effective_llvm_pass_config +from flydsl.compiler.external_llvm import fly_llc_codegen_fingerprint, llvm_opt_fingerprint +from flydsl.compiler.jit_function import _effective_llvm_codegen_config, _effective_llvm_pass_config def test_recodegen_fragments_use_opt_level_zero_by_default(): @@ -73,3 +73,40 @@ def test_fingerprint_changes_with_pipeline_and_plugins(tmp_path): def test_fingerprint_tolerates_missing_plugin(): fp = llvm_opt_fingerprint("default", ["/does/not/exist.so"]) assert "" in fp + + +# --- custom MIR codegen (fly-llc) plumbing --------------------------------- + + +def test_jit_decorator_records_llvm_codegen_hints(): + @flyc.jit(llvm_codegen_passes=["fly-mir-pass"], llvm_codegen_plugins=["/tmp/libFlyMir.so"]) + def f(): # pragma: no cover - never executed + pass + + assert f.compile_hints["llvm_codegen_passes"] == ["fly-mir-pass"] + assert f.compile_hints["llvm_codegen_plugins"] == ["/tmp/libFlyMir.so"] + + +def test_effective_codegen_config_prefers_hints_over_env(monkeypatch): + monkeypatch.setenv("FLYDSL_COMPILE_LLVM_CODEGEN_PASSES", "env-pass") + monkeypatch.setenv("FLYDSL_COMPILE_LLVM_CODEGEN_PLUGINS", "/env/a.so:/env/b.so") + + passes, plugins = _effective_llvm_codegen_config({"llvm_codegen_passes": ["h"], "llvm_codegen_plugins": ["/h.so"]}) + assert passes == ["h"] + assert plugins == ["/h.so"] + + passes, plugins = _effective_llvm_codegen_config({}) + assert passes == ["env-pass"] + assert plugins == ["/env/a.so", "/env/b.so"] + + +def test_codegen_fingerprint_changes_with_passes_and_plugins(tmp_path): + assert fly_llc_codegen_fingerprint(["a"]) != fly_llc_codegen_fingerprint(["b"]) + + so = tmp_path / "libP.so" + so.write_bytes(b"v1") + fp1 = fly_llc_codegen_fingerprint(["a"], [str(so)]) + so.write_bytes(b"v2-changed") + fp2 = fly_llc_codegen_fingerprint(["a"], [str(so)]) + assert fp1 != fp2 # plugin content edit invalidates + assert str(so) in fp1 diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index fe9953eb0..a1cb5e4b6 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(fly-opt) +add_subdirectory(fly-llc) diff --git a/tools/fly-llc/CMakeLists.txt b/tools/fly-llc/CMakeLists.txt new file mode 100644 index 000000000..7e633b51a --- /dev/null +++ b/tools/fly-llc/CMakeLists.txt @@ -0,0 +1,21 @@ +# fly-llc: IR -> object emitter with injectable MIR passes (codegen plugin path). +add_llvm_tool(fly-llc + fly-llc.cpp + EXPORT_SYMBOLS # export symbols so dlopen'd MIR pass plugins resolve LLVM symbols +) + +llvm_map_components_to_libnames(fly_llc_libs + AllTargetsCodeGens + AllTargetsAsmParsers + AllTargetsDescs + AllTargetsInfos + CodeGen + Core + IRReader + MC + Passes + Support + Target + TargetParser +) +target_link_libraries(fly-llc PRIVATE ${fly_llc_libs} ${CMAKE_DL_LIBS}) diff --git a/tools/fly-llc/fly-llc.cpp b/tools/fly-llc/fly-llc.cpp new file mode 100644 index 000000000..cc206f3f4 --- /dev/null +++ b/tools/fly-llc/fly-llc.cpp @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2025 FlyDSL Project Contributors +// +// fly-llc: minimal LLVM IR -> object emitter that allows injecting custom +// MachineFunction (MIR) passes into the standard codegen pipeline, at the +// pre-emit slot. This is the piece MLIR's `gpu-module-to-binary` does not +// expose: it mirrors `CodeGenTargetMachineImpl::addPassesToEmitFile`, but adds +// named legacy MIR passes (loaded from `--load` plugins) after +// `addMachinePasses()` and before the asm printer. +// +// Usage: +// fly-llc -o -mtriple=... -mcpu=... \ +// [--load=lib.so ...] [--pre-emit-pass=name ...] + +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/CodeGen/CodeGenTargetMachineImpl.h" +#include "llvm/CodeGen/MachineModuleInfo.h" +#include "llvm/CodeGen/Passes.h" +#include "llvm/CodeGen/TargetPassConfig.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/MC/TargetRegistry.h" +#include "llvm/Pass.h" +#include "llvm/PassRegistry.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Target/TargetMachine.h" + +#include + +using namespace llvm; + +static cl::opt InputFile(cl::Positional, cl::Required, cl::desc("")); +static cl::opt OutputFile("o", cl::Required, cl::desc("output object file")); +static cl::opt MTriple("mtriple", cl::init("amdgcn-amd-amdhsa"), + cl::desc("target triple")); +static cl::opt MCPU("mcpu", cl::init(""), cl::desc("target cpu (e.g. gfx942)")); +static cl::list LoadLib("load", + cl::desc("dlopen a legacy MIR pass plugin .so (repeatable)")); +static cl::list + PreEmitPass("pre-emit-pass", cl::desc("named MIR pass to insert pre-emit (repeatable)")); + +int main(int argc, char **argv) { + InitLLVM X(argc, argv); + InitializeAllTargets(); + InitializeAllTargetMCs(); + InitializeAllAsmPrinters(); + InitializeAllAsmParsers(); + cl::ParseCommandLineOptions(argc, argv, "fly-llc: IR -> object with injectable MIR passes\n"); + + // Legacy pass plugins self-register into the global PassRegistry on load. + for (auto &lib : LoadLib) + if (!dlopen(lib.c_str(), RTLD_NOW | RTLD_GLOBAL)) { + errs() << "fly-llc: dlopen failed: " << dlerror() << "\n"; + return 1; + } + + LLVMContext Ctx; + SMDiagnostic Err; + std::unique_ptr M = parseIRFile(InputFile, Err, Ctx); + if (!M) { + Err.print(argv[0], errs()); + return 1; + } + + Triple TT(MTriple); + std::string E; + const Target *T = TargetRegistry::lookupTarget(TT, E); + if (!T) { + errs() << "fly-llc: " << E << "\n"; + return 1; + } + TargetOptions O; + std::unique_ptr TM( + T->createTargetMachine(TT, MCPU, "", O, Reloc::PIC_, std::nullopt, CodeGenOptLevel::Default)); + M->setDataLayout(TM->createDataLayout()); + M->setTargetTriple(TT); + + // Replicate addPassesToEmitFile so a custom MIR pass can be injected after + // the standard machine passes and before the asm printer. + auto &CG = static_cast(*TM); + legacy::PassManager PM; + auto *MMIWP = new MachineModuleInfoWrapperPass(&CG); + TargetPassConfig *PC = CG.createPassConfig(PM); + PC->setDisableVerify(true); + PM.add(PC); + PM.add(MMIWP); + TargetLibraryInfoImpl TLII(TT); + PM.add(new TargetLibraryInfoWrapperPass(TLII)); + if (PC->addISelPasses()) { + errs() << "fly-llc: addISelPasses failed\n"; + return 1; + } + PC->addMachinePasses(); + for (auto &name : PreEmitPass) { + const PassInfo *PI = PassRegistry::getPassRegistry()->getPassInfo(StringRef(name)); + if (!PI || !PI->getNormalCtor()) { + errs() << "fly-llc: unknown MIR pass: " << name << "\n"; + return 1; + } + PM.add(PI->getNormalCtor()()); + } + PC->setInitialized(); + + std::error_code EC; + raw_fd_ostream Out(OutputFile, EC, sys::fs::OF_None); + if (EC) { + errs() << "fly-llc: " << EC.message() << "\n"; + return 1; + } + if (CG.addAsmPrinter(PM, Out, nullptr, CodeGenFileType::ObjectFile, + MMIWP->getMMI().getContext())) { + errs() << "fly-llc: addAsmPrinter failed (object emission unsupported)\n"; + return 1; + } + PM.add(createFreeMachineFunctionPass()); + PM.run(*M); + Out.flush(); + return 0; +} From e8275ddc666b03c4d6d8b1af6079801ecd9b38c1 Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Wed, 3 Jun 2026 05:45:23 +0000 Subject: [PATCH 3/5] test(compiler): Add e2e tests for ASM-modifying codegen passes Signed-off-by: fsx950223 --- tests/kernels/test_llvm_codegen_pass_e2e.py | 147 ++++++++++++++++++-- 1 file changed, 134 insertions(+), 13 deletions(-) diff --git a/tests/kernels/test_llvm_codegen_pass_e2e.py b/tests/kernels/test_llvm_codegen_pass_e2e.py index 02f070d8d..e0828af10 100644 --- a/tests/kernels/test_llvm_codegen_pass_e2e.py +++ b/tests/kernels/test_llvm_codegen_pass_e2e.py @@ -88,6 +88,64 @@ static RegisterPass X("fly-insert-nop", "Fly insert NOP", false, false); """ +# A MIR pass that *schedules* (reorders) instructions: within each block it swaps +# adjacent instructions whenever that is provably safe — neither has memory/side +# effects and no def of one overlaps any register operand (def or use, explicit +# or implicit) of the other. Semantics are preserved (results stay correct) but +# the emitted instruction order changes. +PLUGIN_SRC_REORDER = r""" +#include "llvm/CodeGen/MachineBasicBlock.h" +#include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineFunctionPass.h" +#include "llvm/CodeGen/TargetInstrInfo.h" +#include "llvm/CodeGen/TargetRegisterInfo.h" +#include "llvm/CodeGen/TargetSubtargetInfo.h" +#include "llvm/Pass.h" +using namespace llvm; +namespace { +static bool unsafe(const MachineInstr &MI) { + return MI.mayLoadOrStore() || MI.hasUnmodeledSideEffects() || MI.isCall() || + MI.isTerminator() || MI.isBranch() || MI.isInlineAsm() || MI.isMetaInstruction(); +} +static bool canSwap(const MachineInstr &A, const MachineInstr &B, const TargetRegisterInfo *TRI) { + if (unsafe(A) || unsafe(B)) return false; + auto defConflicts = [&](const MachineInstr &X, const MachineInstr &Y) { + for (const MachineOperand &dx : X.operands()) { + if (!dx.isReg() || !dx.getReg() || !dx.isDef()) continue; + for (const MachineOperand &oy : Y.operands()) + if (oy.isReg() && oy.getReg() && TRI->regsOverlap(dx.getReg(), oy.getReg())) return true; + } + return false; + }; + return !defConflicts(A, B) && !defConflicts(B, A); +} +struct FlyReorderPass : public MachineFunctionPass { + static char ID; + FlyReorderPass() : MachineFunctionPass(ID) {} + bool runOnMachineFunction(MachineFunction &MF) override { + const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo(); + bool changed = false; + for (MachineBasicBlock &MBB : MF) { + for (auto it = MBB.begin(); it != MBB.end();) { + auto nxt = std::next(it); + if (nxt != MBB.end() && canSwap(*it, *nxt, TRI)) { + MBB.splice(it, &MBB, nxt); // move B before A + changed = true; + it = std::next(it); // it still == A; skip past the swapped pair + } else { + ++it; + } + } + } + return changed; + } + StringRef getPassName() const override { return "Fly reorder MIR pass"; } +}; +char FlyReorderPass::ID = 0; +} // namespace +static RegisterPass X("fly-reorder", "Fly reorder", false, false); +""" + def _gpu_available() -> bool: try: @@ -147,6 +205,12 @@ def nop_pass_plugin(tmp_path_factory) -> str: return _build_codegen_plugin(tmp_path_factory, src=PLUGIN_SRC_NOP, name="FlyNop") +@pytest.fixture(scope="module") +def reorder_pass_plugin(tmp_path_factory) -> str: + """Compile the instruction-reordering (scheduling) MIR pass plugin.""" + return _build_codegen_plugin(tmp_path_factory, src=PLUGIN_SRC_REORDER, name="FlyReorder") + + @flyc.kernel def _add_kernel(A: fx.Tensor, B: fx.Tensor, C: fx.Tensor, block_dim: fx.Constexpr[int]): bid = fx.block_idx.x @@ -290,6 +354,37 @@ def _extract_gpu_binary(dump_dir: Path) -> bytes: raise AssertionError(f"no gpu.binary found under {dump_dir}") +def _jit_run_add(add, dump: Path, monkeypatch) -> None: + """Compile+run the add kernel into *dump*, asserting the result is correct.""" + monkeypatch.setenv("FLYDSL_DUMP_DIR", str(dump)) + n = 64 + A = torch.randint(0, 10, (n,), dtype=torch.float32).cuda() + B = torch.randint(0, 10, (n,), dtype=torch.float32).cuda() + C = torch.zeros(n, dtype=torch.float32).cuda() + tA = flyc.from_dlpack(A).mark_layout_dynamic(leading_dim=0, divisibility=4) + add(tA, B, C, n, stream=torch.cuda.Stream()) + torch.cuda.synchronize() + assert torch.allclose(C, A + B) # the codegen pass must preserve correctness + + +def _kernel_instr_seq(disasm: str) -> list: + """Ordered list of disassembled instructions (mnemonic + operands, with the + trailing ``// addr: encoding`` comment stripped) across all functions.""" + import re + + seq = [] + in_func = False + for ln in disasm.splitlines(): + if re.match(r"^[0-9a-fA-F]+ <.+>:", ln): + in_func = True + continue + if in_func and "\t" in ln: + ins = ln.split("//")[0].strip() + if ins and not ins.startswith("."): + seq.append(ins) + return seq + + @pytest.mark.skipif(not _gpu_available(), reason="requires a ROCm GPU") def test_codegen_pass_modifies_asm(nop_pass_plugin, monkeypatch, tmp_path): """A codegen pass can change the emitted ASM. Compile the same kernel through @@ -306,21 +401,11 @@ def test_codegen_pass_modifies_asm(nop_pass_plugin, monkeypatch, tmp_path): monkeypatch.setenv("FLYDSL_RUNTIME_ENABLE_CACHE", "0") monkeypatch.setenv("FLYDSL_DUMP_IR", "1") - def _run(add, dump: Path): - monkeypatch.setenv("FLYDSL_DUMP_DIR", str(dump)) - n = 64 - A = torch.randint(0, 10, (n,), dtype=torch.float32).cuda() - B = torch.randint(0, 10, (n,), dtype=torch.float32).cuda() - C = torch.zeros(n, dtype=torch.float32).cuda() - tA = flyc.from_dlpack(A).mark_layout_dynamic(leading_dim=0, divisibility=4) - add(tA, B, C, n, stream=torch.cuda.Stream()) - torch.cuda.synchronize() - assert torch.allclose(C, A + B) # s_nop is harmless: result still correct - mod_dump = tmp_path / "mod" base_dump = tmp_path / "base" - _run(_make_add_jit(llvm_codegen_passes=["fly-insert-nop"], llvm_codegen_plugins=[nop_pass_plugin]), mod_dump) - _run(_make_add_jit(), base_dump) # baseline: same kernel, no codegen pass + add_mod = _make_add_jit(llvm_codegen_passes=["fly-insert-nop"], llvm_codegen_plugins=[nop_pass_plugin]) + _jit_run_add(add_mod, mod_dump, monkeypatch) + _jit_run_add(_make_add_jit(), base_dump, monkeypatch) # baseline: no codegen pass mod_hsaco = next(mod_dump.rglob("fly_llc.hsaco"), None) assert mod_hsaco is not None, "fly-llc HSACO dump not found" @@ -334,3 +419,39 @@ def _run(add, dump: Path): mod_run = _max_entry_nop_run(_disasm(objdump, mcpu, mod_hsaco)) assert mod_run > base_run, f"codegen pass did not modify the ASM: entry s_nop run base={base_run} mod={mod_run}" + + +@pytest.mark.skipif(not _gpu_available(), reason="requires a ROCm GPU") +def test_codegen_pass_reorders_instructions(reorder_pass_plugin, mir_pass_plugin, monkeypatch, tmp_path): + """A custom codegen *scheduling* pass can reorder instructions. Both sides go + through the same fly-llc codegen driver (the baseline uses the no-op print + pass, so the *only* difference is the reordering — not regalloc/ISel that would + differ across codegen drivers). The reorder run must (a) keep results correct, + (b) emit the *same multiset* of instructions (pure reorder), yet (c) in a + *different order*.""" + objdump = _resolve_tool("FLYDSL_COMPILE_LLVM_OBJDUMP", "llvm-objdump") + if objdump is None: + pytest.skip("llvm-objdump not found; set FLYDSL_COMPILE_LLVM_OBJDUMP or place it in /bin") + + from flydsl.compiler.backends.rocm import RocmBackend + + mcpu = RocmBackend.detect_target().arch + monkeypatch.setenv("FLYDSL_RUNTIME_ENABLE_CACHE", "0") + monkeypatch.setenv("FLYDSL_DUMP_IR", "1") + + mod_dump = tmp_path / "mod" + base_dump = tmp_path / "base" + add_mod = _make_add_jit(llvm_codegen_passes=["fly-reorder"], llvm_codegen_plugins=[reorder_pass_plugin]) + add_base = _make_add_jit(llvm_codegen_passes=["fly-mir-pass"], llvm_codegen_plugins=[mir_pass_plugin]) + _jit_run_add(add_mod, mod_dump, monkeypatch) # correctness preserved under reordering + _jit_run_add(add_base, base_dump, monkeypatch) # same fly-llc driver, no-op pass + + mod_hsaco = next(mod_dump.rglob("fly_llc.hsaco"), None) + base_hsaco = next(base_dump.rglob("fly_llc.hsaco"), None) + assert mod_hsaco is not None and base_hsaco is not None, "fly-llc HSACO dump not found" + + base_seq = _kernel_instr_seq(_disasm(objdump, mcpu, base_hsaco)) + mod_seq = _kernel_instr_seq(_disasm(objdump, mcpu, mod_hsaco)) + assert base_seq and mod_seq, "no instructions disassembled" + assert sorted(base_seq) == sorted(mod_seq), "reorder must not add or remove instructions" + assert base_seq != mod_seq, "scheduling pass did not change instruction order" From da39572fe9d885fa3c67cc8aacceab280dc39991 Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Thu, 4 Jun 2026 08:32:16 +0000 Subject: [PATCH 4/5] feat(compiler): Add fly-llc multi-stage MIR pass insertion points Signed-off-by: fsx950223 --- python/flydsl/compiler/external_llvm.py | 18 +++-- python/flydsl/compiler/jit_function.py | 38 ++++++--- python/flydsl/utils/env.py | 7 ++ tests/kernels/test_llvm_codegen_pass_e2e.py | 88 ++++++++++++++++++++- tests/unit/test_llvm_pass_pipeline.py | 18 ++++- tools/fly-llc/fly-llc.cpp | 41 +++++++++- 6 files changed, 185 insertions(+), 25 deletions(-) diff --git a/python/flydsl/compiler/external_llvm.py b/python/flydsl/compiler/external_llvm.py index 338486ee1..4bd7f4fa4 100644 --- a/python/flydsl/compiler/external_llvm.py +++ b/python/flydsl/compiler/external_llvm.py @@ -365,10 +365,12 @@ def _lld_path() -> Path: ) -def fly_llc_codegen_fingerprint(passes: Optional[list] = None, plugins: Optional[list] = None) -> str: +def fly_llc_codegen_fingerprint( + passes: Optional[list] = None, plugins: Optional[list] = None, insert_after: Optional[list] = None +) -> str: """Cache fingerprint for a fly-llc codegen configuration: the pass names plus the fly-llc binary's and each plugin's content hash.""" - parts = ["fly-llc-codegen:" + ",".join(passes or [])] + parts = ["fly-llc-codegen:" + ",".join(passes or []) + "|after:" + ",".join(insert_after or [])] try: parts.append(_file_hash(_fly_llc_path().resolve())) except OSError: @@ -399,8 +401,9 @@ def run_fly_llc_codegen( module: ir.Module, *, llvm_ir: str, - codegen_passes: list, + codegen_passes: Optional[list] = None, codegen_plugins: Optional[list] = None, + codegen_insert_after: Optional[list] = None, target_triple: str, target_cpu: str, work_dir: Optional[Path] = None, @@ -409,9 +412,10 @@ def run_fly_llc_codegen( """Codegen the device kernel's LLVM IR with injectable MIR passes and splice the result back into *module*. - Flow: ``fly-llc -o --load= --pre-emit-pass=`` - (custom MIR passes run pre-emit in the standard codegen) -> ``ld.lld -shared`` - -> wrap the HSACO bytes into a ``gpu.binary`` -> replace the in-process + Flow: ``fly-llc -o --load= [--pre-emit-pass=] + [--insert-after==]`` (custom MIR passes run inside the standard + codegen — pre-emit and/or at named earlier stages) -> ``ld.lld -shared`` -> + wrap the HSACO bytes into a ``gpu.binary`` -> replace the in-process ``gpu.module``. """ fly_llc = _fly_llc_path() @@ -438,6 +442,7 @@ def run_fly_llc_codegen( plugin_args = [f"--load={Path(p).expanduser()}" for p in (codegen_plugins or [])] pass_args = [f"--pre-emit-pass={n}" for n in (codegen_passes or [])] + insert_after_args = [f"--insert-after={spec}" for spec in (codegen_insert_after or [])] _run_tool( [ str(fly_llc), @@ -448,6 +453,7 @@ def run_fly_llc_codegen( f"-mcpu={target_cpu}", *plugin_args, *pass_args, + *insert_after_args, ], prefix=prefix, what="fly-llc codegen", diff --git a/python/flydsl/compiler/jit_function.py b/python/flydsl/compiler/jit_function.py index 826a9f2f3..9e4dabdc9 100644 --- a/python/flydsl/compiler/jit_function.py +++ b/python/flydsl/compiler/jit_function.py @@ -126,6 +126,7 @@ def _create_mlir_context(*, load_dialects=True): "FLYDSL_COMPILE_LLVM_PASS_PLUGINS", "FLYDSL_COMPILE_LLVM_CODEGEN_PASSES", "FLYDSL_COMPILE_LLVM_CODEGEN_PLUGINS", + "FLYDSL_COMPILE_LLVM_CODEGEN_INSERT_AFTER", ) @@ -753,6 +754,7 @@ class PipelineConfig: llvm_pass_plugins: Optional[list] = None llvm_codegen_passes: Optional[list] = None llvm_codegen_plugins: Optional[list] = None + llvm_codegen_insert_after: Optional[list] = None def _effective_llvm_pass_config(hints: dict): @@ -776,7 +778,10 @@ def _effective_llvm_codegen_config(hints: dict): plugins = hints.get("llvm_codegen_plugins") if plugins is None: plugins = env.compile.llvm_codegen_plugins - return list(passes or []), list(plugins or []) + insert_after = hints.get("llvm_codegen_insert_after") + if insert_after is None: + insert_after = env.compile.llvm_codegen_insert_after + return list(passes or []), list(plugins or []), list(insert_after or []) def _pipeline_fragments_for_mode(backend) -> PipelineConfig: @@ -786,12 +791,12 @@ def _pipeline_fragments_for_mode(backend) -> PipelineConfig: hints = CompilationContext.get_compile_hints() llvm_opts = hints.get("llvm_options") llvm_pass_pipeline, llvm_pass_plugins = _effective_llvm_pass_config(hints) - llvm_codegen_passes, llvm_codegen_plugins = _effective_llvm_codegen_config(hints) + llvm_codegen_passes, llvm_codegen_plugins, llvm_codegen_insert_after = _effective_llvm_codegen_config(hints) # Custom LLVM IR pass pipeline (opt) and/or custom MIR codegen passes (fly-llc): # both need the binary fragment split off so we can extract LLVM IR and run an # external codegen tail (see MlirCompiler.compile). - if llvm_pass_pipeline or llvm_codegen_passes: + if llvm_pass_pipeline or llvm_codegen_passes or llvm_codegen_insert_after: pre_binary_fragments, binary_fragment = backend.external_binary_pipeline_fragments(compile_hints=hints) return PipelineConfig( fragments=[*pre_binary_fragments, binary_fragment], @@ -803,6 +808,7 @@ def _pipeline_fragments_for_mode(backend) -> PipelineConfig: llvm_pass_plugins=llvm_pass_plugins, llvm_codegen_passes=llvm_codegen_passes, llvm_codegen_plugins=llvm_codegen_plugins, + llvm_codegen_insert_after=llvm_codegen_insert_after, ) if _use_external_binary_codegen(): @@ -861,8 +867,8 @@ def compile( "use embedded codegen for kernels that require #fly.explicit_module." ) - if (cfg.llvm_pass_pipeline or cfg.llvm_codegen_passes) and link_libs: - raise RuntimeError("custom llvm_pass_pipeline / llvm_codegen_passes do not support extern link_libs yet.") + if (cfg.llvm_pass_pipeline or cfg.llvm_codegen_passes or cfg.llvm_codegen_insert_after) and link_libs: + raise RuntimeError("custom llvm_pass_pipeline / llvm_codegen_* do not support extern link_libs yet.") if link_libs: link_opt = _format_link_lib_options(link_libs) @@ -881,7 +887,7 @@ def compile( dump_dir = Path(env.debug.dump_dir).resolve() with _llvm_ctx: - if cfg.llvm_codegen_passes: + if cfg.llvm_codegen_passes or cfg.llvm_codegen_insert_after: return cls._compile_with_fly_llc( module, backend, cfg, func_name=func_name, dump_enabled=dump_enabled, dump_dir=dump_dir ) @@ -1077,7 +1083,7 @@ def _compile_with_fly_llc( llvm_ir = _extract_llvm_ir(module) if llvm_ir is None: raise FlyDSLCompileError( - "llvm_codegen_passes is set but the device LLVM IR could not be extracted from the gpu.module." + "llvm_codegen_* is set but the device LLVM IR could not be extracted from the gpu.module." ) rocdl_opts = backend._rocdl_opts(compile_hints=hints) @@ -1086,6 +1092,7 @@ def _compile_with_fly_llc( llvm_ir=llvm_ir, codegen_passes=cfg.llvm_codegen_passes, codegen_plugins=cfg.llvm_codegen_plugins, + codegen_insert_after=cfg.llvm_codegen_insert_after, target_triple=rocdl_opts["triple"], target_cpu=rocdl_opts["chip"], work_dir=work_dir, @@ -1508,11 +1515,13 @@ def _resolve_and_make_cache_key(self, bound_args): from .external_llvm import llvm_opt_fingerprint key_parts.append(("_llvm_pass_", llvm_opt_fingerprint(eff_pipeline, eff_plugins))) - eff_cg_passes, eff_cg_plugins = _effective_llvm_codegen_config(self.compile_hints) - if eff_cg_passes: + eff_cg_passes, eff_cg_plugins, eff_cg_insert_after = _effective_llvm_codegen_config(self.compile_hints) + if eff_cg_passes or eff_cg_insert_after: from .external_llvm import fly_llc_codegen_fingerprint - key_parts.append(("_llvm_codegen_", fly_llc_codegen_fingerprint(eff_cg_passes, eff_cg_plugins))) + key_parts.append( + ("_llvm_codegen_", fly_llc_codegen_fingerprint(eff_cg_passes, eff_cg_plugins, eff_cg_insert_after)) + ) for name, arg in bound_args.items(): param = sig.parameters.get(name) @@ -1837,6 +1846,7 @@ def jit( llvm_pass_plugins: Optional[list] = None, llvm_codegen_passes: Optional[list] = None, llvm_codegen_plugins: Optional[list] = None, + llvm_codegen_insert_after: Optional[list] = None, ) -> JitFunction: """JIT decorator for host launcher functions. @@ -1849,6 +1859,12 @@ def jit( via the ``fly-llc`` codegen path (``fly-llc --load`` + ``--pre-emit-pass``, then ``ld.lld``). + ``llvm_codegen_insert_after``: optional ``"ANCHOR=PASS"`` entries that inject a + MIR pass right after the named codegen pass ANCHOR (``fly-llc --insert-after``), + reaching earlier pipeline stages than pre-emit (e.g. + ``"machine-scheduler=my-pass"`` for pre-RA, ``"virtregrewriter=my-pass"`` for + post-RA). + All require ``FLYDSL_COMPILE_LLVM_DIR`` and override the corresponding ``FLYDSL_COMPILE_LLVM_*`` env vars. """ @@ -1861,6 +1877,8 @@ def jit( hints["llvm_codegen_passes"] = list(llvm_codegen_passes) if llvm_codegen_plugins is not None: hints["llvm_codegen_plugins"] = list(llvm_codegen_plugins) + if llvm_codegen_insert_after is not None: + hints["llvm_codegen_insert_after"] = list(llvm_codegen_insert_after) def _make(f: Callable) -> JitFunction: return JitFunction(f, compile_hints=hints or None) diff --git a/python/flydsl/utils/env.py b/python/flydsl/utils/env.py index 0f2e571f1..adb568c90 100644 --- a/python/flydsl/utils/env.py +++ b/python/flydsl/utils/env.py @@ -254,6 +254,13 @@ class CompileEnvManager(EnvManager): description="Colon-separated legacy MIR pass plugin .so paths loaded by fly-llc " "(fly-llc --load). Overridden by @flyc.jit(llvm_codegen_plugins=...).", ) + llvm_codegen_insert_after = OptList( + [], + description="Comma-separated 'ANCHOR=PASS' entries inserting MIR pass PASS right after " + "codegen pass ANCHOR (fly-llc --insert-after), reaching earlier stages than pre-emit " + "(e.g. 'machine-scheduler=my-pass' for pre-RA). Overridden by " + "@flyc.jit(llvm_codegen_insert_after=...).", + ) fly_llc = OptStr("", description="Path to the fly-llc tool (default: /bin/fly-llc).") lld = OptStr("", description="Path to ld.lld for the fly-llc codegen path (default: /bin/ld.lld).") diff --git a/tests/kernels/test_llvm_codegen_pass_e2e.py b/tests/kernels/test_llvm_codegen_pass_e2e.py index e0828af10..19fd2d31c 100644 --- a/tests/kernels/test_llvm_codegen_pass_e2e.py +++ b/tests/kernels/test_llvm_codegen_pass_e2e.py @@ -54,9 +54,10 @@ static RegisterPass X("fly-mir-pass", "Fly demo MIR pass", false, false); """ -# A MIR pass that *modifies* the machine code: inserts 8 ``s_nop`` at the entry -# of every kernel. The opcode is found by name so no AMDGPU target headers are -# needed. 8 nops/function is a distinctive, measurable change in the ASM. +# A MIR pass that *modifies* the machine code: inserts NOP_PER_FUNC ``s_nop`` at +# the entry of every kernel. The opcode is found by name so no AMDGPU target +# headers are needed; the count is a distinctive, measurable change in the ASM. +NOP_PER_FUNC = 8 PLUGIN_SRC_NOP = r""" #include "llvm/CodeGen/MachineFunction.h" #include "llvm/CodeGen/MachineFunctionPass.h" @@ -385,6 +386,27 @@ def _kernel_instr_seq(disasm: str) -> list: return seq +def _func_body_snop(disasm: str) -> int: + """Count ``s_nop`` in function bodies only (between the label and the first + ``s_endpgm``), excluding trailing alignment padding — robust even when the + injected nops get scheduled away from the entry.""" + import re + + n = 0 + in_func = False + for ln in disasm.splitlines(): + if re.match(r"^[0-9a-fA-F]+ <.+>:", ln): + in_func = True + continue + if in_func: + if "s_endpgm" in ln: + in_func = False + continue + if "s_nop" in ln: + n += 1 + return n + + @pytest.mark.skipif(not _gpu_available(), reason="requires a ROCm GPU") def test_codegen_pass_modifies_asm(nop_pass_plugin, monkeypatch, tmp_path): """A codegen pass can change the emitted ASM. Compile the same kernel through @@ -455,3 +477,63 @@ def test_codegen_pass_reorders_instructions(reorder_pass_plugin, mir_pass_plugin assert base_seq and mod_seq, "no instructions disassembled" assert sorted(base_seq) == sorted(mod_seq), "reorder must not add or remove instructions" assert base_seq != mod_seq, "scheduling pass did not change instruction order" + + +@pytest.mark.skipif(not _gpu_available(), reason="requires a ROCm GPU") +def test_codegen_insert_after_runs_at_earlier_stage(nop_pass_plugin, mir_pass_plugin, monkeypatch, tmp_path): + """`--insert-after` injects a pass at an *earlier* codegen stage than pre-emit. + Insert the nop pass right after `machine-scheduler` (pre-RA): its 8 nops/func + survive register allocation + later scheduling, so the device binary's function + body has NOP_PER_FUNC more `s_nop` than the no-op baseline (same fly-llc + driver) — proving the pass ran via the earlier injection point.""" + objdump = _resolve_tool("FLYDSL_COMPILE_LLVM_OBJDUMP", "llvm-objdump") + if objdump is None: + pytest.skip("llvm-objdump not found; set FLYDSL_COMPILE_LLVM_OBJDUMP or place it in /bin") + + from flydsl.compiler.backends.rocm import RocmBackend + + mcpu = RocmBackend.detect_target().arch + monkeypatch.setenv("FLYDSL_RUNTIME_ENABLE_CACHE", "0") + monkeypatch.setenv("FLYDSL_DUMP_IR", "1") + + mod_dump = tmp_path / "mod" + base_dump = tmp_path / "base" + add_mod = _make_add_jit( + llvm_codegen_plugins=[nop_pass_plugin], + llvm_codegen_insert_after=["machine-scheduler=fly-insert-nop"], # pre-RA + ) + add_base = _make_add_jit(llvm_codegen_passes=["fly-mir-pass"], llvm_codegen_plugins=[mir_pass_plugin]) + _jit_run_add(add_mod, mod_dump, monkeypatch) + _jit_run_add(add_base, base_dump, monkeypatch) + + mod_hsaco = next(mod_dump.rglob("fly_llc.hsaco"), None) + base_hsaco = next(base_dump.rglob("fly_llc.hsaco"), None) + assert mod_hsaco is not None and base_hsaco is not None, "fly-llc HSACO dump not found" + + base_n = _func_body_snop(_disasm(objdump, mcpu, base_hsaco)) + mod_n = _func_body_snop(_disasm(objdump, mcpu, mod_hsaco)) + delta = mod_n - base_n + assert ( + delta >= NOP_PER_FUNC and delta % NOP_PER_FUNC == 0 + ), f"pre-RA insert-after pass effect not observed: body s_nop base={base_n} mod={mod_n}" + + +def test_codegen_insert_after_unknown_anchor_fails(nop_pass_plugin, monkeypatch): + """An unknown anchor pass name is rejected by fly-llc.""" + monkeypatch.setenv("FLYDSL_RUNTIME_ENABLE_CACHE", "0") + add = _make_add_jit( + llvm_codegen_plugins=[nop_pass_plugin], + llvm_codegen_insert_after=["no-such-anchor=fly-insert-nop"], + ) + + n = 64 + A = torch.zeros(n, dtype=torch.float32) + B = torch.zeros(n, dtype=torch.float32) + C = torch.zeros(n, dtype=torch.float32) + if _gpu_available(): + A, B, C = A.cuda(), B.cuda(), C.cuda() + tA = flyc.from_dlpack(A).mark_layout_dynamic(leading_dim=0, divisibility=4) + + with pytest.raises(ExternalLLVMError) as excinfo: + add(tA, B, C, n, stream=torch.cuda.Stream() if _gpu_available() else fx.Stream(None)) + assert "no-such-anchor" in str(excinfo.value) diff --git a/tests/unit/test_llvm_pass_pipeline.py b/tests/unit/test_llvm_pass_pipeline.py index 0fd7e99a8..d1f08713f 100644 --- a/tests/unit/test_llvm_pass_pipeline.py +++ b/tests/unit/test_llvm_pass_pipeline.py @@ -79,29 +79,41 @@ def test_fingerprint_tolerates_missing_plugin(): def test_jit_decorator_records_llvm_codegen_hints(): - @flyc.jit(llvm_codegen_passes=["fly-mir-pass"], llvm_codegen_plugins=["/tmp/libFlyMir.so"]) + @flyc.jit( + llvm_codegen_passes=["fly-mir-pass"], + llvm_codegen_plugins=["/tmp/libFlyMir.so"], + llvm_codegen_insert_after=["machine-scheduler=fly-mir-pass"], + ) def f(): # pragma: no cover - never executed pass assert f.compile_hints["llvm_codegen_passes"] == ["fly-mir-pass"] assert f.compile_hints["llvm_codegen_plugins"] == ["/tmp/libFlyMir.so"] + assert f.compile_hints["llvm_codegen_insert_after"] == ["machine-scheduler=fly-mir-pass"] def test_effective_codegen_config_prefers_hints_over_env(monkeypatch): monkeypatch.setenv("FLYDSL_COMPILE_LLVM_CODEGEN_PASSES", "env-pass") monkeypatch.setenv("FLYDSL_COMPILE_LLVM_CODEGEN_PLUGINS", "/env/a.so:/env/b.so") + monkeypatch.setenv("FLYDSL_COMPILE_LLVM_CODEGEN_INSERT_AFTER", "greedy=env-pass") - passes, plugins = _effective_llvm_codegen_config({"llvm_codegen_passes": ["h"], "llvm_codegen_plugins": ["/h.so"]}) + passes, plugins, insert_after = _effective_llvm_codegen_config( + {"llvm_codegen_passes": ["h"], "llvm_codegen_plugins": ["/h.so"], "llvm_codegen_insert_after": ["a=h"]} + ) assert passes == ["h"] assert plugins == ["/h.so"] + assert insert_after == ["a=h"] - passes, plugins = _effective_llvm_codegen_config({}) + passes, plugins, insert_after = _effective_llvm_codegen_config({}) assert passes == ["env-pass"] assert plugins == ["/env/a.so", "/env/b.so"] + assert insert_after == ["greedy=env-pass"] def test_codegen_fingerprint_changes_with_passes_and_plugins(tmp_path): assert fly_llc_codegen_fingerprint(["a"]) != fly_llc_codegen_fingerprint(["b"]) + # insert-after specs participate in the fingerprint too + assert fly_llc_codegen_fingerprint([], [], ["m=a"]) != fly_llc_codegen_fingerprint([], [], ["m=b"]) so = tmp_path / "libP.so" so.write_bytes(b"v1") diff --git a/tools/fly-llc/fly-llc.cpp b/tools/fly-llc/fly-llc.cpp index cc206f3f4..4ed6ab98a 100644 --- a/tools/fly-llc/fly-llc.cpp +++ b/tools/fly-llc/fly-llc.cpp @@ -44,6 +44,17 @@ static cl::list LoadLib("load", cl::desc("dlopen a legacy MIR pass plugin .so (repeatable)")); static cl::list PreEmitPass("pre-emit-pass", cl::desc("named MIR pass to insert pre-emit (repeatable)")); +static cl::list + InsertAfter("insert-after", + cl::desc("ANCHOR=PASS: insert MIR pass PASS right after codegen pass ANCHOR " + "(both are registered pass arg-names, e.g. greedy=my-pass); repeatable. " + "This reaches earlier pipeline stages (pre/post-RA, pre-sched2, ...) that " + "the pre-emit slot cannot.")); + +// Look up a registered (legacy) pass by its command-line arg name. +static const PassInfo *findPass(StringRef name) { + return PassRegistry::getPassRegistry()->getPassInfo(name); +} int main(int argc, char **argv) { InitLLVM X(argc, argv); @@ -81,13 +92,37 @@ int main(int argc, char **argv) { M->setDataLayout(TM->createDataLayout()); M->setTargetTriple(TT); - // Replicate addPassesToEmitFile so a custom MIR pass can be injected after - // the standard machine passes and before the asm printer. + // Replicate addPassesToEmitFile so custom MIR passes can be injected into the + // codegen pipeline: --insert-after schedules passes relative to named anchor + // passes (reaching earlier stages), and --pre-emit-pass appends after the + // whole machine pipeline (just before the asm printer). auto &CG = static_cast(*TM); legacy::PassManager PM; auto *MMIWP = new MachineModuleInfoWrapperPass(&CG); TargetPassConfig *PC = CG.createPassConfig(PM); PC->setDisableVerify(true); + + // Schedule --insert-after injections BEFORE the pipeline is built; each fires + // when the pipeline adds its anchor pass (TargetPassConfig::insertPass). + for (auto &spec : InsertAfter) { + auto eq = spec.find('='); + if (eq == StringRef::npos) { + errs() << "fly-llc: --insert-after expects ANCHOR=PASS, got: " << spec << "\n"; + return 1; + } + const PassInfo *anchor = findPass(StringRef(spec).substr(0, eq)); + const PassInfo *pass = findPass(StringRef(spec).substr(eq + 1)); + if (!anchor) { + errs() << "fly-llc: unknown anchor pass: " << spec.substr(0, eq) << "\n"; + return 1; + } + if (!pass || !pass->getNormalCtor()) { + errs() << "fly-llc: unknown MIR pass: " << spec.substr(eq + 1) << "\n"; + return 1; + } + PC->insertPass(anchor->getTypeInfo(), pass->getTypeInfo()); + } + PM.add(PC); PM.add(MMIWP); TargetLibraryInfoImpl TLII(TT); @@ -98,7 +133,7 @@ int main(int argc, char **argv) { } PC->addMachinePasses(); for (auto &name : PreEmitPass) { - const PassInfo *PI = PassRegistry::getPassRegistry()->getPassInfo(StringRef(name)); + const PassInfo *PI = findPass(name); if (!PI || !PI->getNormalCtor()) { errs() << "fly-llc: unknown MIR pass: " << name << "\n"; return 1; From f5a6a63b2270a6cf283f536a36dd6d8191af4342 Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Thu, 4 Jun 2026 08:48:29 +0000 Subject: [PATCH 5/5] test(compiler): Extract custom pass plugins into standalone .cpp files Signed-off-by: fsx950223 --- .../llvm_pass_plugins/insert_nop_mir_pass.cpp | 42 ++++++ .../llvm_pass_plugins/print_mir_pass.cpp | 29 ++++ .../llvm_pass_plugins/print_tid_ir_pass.cpp | 72 +++++++++ .../llvm_pass_plugins/reorder_mir_pass.cpp | 65 ++++++++ tests/kernels/test_llvm_codegen_pass_e2e.py | 140 ++---------------- tests/kernels/test_llvm_pass_plugin_e2e.py | 74 +-------- 6 files changed, 231 insertions(+), 191 deletions(-) create mode 100644 tests/kernels/llvm_pass_plugins/insert_nop_mir_pass.cpp create mode 100644 tests/kernels/llvm_pass_plugins/print_mir_pass.cpp create mode 100644 tests/kernels/llvm_pass_plugins/print_tid_ir_pass.cpp create mode 100644 tests/kernels/llvm_pass_plugins/reorder_mir_pass.cpp diff --git a/tests/kernels/llvm_pass_plugins/insert_nop_mir_pass.cpp b/tests/kernels/llvm_pass_plugins/insert_nop_mir_pass.cpp new file mode 100644 index 000000000..410aae3a0 --- /dev/null +++ b/tests/kernels/llvm_pass_plugins/insert_nop_mir_pass.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2025 FlyDSL Project Contributors +// +// Test fixture for tests/kernels/test_llvm_codegen_pass_e2e.py (compiled at test +// time, not by CMake). +// +// Legacy MachineFunctionPass plugin registering "fly-insert-nop": inserts 8 +// `s_nop` at the entry of every kernel (count must match NOP_PER_FUNC in the +// test). The opcode is found by name so no AMDGPU target headers are needed. + +#include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineFunctionPass.h" +#include "llvm/CodeGen/MachineInstrBuilder.h" +#include "llvm/CodeGen/TargetInstrInfo.h" +#include "llvm/CodeGen/TargetSubtargetInfo.h" +#include "llvm/Pass.h" +using namespace llvm; +namespace { +struct FlyInsertNopPass : public MachineFunctionPass { + static char ID; + FlyInsertNopPass() : MachineFunctionPass(ID) {} + bool runOnMachineFunction(MachineFunction &MF) override { + const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo(); + unsigned NopOpc = ~0u; + for (unsigned i = 0, e = TII->getNumOpcodes(); i < e; ++i) + if (TII->getName(i) == "S_NOP") { + NopOpc = i; + break; + } + if (NopOpc == ~0u || MF.empty()) + return false; + MachineBasicBlock &MBB = MF.front(); + auto It = MBB.begin(); + for (int k = 0; k < 8; ++k) + BuildMI(MBB, It, DebugLoc(), TII->get(NopOpc)).addImm(0); + return true; + } + StringRef getPassName() const override { return "Fly insert NOP MIR pass"; } +}; +char FlyInsertNopPass::ID = 0; +} // namespace +static RegisterPass X("fly-insert-nop", "Fly insert NOP", false, false); diff --git a/tests/kernels/llvm_pass_plugins/print_mir_pass.cpp b/tests/kernels/llvm_pass_plugins/print_mir_pass.cpp new file mode 100644 index 000000000..795e5f65a --- /dev/null +++ b/tests/kernels/llvm_pass_plugins/print_mir_pass.cpp @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2025 FlyDSL Project Contributors +// +// Test fixture for tests/kernels/test_llvm_codegen_pass_e2e.py (compiled at test +// time, not by CMake). +// +// Legacy MachineFunctionPass plugin registering "fly-mir-pass" via RegisterPass. +// Runs pre-emit during codegen and prints the machine-function name (observable +// under `pytest -s`); a no-op otherwise, so it serves as the "same codegen +// driver" baseline for the reorder test. + +#include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineFunctionPass.h" +#include "llvm/Pass.h" +#include "llvm/Support/raw_ostream.h" +using namespace llvm; +namespace { +struct FlyMirPass : public MachineFunctionPass { + static char ID; + FlyMirPass() : MachineFunctionPass(ID) {} + bool runOnMachineFunction(MachineFunction &MF) override { + errs() << "fly-mir-pass: ran on " << MF.getName() << "\n"; + return false; + } + StringRef getPassName() const override { return "Fly demo MIR pass"; } +}; +char FlyMirPass::ID = 0; +} // namespace +static RegisterPass X("fly-mir-pass", "Fly demo MIR pass", false, false); diff --git a/tests/kernels/llvm_pass_plugins/print_tid_ir_pass.cpp b/tests/kernels/llvm_pass_plugins/print_tid_ir_pass.cpp new file mode 100644 index 000000000..7e47484f1 --- /dev/null +++ b/tests/kernels/llvm_pass_plugins/print_tid_ir_pass.cpp @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2025 FlyDSL Project Contributors +// +// Test fixture for tests/kernels/test_llvm_pass_plugin_e2e.py (compiled at test +// time, not by CMake). +// +// An LLVM new-PM IR pass plugin registering "flydsl-print-tid". At the entry of +// every amdgpu_kernel it emits FlyDSL's exact hostcall device-printf sequence +// (__ockl_printf_begin / append_string_n / append_args) printing threadIdx.x. +// Using the same ockl ABI as fx.printf means the ROCm runtime FlyDSL already +// sets up services it, and ockl is linked during the O=0 re-codegen. (The C +// printf + amdgpu-printf-runtime-binding route instead emits the buffered +// __printf_alloc path, which FlyDSL's runtime does not service.) + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/IntrinsicsAMDGPU.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Plugins/PassPlugin.h" +using namespace llvm; +namespace { +struct PrintTidPass : PassInfoMixin { + PreservedAnalyses run(Module &M, ModuleAnalysisManager &) { + LLVMContext &C = M.getContext(); + auto *i64 = Type::getInt64Ty(C); + auto *i32 = Type::getInt32Ty(C); + auto *ptr = PointerType::get(C, 0); + FunctionCallee beginF = + M.getOrInsertFunction("__ockl_printf_begin", FunctionType::get(i64, {i64}, false)); + FunctionCallee strF = M.getOrInsertFunction( + "__ockl_printf_append_string_n", FunctionType::get(i64, {i64, ptr, i64, i32}, false)); + FunctionCallee argsF = M.getOrInsertFunction( + "__ockl_printf_append_args", + FunctionType::get(i64, {i64, i32, i64, i64, i64, i64, i64, i64, i64, i32}, false)); + Function *widx = Intrinsic::getOrInsertDeclaration(&M, Intrinsic::amdgcn_workitem_id_x); + bool changed = false; + for (Function &F : M) { + if (F.isDeclaration() || F.getCallingConv() != CallingConv::AMDGPU_KERNEL) + continue; + IRBuilder<> B(&*F.getEntryBlock().getFirstInsertionPt()); + Constant *str = ConstantDataArray::getString(C, "flydsl-pass: threadIdx.x=%d\n", true); + // Format string must live in addrspace 0 (matches the ockl append ABI). + auto *gv = new GlobalVariable(M, str->getType(), true, GlobalValue::InternalLinkage, str, + "flydsl_tid_fmt", nullptr, GlobalValue::NotThreadLocal, 0); + uint64_t len = cast(str->getType())->getNumElements(); + Value *tid = B.CreateZExt(B.CreateCall(widx, {}), i64); + Value *z = ConstantInt::get(i64, 0); + Value *h0 = B.CreateCall(beginF, {z}); + Value *h1 = + B.CreateCall(strF, {h0, gv, ConstantInt::get(i64, len), ConstantInt::get(i32, 0)}); + B.CreateCall(argsF, + {h1, ConstantInt::get(i32, 1), tid, z, z, z, z, z, z, ConstantInt::get(i32, 1)}); + changed = true; + } + return changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); + } +}; +} // namespace +extern "C" LLVM_ATTRIBUTE_WEAK ::llvm::PassPluginLibraryInfo llvmGetPassPluginInfo() { + return {LLVM_PLUGIN_API_VERSION, "FlydslPrintTid", LLVM_VERSION_STRING, [](PassBuilder &PB) { + PB.registerPipelineParsingCallback( + [](StringRef N, ModulePassManager &MPM, ArrayRef) { + if (N == "flydsl-print-tid") { + MPM.addPass(PrintTidPass()); + return true; + } + return false; + }); + }}; +} diff --git a/tests/kernels/llvm_pass_plugins/reorder_mir_pass.cpp b/tests/kernels/llvm_pass_plugins/reorder_mir_pass.cpp new file mode 100644 index 000000000..5f92246e9 --- /dev/null +++ b/tests/kernels/llvm_pass_plugins/reorder_mir_pass.cpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2025 FlyDSL Project Contributors +// +// Test fixture for tests/kernels/test_llvm_codegen_pass_e2e.py (compiled at test +// time, not by CMake). +// +// Legacy MachineFunctionPass plugin registering "fly-reorder": a tiny scheduler +// that, within each block, swaps adjacent instructions whenever provably safe — +// neither has memory/side effects and no def of one overlaps any register +// operand (def or use, explicit or implicit) of the other. Semantics are +// preserved (results stay correct) but the emitted instruction order changes. + +#include "llvm/CodeGen/MachineBasicBlock.h" +#include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineFunctionPass.h" +#include "llvm/CodeGen/TargetInstrInfo.h" +#include "llvm/CodeGen/TargetRegisterInfo.h" +#include "llvm/CodeGen/TargetSubtargetInfo.h" +#include "llvm/Pass.h" +using namespace llvm; +namespace { +static bool unsafe(const MachineInstr &MI) { + return MI.mayLoadOrStore() || MI.hasUnmodeledSideEffects() || MI.isCall() || MI.isTerminator() || + MI.isBranch() || MI.isInlineAsm() || MI.isMetaInstruction(); +} +static bool canSwap(const MachineInstr &A, const MachineInstr &B, const TargetRegisterInfo *TRI) { + if (unsafe(A) || unsafe(B)) + return false; + auto defConflicts = [&](const MachineInstr &X, const MachineInstr &Y) { + for (const MachineOperand &dx : X.operands()) { + if (!dx.isReg() || !dx.getReg() || !dx.isDef()) + continue; + for (const MachineOperand &oy : Y.operands()) + if (oy.isReg() && oy.getReg() && TRI->regsOverlap(dx.getReg(), oy.getReg())) + return true; + } + return false; + }; + return !defConflicts(A, B) && !defConflicts(B, A); +} +struct FlyReorderPass : public MachineFunctionPass { + static char ID; + FlyReorderPass() : MachineFunctionPass(ID) {} + bool runOnMachineFunction(MachineFunction &MF) override { + const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo(); + bool changed = false; + for (MachineBasicBlock &MBB : MF) { + for (auto it = MBB.begin(); it != MBB.end();) { + auto nxt = std::next(it); + if (nxt != MBB.end() && canSwap(*it, *nxt, TRI)) { + MBB.splice(it, &MBB, nxt); // move B before A + changed = true; + it = std::next(it); // it still == A; skip past the swapped pair + } else { + ++it; + } + } + } + return changed; + } + StringRef getPassName() const override { return "Fly reorder MIR pass"; } +}; +char FlyReorderPass::ID = 0; +} // namespace +static RegisterPass X("fly-reorder", "Fly reorder", false, false); diff --git a/tests/kernels/test_llvm_codegen_pass_e2e.py b/tests/kernels/test_llvm_codegen_pass_e2e.py index 19fd2d31c..778a75240 100644 --- a/tests/kernels/test_llvm_codegen_pass_e2e.py +++ b/tests/kernels/test_llvm_codegen_pass_e2e.py @@ -31,121 +31,14 @@ import flydsl.expr as fx # noqa: E402 from flydsl.compiler.external_llvm import ExternalLLVMError # noqa: E402 -# Legacy MachineFunctionPass plugin: registers "fly-mir-pass" via RegisterPass, -# runs pre-emit during codegen, prints the MF name (observable under `pytest -s`). -PLUGIN_SRC = r""" -#include "llvm/CodeGen/MachineFunction.h" -#include "llvm/CodeGen/MachineFunctionPass.h" -#include "llvm/Pass.h" -#include "llvm/Support/raw_ostream.h" -using namespace llvm; -namespace { -struct FlyMirPass : public MachineFunctionPass { - static char ID; - FlyMirPass() : MachineFunctionPass(ID) {} - bool runOnMachineFunction(MachineFunction &MF) override { - errs() << "fly-mir-pass: ran on " << MF.getName() << "\n"; - return false; - } - StringRef getPassName() const override { return "Fly demo MIR pass"; } -}; -char FlyMirPass::ID = 0; -} // namespace -static RegisterPass X("fly-mir-pass", "Fly demo MIR pass", false, false); -""" - -# A MIR pass that *modifies* the machine code: inserts NOP_PER_FUNC ``s_nop`` at -# the entry of every kernel. The opcode is found by name so no AMDGPU target -# headers are needed; the count is a distinctive, measurable change in the ASM. +# Plugin sources live as standalone .cpp files (compiled at test time) under +# llvm_pass_plugins/. Each registers a legacy MachineFunctionPass: +# print_mir_pass.cpp -> 'fly-mir-pass' (no-op; prints the MF name) +# insert_nop_mir_pass.cpp -> 'fly-insert-nop' (inserts NOP_PER_FUNC s_nop) +# reorder_mir_pass.cpp -> 'fly-reorder' (safe adjacent-instr swap) +_PLUGINS_DIR = Path(__file__).parent / "llvm_pass_plugins" +# Must match the count inserted by insert_nop_mir_pass.cpp. NOP_PER_FUNC = 8 -PLUGIN_SRC_NOP = r""" -#include "llvm/CodeGen/MachineFunction.h" -#include "llvm/CodeGen/MachineFunctionPass.h" -#include "llvm/CodeGen/MachineInstrBuilder.h" -#include "llvm/CodeGen/TargetInstrInfo.h" -#include "llvm/CodeGen/TargetSubtargetInfo.h" -#include "llvm/Pass.h" -using namespace llvm; -namespace { -struct FlyInsertNopPass : public MachineFunctionPass { - static char ID; - FlyInsertNopPass() : MachineFunctionPass(ID) {} - bool runOnMachineFunction(MachineFunction &MF) override { - const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo(); - unsigned NopOpc = ~0u; - for (unsigned i = 0, e = TII->getNumOpcodes(); i < e; ++i) - if (TII->getName(i) == "S_NOP") { NopOpc = i; break; } - if (NopOpc == ~0u || MF.empty()) return false; - MachineBasicBlock &MBB = MF.front(); - auto It = MBB.begin(); - for (int k = 0; k < 8; ++k) - BuildMI(MBB, It, DebugLoc(), TII->get(NopOpc)).addImm(0); - return true; - } - StringRef getPassName() const override { return "Fly insert NOP MIR pass"; } -}; -char FlyInsertNopPass::ID = 0; -} // namespace -static RegisterPass X("fly-insert-nop", "Fly insert NOP", false, false); -""" - -# A MIR pass that *schedules* (reorders) instructions: within each block it swaps -# adjacent instructions whenever that is provably safe — neither has memory/side -# effects and no def of one overlaps any register operand (def or use, explicit -# or implicit) of the other. Semantics are preserved (results stay correct) but -# the emitted instruction order changes. -PLUGIN_SRC_REORDER = r""" -#include "llvm/CodeGen/MachineBasicBlock.h" -#include "llvm/CodeGen/MachineFunction.h" -#include "llvm/CodeGen/MachineFunctionPass.h" -#include "llvm/CodeGen/TargetInstrInfo.h" -#include "llvm/CodeGen/TargetRegisterInfo.h" -#include "llvm/CodeGen/TargetSubtargetInfo.h" -#include "llvm/Pass.h" -using namespace llvm; -namespace { -static bool unsafe(const MachineInstr &MI) { - return MI.mayLoadOrStore() || MI.hasUnmodeledSideEffects() || MI.isCall() || - MI.isTerminator() || MI.isBranch() || MI.isInlineAsm() || MI.isMetaInstruction(); -} -static bool canSwap(const MachineInstr &A, const MachineInstr &B, const TargetRegisterInfo *TRI) { - if (unsafe(A) || unsafe(B)) return false; - auto defConflicts = [&](const MachineInstr &X, const MachineInstr &Y) { - for (const MachineOperand &dx : X.operands()) { - if (!dx.isReg() || !dx.getReg() || !dx.isDef()) continue; - for (const MachineOperand &oy : Y.operands()) - if (oy.isReg() && oy.getReg() && TRI->regsOverlap(dx.getReg(), oy.getReg())) return true; - } - return false; - }; - return !defConflicts(A, B) && !defConflicts(B, A); -} -struct FlyReorderPass : public MachineFunctionPass { - static char ID; - FlyReorderPass() : MachineFunctionPass(ID) {} - bool runOnMachineFunction(MachineFunction &MF) override { - const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo(); - bool changed = false; - for (MachineBasicBlock &MBB : MF) { - for (auto it = MBB.begin(); it != MBB.end();) { - auto nxt = std::next(it); - if (nxt != MBB.end() && canSwap(*it, *nxt, TRI)) { - MBB.splice(it, &MBB, nxt); // move B before A - changed = true; - it = std::next(it); // it still == A; skip past the swapped pair - } else { - ++it; - } - } - } - return changed; - } - StringRef getPassName() const override { return "Fly reorder MIR pass"; } -}; -char FlyReorderPass::ID = 0; -} // namespace -static RegisterPass X("fly-reorder", "Fly reorder", false, false); -""" def _gpu_available() -> bool: @@ -169,8 +62,9 @@ def _resolve_tool(env_var: str, name: str): return None -def _build_codegen_plugin(tmp_path_factory, *, src: str, name: str) -> str: - """Skip unless the codegen toolchain is present, then compile a plugin .so.""" +def _build_codegen_plugin(tmp_path_factory, *, cpp_name: str, lib_name: str) -> str: + """Skip unless the codegen toolchain is present, then compile a plugin .cpp + from llvm_pass_plugins/ into a .so.""" raw = os.environ.get("FLYDSL_COMPILE_LLVM_DIR", "").strip() if not raw: pytest.skip("FLYDSL_COMPILE_LLVM_DIR not set; required to build/load a codegen pass plugin") @@ -185,11 +79,9 @@ def _build_codegen_plugin(tmp_path_factory, *, src: str, name: str) -> str: pytest.skip("ld.lld not found; set FLYDSL_COMPILE_LLD or place ld.lld in /bin") cxxflags = subprocess.check_output([str(llvm_config), "--cxxflags"], text=True).split() - work = tmp_path_factory.mktemp("codegen_plugin") - cpp = work / (name + ".cpp") - cpp.write_text(src, encoding="utf-8") - so = work / ("lib" + name + ".so") - subprocess.run([cxx, "-shared", "-fPIC", *cxxflags, str(cpp), "-o", str(so)], check=True) + src = _PLUGINS_DIR / cpp_name + so = tmp_path_factory.mktemp("codegen_plugin") / ("lib" + lib_name + ".so") + subprocess.run([cxx, "-shared", "-fPIC", *cxxflags, str(src), "-o", str(so)], check=True) assert so.is_file() return str(so) @@ -197,19 +89,19 @@ def _build_codegen_plugin(tmp_path_factory, *, src: str, name: str) -> str: @pytest.fixture(scope="module") def mir_pass_plugin(tmp_path_factory) -> str: """Compile the print-only MIR pass plugin (and ensure fly-llc + ld.lld exist).""" - return _build_codegen_plugin(tmp_path_factory, src=PLUGIN_SRC, name="FlyMir") + return _build_codegen_plugin(tmp_path_factory, cpp_name="print_mir_pass.cpp", lib_name="FlyMir") @pytest.fixture(scope="module") def nop_pass_plugin(tmp_path_factory) -> str: """Compile the s_nop-inserting MIR pass plugin (modifies the machine code).""" - return _build_codegen_plugin(tmp_path_factory, src=PLUGIN_SRC_NOP, name="FlyNop") + return _build_codegen_plugin(tmp_path_factory, cpp_name="insert_nop_mir_pass.cpp", lib_name="FlyNop") @pytest.fixture(scope="module") def reorder_pass_plugin(tmp_path_factory) -> str: """Compile the instruction-reordering (scheduling) MIR pass plugin.""" - return _build_codegen_plugin(tmp_path_factory, src=PLUGIN_SRC_REORDER, name="FlyReorder") + return _build_codegen_plugin(tmp_path_factory, cpp_name="reorder_mir_pass.cpp", lib_name="FlyReorder") @flyc.kernel diff --git a/tests/kernels/test_llvm_pass_plugin_e2e.py b/tests/kernels/test_llvm_pass_plugin_e2e.py index adf87b570..ee1f33359 100644 --- a/tests/kernels/test_llvm_pass_plugin_e2e.py +++ b/tests/kernels/test_llvm_pass_plugin_e2e.py @@ -30,69 +30,11 @@ import flydsl.expr as fx # noqa: E402 from flydsl.compiler.external_llvm import ExternalLLVMError # noqa: E402 -# An LLVM new-PM module pass registered under ``flydsl-print-tid`` via the -# pass-plugin C API. At the entry of every amdgpu_kernel it emits FlyDSL's exact -# hostcall device-printf sequence (``__ockl_printf_begin`` / ``append_string_n`` / -# ``append_args``) printing ``threadIdx.x``. Using the same ockl ABI as -# ``fx.printf`` means the ROCm runtime FlyDSL already sets up services it, and -# ``ockl`` is linked during the O=0 re-codegen. (Note: the C ``printf`` + -# ``amdgpu-printf-runtime-binding`` route instead emits the buffered -# ``__printf_alloc`` path, which FlyDSL's runtime does not service.) -PLUGIN_SRC = r""" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Intrinsics.h" -#include "llvm/IR/IntrinsicsAMDGPU.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/PassManager.h" -#include "llvm/Passes/PassBuilder.h" -#include "llvm/Plugins/PassPlugin.h" -using namespace llvm; -namespace { -struct PrintTidPass : PassInfoMixin { - PreservedAnalyses run(Module &M, ModuleAnalysisManager &) { - LLVMContext &C = M.getContext(); - auto *i64 = Type::getInt64Ty(C); - auto *i32 = Type::getInt32Ty(C); - auto *ptr = PointerType::get(C, 0); - FunctionCallee beginF = - M.getOrInsertFunction("__ockl_printf_begin", FunctionType::get(i64, {i64}, false)); - FunctionCallee strF = M.getOrInsertFunction( - "__ockl_printf_append_string_n", FunctionType::get(i64, {i64, ptr, i64, i32}, false)); - FunctionCallee argsF = M.getOrInsertFunction( - "__ockl_printf_append_args", - FunctionType::get(i64, {i64, i32, i64, i64, i64, i64, i64, i64, i64, i32}, false)); - Function *widx = Intrinsic::getOrInsertDeclaration(&M, Intrinsic::amdgcn_workitem_id_x); - bool changed = false; - for (Function &F : M) { - if (F.isDeclaration() || F.getCallingConv() != CallingConv::AMDGPU_KERNEL) - continue; - IRBuilder<> B(&*F.getEntryBlock().getFirstInsertionPt()); - Constant *str = ConstantDataArray::getString(C, "flydsl-pass: threadIdx.x=%d\n", true); - // Format string must live in addrspace 0 (matches the ockl append ABI). - auto *gv = new GlobalVariable(M, str->getType(), true, GlobalValue::InternalLinkage, str, - "flydsl_tid_fmt", nullptr, GlobalValue::NotThreadLocal, 0); - uint64_t len = cast(str->getType())->getNumElements(); - Value *tid = B.CreateZExt(B.CreateCall(widx, {}), i64); - Value *z = ConstantInt::get(i64, 0); - Value *h0 = B.CreateCall(beginF, {z}); - Value *h1 = B.CreateCall(strF, {h0, gv, ConstantInt::get(i64, len), ConstantInt::get(i32, 0)}); - B.CreateCall(argsF, {h1, ConstantInt::get(i32, 1), tid, z, z, z, z, z, z, ConstantInt::get(i32, 1)}); - changed = true; - } - return changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); - } -}; -} // namespace -extern "C" LLVM_ATTRIBUTE_WEAK ::llvm::PassPluginLibraryInfo llvmGetPassPluginInfo() { - return {LLVM_PLUGIN_API_VERSION, "FlydslPrintTid", LLVM_VERSION_STRING, [](PassBuilder &PB) { - PB.registerPipelineParsingCallback( - [](StringRef N, ModulePassManager &MPM, ArrayRef) { - if (N == "flydsl-print-tid") { MPM.addPass(PrintTidPass()); return true; } - return false; - }); - }}; -} -""" +# Plugin sources live as standalone .cpp files (compiled at test time, see +# llvm_pass_plugins/). print_tid_ir_pass.cpp registers "flydsl-print-tid": +# at each amdgpu_kernel entry it emits FlyDSL's hostcall device-printf +# sequence (__ockl_printf_*) printing threadIdx.x, matching the fx.printf ABI. +_PLUGINS_DIR = Path(__file__).parent / "llvm_pass_plugins" def _gpu_available() -> bool: @@ -119,10 +61,8 @@ def print_tid_plugin(tmp_path_factory) -> str: pytest.skip("LLVM headers/llvm-config or a C++ compiler not available for plugin build") cxxflags = subprocess.check_output([str(llvm_config), "--cxxflags"], text=True).split() - work = tmp_path_factory.mktemp("llvm_plugin") - src = work / "flydsl_print_tid.cpp" - src.write_text(PLUGIN_SRC, encoding="utf-8") - so = work / "libFlydslPrintTid.so" + src = _PLUGINS_DIR / "print_tid_ir_pass.cpp" + so = tmp_path_factory.mktemp("llvm_plugin") / "libFlydslPrintTid.so" subprocess.run([cxx, "-shared", "-fPIC", *cxxflags, str(src), "-o", str(so)], check=True) assert so.is_file() return str(so)