Skip to content

CK JIT integration#582

Open
ipanfilo wants to merge 4 commits into
devfrom
ipanfilo/ck_jit
Open

CK JIT integration#582
ipanfilo wants to merge 4 commits into
devfrom
ipanfilo/ck_jit

Conversation

@ipanfilo
Copy link
Copy Markdown
Collaborator

Description

Add CK JIT demo integration that builds slim MHA libraries accompanied by blob sources to compile on demand

Fixes https://github.com/ROCm/frameworks-internal/issues/16406

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@ipanfilo ipanfilo added the ci-level 1 CI test level 1 label May 11, 2026
@ipanfilo ipanfilo requested a review from Micky774 May 13, 2026 15:59
@ipanfilo ipanfilo marked this pull request as ready for review May 13, 2026 15:59
@github-actions
Copy link
Copy Markdown

Claude Walkthrough

Intent. Wire up an on-demand JIT path for the AITER/CK fused-attention kernels: rather than always building the full kernel matrix from source (or downloading a monolithic prebuilt), the build now produces slim MHA libraries paired with per-kernel blob sources that get compiled lazily at runtime. A new 3rdparty/ck_jit submodule drives the build, and CI is taught to warm the JIT cache before tests. Fixes frameworks-internal#16406.

Key changes.

  • New submodule 3rdparty/ck_jit (ipanfilo/ck_jit) added in .gitmodules and pinned via 3rdparty/ck_jit:1; the wheel-build workflow now checks it out (.github/workflows/rocm-wheels-build.yml:90).
  • NVTE_CK_JIT env-var gate (default on) selects the JIT build path in transformer_engine/common/ck_fused_attn/CMakeLists.txt:74; NVTE_CK_JIT_DIR and AITER_MHA_PATH provide overrides.
  • New ck_jit_prebuild shell helper in ci/_utils.sh:278 invokes lib/ck_jit/ck_jit_prebuild.py (shipped from the submodule into the install tree) against ci/ck_jit_prebuild.txt — a 531-line allowlist of fmha_* blob filenames used to warm the cache.
  • ci/jax.sh and ci/pytorch.sh call ck_jit_prebuild build before the test loop and ck_jit_prebuild list after, so per-config runs hit a populated cache; pytorch.sh also adds a check_flash_attn_installed guard so the flash config is skipped when FA isn't present.
  • setup.py:78 drops the old NVTE_AOTRITON_PATH / NVTE_CK_FUSED_ATTN_PATH cmake-flag forwarding; those overrides now live entirely inside the CMake script as environment variables.
  • aiter_prebuilt.cmake now returns the cache root (not …/lib) so the same variable can carry both libs and headers (transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake:51).

Walkthrough.

transformer_engine/common/ck_fused_attn/CMakeLists.txt — the install prefix is reframed: AITER_MHA_INSTALL_DIR is now a concrete path (${CMAKE_INSTALL_PREFIX}/transformer_engine/lib) rather than a cache string. A three-way decision tree picks the source of te_libmha_{fwd,bwd}.so and the QoLA headers: (1) $AITER_MHA_PATH user override, (2) the pre-built download cache (skipped when NVTE_CK_JIT is on), or (3) build from source. The source-build path forks again: if NVTE_CK_JIT is set, it invokes ${__CK_JIT_SOURCE_DIR}/ck_jit_build.py full --with-qola … (which drives QoLA internally and installs straight into AITER_MHA_INSTALL_DIR); otherwise the previous direct qola.cli build path runs and copies artifacts into the prebuilt cache. The final install(FILES …) is now conditional — when JIT builds directly into the install dir, the explicit copy is skipped to avoid a self-overwrite.

ci/_utils.shconfigure_omp_threads is refactored: the CPU-count math is split out into a reusable get_cpu_count helper. The new ck_jit_prebuild function detects the GPU arch via rocminfo, resolves the installed transformer_engine package directory to locate lib/ck_jit/ck_jit_prebuild.py, picks --jobs $((cpus/2)), and runs either build (when called as ck_jit_prebuild build) or just cache (to report cache status). A PYTHON_TE_IMPORT snippet sanitises sys.path before importing TE, defending against an in-repo shadow when CI runs from the source tree.

ci/ck_jit_prebuild.txt — flat allowlist of 531 kernel artifact filenames (fmha_fwd_*, fmha_bwd_*, fmha_bwd_convert_dq_*, etc.) spanning bf16/fp16, head dims 32/64/128/192/256, batch/group, deterministic/non-deterministic, padded/non-padded variants, for gfx9 and gfx950. Acts as the static "what to pre-compile in CI" manifest consumed by ck_jit_prebuild.py.

ci/jax.sh, ci/pytorch.sh — bracket the existing per-attn-backend test loops with ck_jit_prebuild build (warm cache once for all configs) and ck_jit_prebuild list (print final cache state). In pytorch.sh the new check_flash_attn_installed runs the TE-installed FlashAttentionUtils.is_installed probe and continues past the flash iteration when false, replacing what was presumably an opaque test failure.

setup.py — drops the Python-side propagation of NVTE_AOTRITON_PATH / NVTE_CK_FUSED_ATTN_PATH into -DAOTRITON_PATH / -DAITER_MHA_PATH. The CMake script now reads $AITER_MHA_PATH directly from the environment, so the indirection is unneeded; CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT is kept (and its default coerced to a string).

Testing. No new tests. The only PR-side test-surface change is the CI shell wrappers (ck_jit_prebuild warmup, check_flash_attn_installed guard) in ci/jax.sh and ci/pytorch.sh. Verification depends entirely on the existing pytest suites picking up populated CK-JIT caches when those scripts run.

Notes for reviewers.

  • This is labeled a breaking change: the build-time contract has shifted. Anyone passing NVTE_AOTRITON_PATH or NVTE_CK_FUSED_ATTN_PATH to setup.py will no longer see those propagate — AITER_MHA_PATH is the supported env override now, and NVTE_CK_JIT=0 is needed to opt back into the prebuilt-cache path.
  • The aiter_prebuilt.cmake return-value change (CACHE_DIR vs CACHE_DIR/lib) is load-bearing for every call site — both the JIT and non-JIT branches now compute ${__AITER_CACHE_DIR}/lib and ${__AITER_CACHE_DIR}/include themselves; any out-of-tree consumer of get_prebuilt_aiter will need the same fix-up.
  • ck_jit_prebuild needs the wheel already installed (it imports transformer_engine to find the prebuild script), so call order in CI matters — install_prerequisites must run first, which it does in both jax.sh and pytorch.sh.
  • The 531-entry blob list contains a handful of suspicious-looking duplicates with random suffixes (e.g. …ndeterministic_gfx9.so.0PNcSK, .HqtomN, .PHbI9W) — worth confirming with the author that these are intentional artifacts of how the list was harvested and not stray tempfile names that snuck into the manifest.
  • The submodule URL points at a personal fork (github.com/ipanfilo/ck_jit); presumably a ROCm-org mirror is the eventual home.

Generated by Claude. To request a code review, comment /claude review.

Comment thread .gitmodules
url = https://github.com/Micky774/QoLA.git
[submodule "3rdparty/ck_jit"]
path = 3rdparty/ck_jit
url = https://github.com/ipanfilo/ck_jit.git
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ck_jit submodule points at a personal repo (github.com/ipanfilo/ck_jit.git). With CK_JIT being enabled by default (see transformer_engine/common/ck_fused_attn/CMakeLists.txt:73), this becomes a hard dependency on ipanfilo/ck_jit for every default ROCm build — including wheels published from rocm-wheels-build.yml. Recommend moving the repo under the ROCm org (or another organizational account) before this lands, so availability/ownership isn't tied to a single contributor account. The existing personal-repo submodules in this file (HaiShaw/minGPT, floraamd/nanoGPTwTE) are example-only and don't gate the build.

Comment thread ci/ck_jit_prebuild.txt
Comment on lines +47 to +49
fmha_bwd_convert_dq_d32_fp16_b64x0_batch_o2_npad_ndeterministic_gfx9.so.0PNcSK
fmha_bwd_convert_dq_d32_fp16_b64x0_batch_o2_npad_ndeterministic_gfx9.so.HqtomN
fmha_bwd_convert_dq_d32_fp16_b64x0_batch_o2_npad_ndeterministic_gfx9.so.PHbI9W
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These three entries (and several more in the file) are clearly transient build-tool tempfiles, not real kernel libraries — the .so.<6-random-chars> suffix is what tools like mv -T / atomic-rename writers leave behind. They got captured into the blob list by accident and will never match an actual prebuild target.

All offending lines I see: 47, 48, 49, 75, 77, 323, 324, 325, 355, 358, 506. Regex \.so\.[A-Za-z0-9]{6}$ flags them all. Suggest regenerating the list with the temp files filtered out (e.g. only accept names matching \.so$).

Comment on lines +73 to +86
if(NOT "$ENV{NVTE_CK_JIT}" STREQUAL "0")
set(__USE_CK_JIT TRUE)
else()
set(__AITER_MHA_PATH "")
set(__USE_CK_JIT FALSE)
endif()
if(DEFINED ENV{AITER_MHA_PATH})
message(STATUS "[AITER-BUILD] Using AITER_MHA_PATH=$ENV{AITER_MHA_PATH}")
# use pre-built libraries and includes from a location specified by the user
set(__AITER_CACHE_DIR $ENV{AITER_MHA_PATH})
elseif(NOT __USE_CK_JIT) #disable for CK_JIT for now
# use pre-built cache
include("${CMAKE_CURRENT_LIST_DIR}/aiter_prebuilt.cmake")
get_prebuilt_aiter(__AITER_MHA_PATH)
get_prebuilt_aiter(__AITER_CACHE_DIR)
endif()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two concerns about the default-on behavior:

  1. Default-on with a fail-closed signal. __USE_CK_JIT is TRUE whenever NVTE_CK_JIT is anything except literal "0" — unset, empty, "off", "false", etc. all enable CK_JIT. If you want default-on, fine, but the check NOT "$ENV{NVTE_CK_JIT}" STREQUAL "0" makes "opt out" only work for NVTE_CK_JIT=0. A clearer pattern is to use a CMake option() plus an env-var override so the choice surfaces in the configure summary and is documented.

  2. Default-on skips the prebuilt-cache path entirely. The elseif(NOT __USE_CK_JIT) on line 82 means that with the default settings (no env vars set) we never call get_prebuilt_aiter(), so the existing on-disk cache and NVTE_AITER_PREBUILT_BASE_URL download path are bypassed and we always rebuild via CK_JIT. That's a non-trivial CI/dev-loop regression vs. the previous default. If CK_JIT is meant to also honor that cache, this branch should still call into aiter_prebuilt.cmake first; if not, the PR description should call this out as an intentional behavior change. Also: NVTE_CK_JIT / NVTE_CK_JIT_DIR aren't documented anywhere — please add a short note (README or env-var table) since this is the new default.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto on point 2, I think we should still use the pre-built cache if available as first priority.

Comment thread setup.py

if rocm_build():
cmake_flags.append("-DUSE_ROCM=ON")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hunk silently drops the NVTE_AOTRITON_PATH -> -DAOTRITON_PATH=… and NVTE_CK_FUSED_ATTN_PATH -> -DAITER_MHA_PATH=… translations.

Downstream the CMake files now read $ENV{AOTRITON_PATH} and $ENV{AITER_MHA_PATH} directly (see aotriton/CMakeLists.txt:11 and ck_fused_attn/CMakeLists.txt:78), so users have to rename their env vars from NVTE_AOTRITON_PATH / NVTE_CK_FUSED_ATTN_PATH to the bare names. Anyone with existing scripts setting the NVTE_* form will get a silent "ignored" failure — the build will quietly fall back to the default behavior instead of using their prebuilt path.

PR is marked as breaking, but the description doesn't call this rename out. Two reasonable options:

  • keep the NVTE_* aliases here as a thin shim (a couple of lines) so old usage still works; or
  • explicitly document the env-var rename in the PR description / release notes.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this an intentional part of this change? Seems unrelated?

Comment thread ci/jax.sh
pip list | egrep "flax|fidle|jax|ml_dtypes|numpy|transformer_e|typing_ext"
#check_test_jobs_requested
#test $? -eq 0 && init_test_jobs `python -c "import jax; print(len([d for d in jax.devices() if 'rocm' in d.client.platform_version]))"`
ck_jit_prebuild build
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ck_jit_prebuild can return non-zero from several paths (missing blob list, missing prebuild script, can't resolve TE install dir — see ci/_utils.sh:280-300), but neither this caller nor the matching one in ci/pytorch.sh:151 checks the return. A silent prebuild failure here means every test config that follows still runs, but JIT-compiles each kernel on the first call — looks like the CI succeeded while actually undoing the whole point of the prebuild step.

Suggest making this fail-fast, e.g. ck_jit_prebuild build || return 1 (or || exit 1 depending on shell context), so a broken prebuild surfaces as a CI failure rather than as a quiet slowdown.



set(AITER_MHA_INSTALL_PREFIX "transformer_engine" CACHE STRING "aiter mha shared lib install prefix in TE")
set(AITER_MHA_INSTALL_DIR "${CMAKE_INSTALL_PREFIX}/transformer_engine/lib")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: this replaces the previous AITER_MHA_INSTALL_PREFIX CACHE STRING with a hardcoded path. The old form let a downstream consumer override the install layout with -DAITER_MHA_INSTALL_PREFIX=…; the new form bakes transformer_engine/lib in. Probably no one was actually using the override, but if you don't need configurability the CACHE STRING doc-string was worth keeping for grep-ability. Not blocking.

@github-actions
Copy link
Copy Markdown

Review summary

Reviewed the CK JIT integration changes (10 files, ~666/-36 LoC; merge-base 6855218).

Scope: new 3rdparty/ck_jit submodule, CMake wiring for an optional CK_JIT path in ck_fused_attn, a CI prebuild helper (ck_jit_prebuild in ci/_utils.sh invoked from ci/jax.sh and ci/pytorch.sh), a 531-entry blob list in ci/ck_jit_prebuild.txt, and a setup.py cleanup that removes two env-var aliases.

Verdict: changes look generally on-track for adding CK_JIT as an additional build path, but a few items should be addressed before merge — see inline comments. Highlights:

  • 3rdparty/ck_jit submodule URL is a personal repo (ipanfilo/ck_jit); with CK_JIT default-on this becomes a hard build dependency.
  • ci/ck_jit_prebuild.txt contains ~11 entries that are clearly atomic-rename tempfiles (.so.<6 random chars>); they should be filtered out before checking in.
  • NVTE_CK_JIT default-on behavior in ck_fused_attn/CMakeLists.txt skips the existing prebuilt-cache / download path entirely, which is a behavior change worth either reverting or documenting. The env var itself is undocumented.
  • setup.py silently drops NVTE_AOTRITON_PATH / NVTE_CK_FUSED_ATTN_PATH aliasing — existing users get a quiet fallback unless they rename their env vars.
  • ck_jit_prebuild build return value is ignored in both CI driver scripts, so prebuild failures degrade silently into per-test JIT compiles.

Copyright headers: OK — all modified/added files carry an AMD copyright line with year ending in 2026; no NVIDIA copyright was altered.

This is the first Claude review on this PR; no prior Claude or human review comments to dedupe against.

Copy link
Copy Markdown
Collaborator

@wangye805 wangye805 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Actually this is a very clever and non-trivial design. Do you think we can merge this into CK repo?

Two small issues:
1). For the multi-process scenario like pytorch, your mktemp + mv -n cache write pattern relies on rename(2) being atomic, which is only guaranteed on a single local filesystem — if $CK_JIT_CACHE_DIR ever lands on NFS or another shared FS, readers can dlopen a partially-written .so. Worth documenting that the cache dir must be local.
2). cache invalidation note: stems are content-agnostic, so a CK/aiter source change without a template-signature change silently reuses the stale .so — worth a README line telling users to run ck_jit_prebuild.py clean --all after any CK/aiter version bump. Or run ck_jit_prebuild.py clean during each pip install?

Copy link
Copy Markdown
Contributor

@Micky774 Micky774 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, just a couple points

Comment thread setup.py

if rocm_build():
cmake_flags.append("-DUSE_ROCM=ON")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this an intentional part of this change? Seems unrelated?

Comment on lines +73 to +86
if(NOT "$ENV{NVTE_CK_JIT}" STREQUAL "0")
set(__USE_CK_JIT TRUE)
else()
set(__AITER_MHA_PATH "")
set(__USE_CK_JIT FALSE)
endif()
if(DEFINED ENV{AITER_MHA_PATH})
message(STATUS "[AITER-BUILD] Using AITER_MHA_PATH=$ENV{AITER_MHA_PATH}")
# use pre-built libraries and includes from a location specified by the user
set(__AITER_CACHE_DIR $ENV{AITER_MHA_PATH})
elseif(NOT __USE_CK_JIT) #disable for CK_JIT for now
# use pre-built cache
include("${CMAKE_CURRENT_LIST_DIR}/aiter_prebuilt.cmake")
get_prebuilt_aiter(__AITER_MHA_PATH)
get_prebuilt_aiter(__AITER_CACHE_DIR)
endif()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto on point 2, I think we should still use the pre-built cache if available as first priority.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 1 CI test level 1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants