Skip to content

[Cute-DSL] Fix runtime shared library loading for TVM FFI Module Loading#3075

Open
Alkaid-Benetnash wants to merge 1 commit intoNVIDIA:mainfrom
Alkaid-Benetnash:cutedsl-runtime-dynld-global
Open

[Cute-DSL] Fix runtime shared library loading for TVM FFI Module Loading#3075
Alkaid-Benetnash wants to merge 1 commit intoNVIDIA:mainfrom
Alkaid-Benetnash:cutedsl-runtime-dynld-global

Conversation

@Alkaid-Benetnash
Copy link
Copy Markdown

When loading a pre-compiled CuTe DSL shared library via cute.runtime.load_module() with TVM FFI, the dynamic linker fails to resolve _cudaLibraryLoadData:

m = cute.runtime.load_module(path, enable_tvm_ffi=True)

#Stack trace:
cutlass/cute/runtime.py:966: in load_module
    return tvm_ffi.load_module(file_path, keep_module_alive=False)
tvm_ffi/module.py:472: in load_module
    mod = _ffi_api.ModuleLoadFromFile(path)
tvm_ffi/cython/function.pxi:923: in tvm_ffi.core.Function.__call__
    ...
# Error: _cudaLibraryLoadData not defined

libcute_dsl_runtime.so exports _cudaLibraryLoadData and other CUDA Runtime API wrappers that the MLIR-generated code in compiled .so files references. The library should be loaded with RTLD_GLOBAL so its symbols are visible to subsequent dlopen() calls.

ctypes.CDLL(path) defaults to mode=0, which is RTLD_LOCAL | RTLD_LAZY on Linux. RTLD_LOCAL keeps the loaded symbols private.

Without this PR, the workaround would be "always run cute.compile once before running pre-compiled binaries".

When loading a pre-compiled CuTe DSL shared library via `cute.runtime.load_module()` with TVM FFI,
the dynamic linker fails to resolve `_cudaLibraryLoadData`:

```
m = cute.runtime.load_module(path, enable_tvm_ffi=True)

#Stack trace:
cutlass/cute/runtime.py:966: in load_module
    return tvm_ffi.load_module(file_path, keep_module_alive=False)
tvm_ffi/module.py:472: in load_module
    mod = _ffi_api.ModuleLoadFromFile(path)
tvm_ffi/cython/function.pxi:923: in tvm_ffi.core.Function.__call__
    ...
# Error: _cudaLibraryLoadData not defined
```

`libcute_dsl_runtime.so` exports `_cudaLibraryLoadData` and other CUDA Runtime API wrappers
that the MLIR-generated code in compiled `.so` files references. The library should be loaded
with `RTLD_GLOBAL` so its symbols are visible to subsequent `dlopen()` calls.

`ctypes.CDLL(path)` defaults to `mode=0`, which is `RTLD_LOCAL | RTLD_LAZY` on Linux.
`RTLD_LOCAL` keeps the loaded symbols **private**.

Without this PR, the workaround would be "always run cute.compile once before running pre-compiled binaries".
@Alkaid-Benetnash
Copy link
Copy Markdown
Author

cc @tqchen ?

Alkaid-Benetnash added a commit to Alkaid-Benetnash/flash-attention that referenced this pull request Feb 26, 2026
…d persistent compile cache

**The proposed new two-pass test workflow:**

Step.1: pre-compile tests in parallel with pytest-xdist (e.g., 256 workers)

```
FLASH_ATTENTION_FAKE_TENSOR=1 FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1 pytest -n 256 -x tests/cute/test_flash_attn.py
```

Compiled binaries will be dumped to `/tmp/${USER}/flash_attention_cute_dsl_cache/` by default

On a B200 machine, I got:
> 30044 passed, 16128 skipped, 413456 warnings in 699.09s (0:11:39)

Step.2: run tests without `cute.compile`
```
FLASH_ATTENTION_FAKE_TENSOR=0 FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1 pytest -x tests/cute/test_flash_attn.py -k 'not test_flash_attn_kvcache'
```
Skipping test_flash_attn_kvcache because it currently has [GPU hangs](Dao-AILab#2278).
On the same B200 machine, with single-GPU-single-worker, I got:
> 27548 passed, 15552 skipped, 3072 deselected, 553 warnings in 767.08s (0:12:47)

PS: you can do pytest-xdist here too. ~13min full sweep here is an upper-bound.

The new introduced envvar flags are disabled by default.

**How it works:**

- Use [torch fake tensor mode](https://docs.pytorch.org/docs/stable/user_guide/torch_compiler/torch.compiler_fake_tensor.html#api-the-important-bits)
  to enable cute.compile but don't allocate GPU memory or run kernels.
- Changes are made throughout test files and flash_attn `interface.py` to
  bypass certain "data-dependent" operations when inside fake tensor mode.
- Use cutedsl ahead-of-time compilation support to export and load compiled
  modules to/from external storage.

NOTE: To make the above workflow work, I had to patch the following cutedsl bugs:
- leaked MLIR threading pool: [cutlass#3062](NVIDIA/cutlass#3062)
- ahead-of-time compiling not loading shared memory correctly: [cutlass#3075](NVIDIA/cutlass#3075)

NOTE: As of cutedsl 4.4.0, cutedsl does not support exporting compiled
kernels as "shared library" (i.e., only support object files), but tvm_ffi only supports shared library.
So my currently workaround is an explicit `ld -shared -o foo.so foo.o`.

cutedsl patches attached bellow for your convenience (`cd site-packages && patch -p1 < xxx`):
```
--- a/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/dsl.py 2026-02-25 17:32:03.674296431 -0800
+++ b/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/dsl.py 2026-02-23 23:42:18.111766576 -0800
@@ -1345,7 +1345,8 @@
         location=None,
     ):
         """Generate MLIR module and compile iself.T_provider."""
-        with ir.Context(), self.get_ir_location(location):
+        with ir.Context() as ctx, self.get_ir_location(location):
+            ctx.enable_multithreading(False)
             try:
                 # Convert input arguments to MLIR arguments
                 exe_args, func_types, adapted_args = self.generate_mlir_function_types(
--- a/nvidia_cutlass_dsl/python_packages/cutlass/cute/runtime.py 2026-02-25 17:32:03.678703799 -0800
+++ b/nvidia_cutlass_dsl/python_packages/cutlass/cute/runtime.py 2026-02-25 17:28:11.608996616 -0800
@@ -951,7 +951,7 @@
         # no need to load tvm_ffi library here since it will be loaded by tvm_ffi package.
         for path in find_runtime_libraries(enable_tvm_ffi=False):
             if Path(path).exists():
-                _LOAD_MODULE_LIBS_CACHE.append(ctypes.CDLL(path))
+                _LOAD_MODULE_LIBS_CACHE.append(ctypes.CDLL(path, mode=ctypes.RTLD_GLOBAL))

     if enable_tvm_ffi:
         import tvm_ffi
```

**Other minor changes***

Added statistics printing to breakdown number of tests by test name:
```
pytest --co -qq tests/cute/test_flash_attn.py
```
Would print stats like:
```
cute/test_flash_attn.py: 46172
test_counts={'test_flash_attn_output': 2304, 'test_flash_attn_varlen_output': 39168, 'test_flash_attn_kvcache': 3072, 'test_flash_attn_bwd_preallocated_outputs': 8, 'test_flash_attn_combine': 1620}
```

General QoL w.r.t. `pytest -n N`
- Add per-worker logging at `/tmp/${USER}/flash_attention_tests/tests_gw${worker_id}.log`
- Only let one worker to query nvidia-smi and cache that results for
  other workers. When spawning many workers initializing pytorch,
  nvidia-smi becomes extremely slow and always timeout.
Alkaid-Benetnash added a commit to Alkaid-Benetnash/flash-attention that referenced this pull request Feb 27, 2026
…d persistent compile cache

**The proposed new two-pass test workflow:**

Step.1: pre-compile tests in parallel with pytest-xdist (e.g., 256 workers)

```
FLASH_ATTENTION_FAKE_TENSOR=1 FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1 pytest -n 256 -x tests/cute/test_flash_attn.py
```

Compiled binaries will be dumped to `/tmp/${USER}/flash_attention_cute_dsl_cache/` by default

On a B200 machine, I got:
> 30044 passed, 16128 skipped, 413456 warnings in 699.09s (0:11:39)

Step.2: run tests without `cute.compile`
```
FLASH_ATTENTION_FAKE_TENSOR=0 FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1 pytest -x tests/cute/test_flash_attn.py -k 'not test_flash_attn_kvcache'
```
Skipping test_flash_attn_kvcache because it currently has [GPU hangs](Dao-AILab#2278).
On the same B200 machine, with single-GPU-single-worker, I got:
> 27548 passed, 15552 skipped, 3072 deselected, 553 warnings in 767.08s (0:12:47)

PS: you can do pytest-xdist here too. ~13min full sweep here is an upper-bound.

The new introduced envvar flags are disabled by default.

**How it works:**

- Use [torch fake tensor mode](https://docs.pytorch.org/docs/stable/user_guide/torch_compiler/torch.compiler_fake_tensor.html#api-the-important-bits)
  to enable cute.compile but don't allocate GPU memory or run kernels.
- Changes are made throughout test files and flash_attn `interface.py` to
  bypass certain "data-dependent" operations when inside fake tensor mode.
- Use cutedsl ahead-of-time compilation support to export and load compiled
  modules to/from external storage.

NOTE: To make the above workflow work, I had to patch the following cutedsl bugs:
- leaked MLIR threading pool: [cutlass#3062](NVIDIA/cutlass#3062)
- ahead-of-time compiling not loading shared memory correctly: [cutlass#3075](NVIDIA/cutlass#3075)

NOTE: As of cutedsl 4.4.0, cutedsl does not support exporting compiled
kernels as "shared library" (i.e., only support object files), but tvm_ffi only supports shared library.
So my currently workaround is an explicit `ld -shared -o foo.so foo.o`.

cutedsl patches attached bellow for your convenience (`cd site-packages && patch -p1 < xxx`):
```
--- a/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/dsl.py 2026-02-25 17:32:03.674296431 -0800
+++ b/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/dsl.py 2026-02-23 23:42:18.111766576 -0800
@@ -1345,7 +1345,8 @@
         location=None,
     ):
         """Generate MLIR module and compile iself.T_provider."""
-        with ir.Context(), self.get_ir_location(location):
+        with ir.Context() as ctx, self.get_ir_location(location):
+            ctx.enable_multithreading(False)
             try:
                 # Convert input arguments to MLIR arguments
                 exe_args, func_types, adapted_args = self.generate_mlir_function_types(
--- a/nvidia_cutlass_dsl/python_packages/cutlass/cute/runtime.py 2026-02-25 17:32:03.678703799 -0800
+++ b/nvidia_cutlass_dsl/python_packages/cutlass/cute/runtime.py 2026-02-25 17:28:11.608996616 -0800
@@ -951,7 +951,7 @@
         # no need to load tvm_ffi library here since it will be loaded by tvm_ffi package.
         for path in find_runtime_libraries(enable_tvm_ffi=False):
             if Path(path).exists():
-                _LOAD_MODULE_LIBS_CACHE.append(ctypes.CDLL(path))
+                _LOAD_MODULE_LIBS_CACHE.append(ctypes.CDLL(path, mode=ctypes.RTLD_GLOBAL))

     if enable_tvm_ffi:
         import tvm_ffi
```

**Other minor changes***

Added statistics printing to breakdown number of tests by test name:
```
pytest --co -qq tests/cute/test_flash_attn.py
```
Would print stats like:
```
cute/test_flash_attn.py: 46172
test_counts={'test_flash_attn_output': 2304, 'test_flash_attn_varlen_output': 39168, 'test_flash_attn_kvcache': 3072, 'test_flash_attn_bwd_preallocated_outputs': 8, 'test_flash_attn_combine': 1620}
```

General QoL w.r.t. `pytest -n N`
- Add per-worker logging at `/tmp/${USER}/flash_attention_tests/tests_gw${worker_id}.log`
- Only let one worker to query nvidia-smi and cache that results for
  other workers. When spawning many workers initializing pytorch,
  nvidia-smi becomes extremely slow and always timeout.
@tqchen
Copy link
Copy Markdown
Contributor

tqchen commented Feb 27, 2026

Thanks @Alkaid-Benetnash seems this is a requirement for related symbols being available in the global space. this is a reasonable change. alternatively i think we could have the .o link to libcute_dsl_runtime.so explicitly in linking step, which should get this module

There should some some further updates that already loads the cute.runtime.load_module to use cuteDSL's binaryeExecutor which might resolve the issue, needs to confirm if it lands in release cc @fengxie

@fengxie
Copy link
Copy Markdown
Collaborator

fengxie commented Mar 23, 2026

@brandon-yujie-sun can you confirm this will be included in latest release?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants