From d89f90fd30b0510862869404cdaa38a9b56a7e21 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 28 Apr 2026 17:00:41 +0000 Subject: [PATCH 01/10] Updated QoLA (to port CK receipt patch) and TE manifest --- 3rdparty/QoLA | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/QoLA b/3rdparty/QoLA index 549844d77..a597de03f 160000 --- a/3rdparty/QoLA +++ b/3rdparty/QoLA @@ -1 +1 @@ -Subproject commit 549844d771ed3155dd75a6bf2c714cb3f710bada +Subproject commit a597de03f36bf4ea59fe3681675c45e24c441669 From 312212f03eff1ffdb23c80858ea6af3fe19c023f Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 28 Apr 2026 17:04:34 +0000 Subject: [PATCH 02/10] Updated manifest --- transformer_engine/common/ck_fused_attn/qola_manifest.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/ck_fused_attn/qola_manifest.toml b/transformer_engine/common/ck_fused_attn/qola_manifest.toml index b6877c6c0..9bdce2e60 100644 --- a/transformer_engine/common/ck_fused_attn/qola_manifest.toml +++ b/transformer_engine/common/ck_fused_attn/qola_manifest.toml @@ -1,5 +1,5 @@ [qola] -aiter_commit = "33f2e6af5f39379c739720080ed0033d533f5cb2" # pinned AITER submodule commit +aiter_commit = "8f816a049449f39609ee7daca8c21d63aa4274ed" # pinned AITER submodule commit namespace = "te" rocm_versions = ["7.2"] From 89f6983c1a41f6547b6a243921f99f833375e6bf Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 28 Apr 2026 17:53:17 +0000 Subject: [PATCH 03/10] Corrected AITER mha args validation against pinned commit --- .../common/ck_fused_attn/CMakeLists.txt | 33 +++++++++++++++++-- .../common/ck_fused_attn/aiter_prebuilt.cmake | 18 ++++++++-- .../ck_fused_attn/check_aiter_mha_args.py | 5 ++- .../ck_fused_attn/src/ck_fused_attn_bwd.cpp | 2 ++ 4 files changed, 52 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/ck_fused_attn/CMakeLists.txt b/transformer_engine/common/ck_fused_attn/CMakeLists.txt index b11e848dd..7ae3e3c7a 100644 --- a/transformer_engine/common/ck_fused_attn/CMakeLists.txt +++ b/transformer_engine/common/ck_fused_attn/CMakeLists.txt @@ -33,9 +33,37 @@ if(NOT Python_EXECUTABLE) find_package(Python COMPONENTS Interpreter QUIET) endif() +# Resolve the manifest-pinned AITER commit (defines AITER_SHA) and bring the +# QoLA-nested AITER source tree to that commit before any consumer reads it +# (header validation below, header includes for the .cpp build later, and +# QoLA's own kernel build if the prebuilt cache misses). +include("${CMAKE_CURRENT_LIST_DIR}/aiter_prebuilt.cmake") + if(Python_EXECUTABLE) + set(__QOLA_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/QoLA") + execute_process( + COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${__QOLA_DIR}:$ENV{PYTHONPATH}" + ${Python_EXECUTABLE} -c + "from qola.build_tools.submodule import ensure_aiter_commit; ensure_aiter_commit(r'${__AITER_SOURCE_DIR}', r'${AITER_SHA}')" + RESULT_VARIABLE AITER_CHECKOUT_RESULT + OUTPUT_VARIABLE AITER_CHECKOUT_OUTPUT + ERROR_VARIABLE AITER_CHECKOUT_ERROR + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_STRIP_TRAILING_WHITESPACE + ) + if(NOT AITER_CHECKOUT_RESULT EQUAL 0) + message(FATAL_ERROR + "Failed to sync AITER source tree at ${__AITER_SOURCE_DIR} to " + "manifest-pinned commit ${AITER_SHA}.\n" + "${AITER_CHECKOUT_OUTPUT}\n${AITER_CHECKOUT_ERROR}") + endif() + message(STATUS "[AITER] Synced ${__AITER_SOURCE_DIR} to ${AITER_SHA}") + execute_process( - COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/check_aiter_mha_args.py --mode both --te-dir "${CMAKE_CURRENT_LIST_DIR}/../../.." + COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/check_aiter_mha_args.py + --mode both + --te-dir "${CMAKE_CURRENT_LIST_DIR}/../../.." + --aiter-root "${__AITER_SOURCE_DIR}" RESULT_VARIABLE AITER_ARG_CHECK_RESULT OUTPUT_VARIABLE AITER_ARG_CHECK_OUTPUT ERROR_VARIABLE AITER_ARG_CHECK_ERROR @@ -50,7 +78,7 @@ if(Python_EXECUTABLE) endif() message(STATUS "AITER API validation passed via check_aiter_mha_args.py") else() - message(WARNING "Python interpreter not found; skipping AITER API validation.") + message(WARNING "Python interpreter not found; skipping AITER source-tree sync and API validation.") endif() # so far, there are only gfx942 and gfx950 v3 kernels @@ -78,7 +106,6 @@ if(DEFINED AITER_MHA_PATH) set(__AITER_MHA_PATH ${AITER_MHA_PATH}) else() set(__AITER_MHA_PATH "") - include("${CMAKE_CURRENT_LIST_DIR}/aiter_prebuilt.cmake") get_prebuilt_aiter(__AITER_MHA_PATH) if(__AITER_MHA_PATH STREQUAL "") diff --git a/transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake b/transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake index ea0396116..65ee3ec81 100644 --- a/transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake +++ b/transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake @@ -18,8 +18,22 @@ string(STRIP "${ROCM_VER_CONTENT}" ROCM_VER_CONTENT) string(REGEX MATCH "^[0-9]+\\.[0-9]+" ROCM_VER "${ROCM_VER_CONTENT}") string(REGEX MATCH "^[0-9]+" ROCM_VER_MAJOR "${ROCM_VER}") -# AITER commit -get_git_commit("${__AITER_SOURCE_DIR}" AITER_SHA) +# AITER commit — read from the QoLA manifest so the cache key tracks the +# commit QoLA will actually check out and build, not whatever happens to be +# the submodule's current HEAD at configure time. +set(__QOLA_MANIFEST "${CMAKE_CURRENT_LIST_DIR}/qola_manifest.toml") +set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS "${__QOLA_MANIFEST}") +file(STRINGS "${__QOLA_MANIFEST}" __AITER_COMMIT_LINES + REGEX "^[ \t]*aiter_commit[ \t]*=[ \t]*\"[^\"]+\"") +list(LENGTH __AITER_COMMIT_LINES __AITER_COMMIT_COUNT) +if(NOT __AITER_COMMIT_COUNT EQUAL 1) + message(FATAL_ERROR + "Expected exactly one 'aiter_commit = \"...\"' line in " + "${__QOLA_MANIFEST}, found ${__AITER_COMMIT_COUNT}.") +endif() +list(GET __AITER_COMMIT_LINES 0 __AITER_COMMIT_LINE) +string(REGEX MATCH "\"([^\"]+)\"" _UNUSED "${__AITER_COMMIT_LINE}") +set(AITER_SHA "${CMAKE_MATCH_1}") # Cache key & local paths set(AITER_CACHE_ROOT "${CMAKE_CURRENT_LIST_DIR}/../../../build/aiter-prebuilts") diff --git a/transformer_engine/common/ck_fused_attn/check_aiter_mha_args.py b/transformer_engine/common/ck_fused_attn/check_aiter_mha_args.py index 2e9831f1a..2cce9484f 100644 --- a/transformer_engine/common/ck_fused_attn/check_aiter_mha_args.py +++ b/transformer_engine/common/ck_fused_attn/check_aiter_mha_args.py @@ -64,11 +64,14 @@ def main() -> int: parser = argparse.ArgumentParser(description="Check aiter args usage vs header definition") parser.add_argument("--mode", choices=["fwd", "bwd", "both"], default="both", help="Mode: fwd, bwd, or both") parser.add_argument("--te-dir", type=Path, default=Path(__file__).parent.parent.parent.parent, help="Root directory of TransformerEngine") + parser.add_argument("--aiter-root", type=Path, default=None, + help="AITER source tree root. Defaults to /3rdparty/aiter.") args = parser.parse_args() + aiter_root = args.aiter_root if args.aiter_root else args.te_dir / "3rdparty/aiter" modes = ["fwd", "bwd"] if args.mode == "both" else [args.mode] mismatch = 0 for mode in modes: - header_path = args.te_dir / f"3rdparty/aiter/csrc/include/mha_{mode}.h" + header_path = aiter_root / f"csrc/include/mha_{mode}.h" source_path = args.te_dir / f"transformer_engine/common/ck_fused_attn/src/ck_fused_attn_{mode}.cpp" header_text = header_path.read_text(encoding="utf-8") source_text = source_path.read_text(encoding="utf-8") diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index 5f6af0a41..c68a6a3c7 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -527,6 +527,8 @@ hipError_t _ck_attn_bwd_impl( } aiter::mha_bwd_args fmha_args{}; + fmha_args.sink_ptr=nullptr; + fmha_args.d_sink_ptr=nullptr; fmha_args.mask_type = static_cast(mask_type); fmha_args.use_asm_v3 = uses_bwd_v3; fmha_args.v3_atomic_fp32 = is_v3_atomic_fp32; From c417e4070b9ba22c82da2493ea81e31b6638489e Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 28 Apr 2026 20:53:27 +0000 Subject: [PATCH 04/10] Updated cmake w/ dubious ownership protection --- .../common/ck_fused_attn/CMakeLists.txt | 34 +++++++++++++------ .../common/ck_fused_attn/qola_manifest.toml | 2 ++ 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/transformer_engine/common/ck_fused_attn/CMakeLists.txt b/transformer_engine/common/ck_fused_attn/CMakeLists.txt index 7ae3e3c7a..5aea4db7d 100644 --- a/transformer_engine/common/ck_fused_attn/CMakeLists.txt +++ b/transformer_engine/common/ck_fused_attn/CMakeLists.txt @@ -41,10 +41,17 @@ include("${CMAKE_CURRENT_LIST_DIR}/aiter_prebuilt.cmake") if(Python_EXECUTABLE) set(__QOLA_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/QoLA") + # Redirect GIT_CONFIG_GLOBAL to a tempfile carrying `safe.directory = *` so + # git operations inside the QoLA-nested AITER tree (and its recursive + # submodules) work in containerized builds where the bind-mounted .git is + # owned by a different UID than the build process. Mirrors the pattern in + # transformer_engine/common/CMakeLists.txt:get_git_commit(). execute_process( - COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${__QOLA_DIR}:$ENV{PYTHONPATH}" - ${Python_EXECUTABLE} -c - "from qola.build_tools.submodule import ensure_aiter_commit; ensure_aiter_commit(r'${__AITER_SOURCE_DIR}', r'${AITER_SHA}')" + COMMAND sh -c + "tmp=$(mktemp /tmp/gitconfig.XXXXXX) || exit 1; \ +GIT_CONFIG_GLOBAL=$tmp git config --global --add safe.directory '*' >/dev/null 2>&1; \ +GIT_CONFIG_GLOBAL=$tmp PYTHONPATH=\"${__QOLA_DIR}:$PYTHONPATH\" '${Python_EXECUTABLE}' -c 'from qola.build_tools.submodule import ensure_aiter_commit; ensure_aiter_commit(r\"${__AITER_SOURCE_DIR}\", r\"${AITER_SHA}\")'; \ +rc=$?; rm -f \"$tmp\"; exit $rc" RESULT_VARIABLE AITER_CHECKOUT_RESULT OUTPUT_VARIABLE AITER_CHECKOUT_OUTPUT ERROR_VARIABLE AITER_CHECKOUT_ERROR @@ -114,13 +121,19 @@ else() set(__QOLA_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/QoLA") set(__QOLA_BUILD_DIR "${__QOLA_DIR}/build") set(__QOLA_MANIFEST "${CMAKE_CURRENT_LIST_DIR}/qola_manifest.toml") + # Same GIT_CONFIG_GLOBAL trick as the early `ensure_aiter_commit` call: + # qola.cli build re-invokes ensure_aiter_commit internally and will hit + # the same dubious-ownership trap without it. execute_process( - COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${__QOLA_DIR}:$ENV{PYTHONPATH}" - ${Python_EXECUTABLE} -m qola.cli build - --manifest ${__QOLA_MANIFEST} - --aiter-root ${__AITER_SOURCE_DIR} - --output-dir ${__QOLA_BUILD_DIR} - --arch "${V3_ASM_ARCHS_STR}" + COMMAND sh -c + "tmp=$(mktemp /tmp/gitconfig.XXXXXX) || exit 1; \ +GIT_CONFIG_GLOBAL=$tmp git config --global --add safe.directory '*' >/dev/null 2>&1; \ +GIT_CONFIG_GLOBAL=$tmp PYTHONPATH=\"${__QOLA_DIR}:$PYTHONPATH\" '${Python_EXECUTABLE}' -m qola.cli build \ +--manifest '${__QOLA_MANIFEST}' \ +--aiter-root '${__AITER_SOURCE_DIR}' \ +--output-dir '${__QOLA_BUILD_DIR}' \ +--arch '${V3_ASM_ARCHS_STR}'; \ +rc=$?; rm -f \"$tmp\"; exit $rc" RESULT_VARIABLE QOLA_BUILD_RESULT ) if(NOT QOLA_BUILD_RESULT EQUAL 0) @@ -155,7 +168,8 @@ endforeach() add_library(ck_fused_attn SHARED ${ck_fused_attn_SOURCES}) set(CK_FUSED_ATTN_COMPILE_OPTIONS) list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS - -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT}) + -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT} + -DENABLE_CK=1) foreach(ARCH IN LISTS V3_ASM_ARCHS) list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS --offload-arch=${ARCH}) diff --git a/transformer_engine/common/ck_fused_attn/qola_manifest.toml b/transformer_engine/common/ck_fused_attn/qola_manifest.toml index 9bdce2e60..2b7e41537 100644 --- a/transformer_engine/common/ck_fused_attn/qola_manifest.toml +++ b/transformer_engine/common/ck_fused_attn/qola_manifest.toml @@ -9,9 +9,11 @@ architectures = ["gfx950", "gfx942"] [[modules]] name = "libmha_fwd" mode = "cpp_itfs" +receipt = 700 drop_srcs = ["mha_fwd_split.cu", "mha_fwd_batch_prefill.cu"] drop_directions = ["fwd_splitkv", "batch_prefill"] [[modules]] name = "libmha_bwd" mode = "cpp_itfs" +receipt = 700 From c7ecaf7a5209cb10cc5c423bc30529f222fca2dc Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 29 Apr 2026 15:41:15 +0000 Subject: [PATCH 05/10] Corrected logging --- .../ck_fused_attn/src/ck_fused_attn_bwd.cpp | 37 ++++++++----------- .../ck_fused_attn/src/ck_fused_attn_fwd.cpp | 17 +++------ 2 files changed, 22 insertions(+), 32 deletions(-) diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index 86ea82388..7e5a529b2 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -331,9 +331,8 @@ __global__ void dbias_reduce_b1ss( } // print the fmha_traits and args passed into ck apis -void log_bwd_config(const char* func_name, const aiter::mha_bwd_args& fmha_args){ +void log_bwd_config(const char* func_name, const aiter::mha_bwd_args& fmha_args, std::ostream* log_file){ - std::ostream* log_file = get_ck_log_stream(); (*log_file) << "\n" << func_name << "\n"; // fmha_traits debug @@ -447,14 +446,10 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ bool has_dbias = args.dbias_ptr != nullptr; bool is_mqa_gqa = (args.h > args.hg); - bool ck_log_config = false; - if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") - ck_log_config = true; - } + auto* log_file = get_ck_log_stream(); const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); // print kernel name on verbose mode - ck_tile::stream_config stream_config{stream, dump_path!=nullptr, get_ck_log_stream() != nullptr}; + ck_tile::stream_config stream_config{stream, dump_path!=nullptr, log_file != nullptr}; bias_enum bias_type = bias_enum::no_bias; BiasShape bias_shape = BiasShape::k11SS; @@ -584,8 +579,8 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ // lse_workspace_ptr used as buffer if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")) { if(args.is_group_mode() && std::string(env_p) == "1"){ - if(ck_log_config){ - std::cout << "attn_bwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; + if(log_file){ + *log_file << "attn_bwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; } fmha_args.max_seqlen_q = get_runtime_max_seqlen(args.b, args.cu_seqlen_q_ptr, nullptr, args.lse_workspace_ptr, stream); fmha_args.max_seqlen_k = get_runtime_max_seqlen(args.b, args.cu_seqlen_kv_ptr, nullptr, args.lse_workspace_ptr, stream); @@ -593,8 +588,8 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ } // print ck traits and args when needed - if(ck_log_config){ - log_bwd_config(__FUNCTION__, fmha_args); + if(log_file){ + log_bwd_config(__FUNCTION__, fmha_args, log_file); } float average_runtime = QOLA_NS(mha_bwd)(fmha_args, stream_config); if(average_runtime < 0){ @@ -612,7 +607,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ dim3 grid(args.max_tokens_kv, args.hg); if(args.d_qk == args.d_v){ dim3 block(args.d_qk); - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dk_dv_reduce_thd: " << "\n"; *log_file << "cu_seqlen_kv_ptr: " << args.cu_seqlen_kv_ptr << "\n"; *log_file << "cu_seqlen_kv_padded_ptr: " << args.cu_seqlen_kv_padded_ptr << "\n"; @@ -639,7 +634,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ args.stride_h_dk, args.stride_s_dk);); } else { dim3 block_dk(args.d_qk); - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dk_or_dv_reduce_thd on dk: " << "\n"; *log_file << "cu_seqlen_kv_ptr: " << args.cu_seqlen_kv_ptr << "\n"; *log_file << "cu_seqlen_kv_padded_ptr: " << args.cu_seqlen_kv_padded_ptr << "\n"; @@ -662,7 +657,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ args.stride_h_dk, args.stride_s_dk);); dim3 block_dv(args.d_v); - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dk_or_dv_reduce_thd on dv: " << "\n"; *log_file << "cu_seqlen_kv_ptr: " << args.cu_seqlen_kv_ptr << "\n"; *log_file << "cu_seqlen_kv_padded_ptr: " << args.cu_seqlen_kv_padded_ptr << "\n"; @@ -688,7 +683,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ dim3 grid(args.b, args.s_kv, args.hg); if(args.d_qk == args.d_v){ dim3 block(args.d_qk); - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dk_dv_reduce: " << "\n"; *log_file << "dk_expanded_ptr: " << args.dk_expanded_ptr << "\n"; *log_file << "dv_expanded_ptr: " << args.dv_expanded_ptr << "\n"; @@ -713,7 +708,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ args.stride_b_dk, args.stride_h_dk, args.stride_s_dk);); } else { dim3 block_dk(args.d_qk); - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dk_or_dv_reduce on dk: " << "\n"; *log_file << "dk_expanded_ptr: " << args.dk_expanded_ptr << "\n"; *log_file << "stride_b_dk_expanded: " << args.stride_b_dk_expanded << "\n"; @@ -734,7 +729,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ args.stride_b_dk, args.stride_h_dk, args.stride_s_dk);); dim3 block_dv(args.d_v); - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dk_or_dv_reduce on dv: " << "\n"; *log_file << "dv_expanded_ptr: " << args.dv_expanded_ptr << "\n"; *log_file << "stride_b_dv_expanded: " << args.stride_b_dv_expanded << "\n"; @@ -764,7 +759,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ dim3 block(THREADS_PER_BLOCK); dim3 grid(ceil(1.0 * args.s_q * args.s_kv / THREADS_PER_BLOCK)); if(bias_shape==BiasShape::k11SS){ - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dbias_reduce_11SS: " << "\n"; *log_file << "dbias_ptr: " << args.dbias_ptr << "\n"; *log_file << "dbias_expanded_ptr: " << args.dbias_expanded_ptr << "\n"; @@ -776,7 +771,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ static_cast(args.dbias_expanded_ptr), static_cast(args.dbias_ptr));); }else if(bias_shape==BiasShape::k1HSS){ - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dbias_reduce_1HSS: " << "\n"; *log_file << "dbias_ptr: " << args.dbias_ptr << "\n"; *log_file << "dbias_expanded_ptr: " << args.dbias_expanded_ptr << "\n"; @@ -788,7 +783,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ static_cast(args.dbias_expanded_ptr), static_cast(args.dbias_ptr));); }else if(bias_shape==BiasShape::kB1SS){ - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dbias_reduce_B1SS: " << "\n"; *log_file << "dbias_ptr: " << args.dbias_ptr << "\n"; *log_file << "dbias_expanded_ptr: " << args.dbias_expanded_ptr << "\n"; diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp index 0f407230c..0f4e9a424 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp @@ -15,9 +15,8 @@ namespace ck_fused_attn{ // print the fmha traits and fmha_args when calling ck apis -void log_fwd_config(const char* func_name, bool has_dropout, const aiter::mha_fwd_args& fmha_args){ +void log_fwd_config(const char* func_name, bool has_dropout, const aiter::mha_fwd_args& fmha_args, std::ostream* log_file){ - std::ostream* log_file = get_ck_log_stream(); (*log_file) << "\n" << func_name << "\n"; // debug fmha_traits @@ -103,11 +102,7 @@ hipError_t ck_attn_fwd(const CKAttnFwdArgs& args, hipStream_t stream){ bool has_dropout = (args.is_training && args.dropout_probability > 0.f); - bool ck_log_config = false; - if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") - ck_log_config = true; - } + auto* log_file = get_ck_log_stream(); const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); // print kernel name on verbose mode ck_tile::stream_config stream_config{stream, dump_path!=nullptr, get_ck_log_stream() != nullptr}; @@ -218,16 +213,16 @@ hipError_t ck_attn_fwd(const CKAttnFwdArgs& args, hipStream_t stream){ if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")){ if(args.is_group_mode() && std::string(env_p) == "1"){ - if(ck_log_config){ - std::cout << "attn_fwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; + if(log_file){ + *log_file << "attn_fwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; } fmha_args.max_seqlen_q = get_runtime_max_seqlen(args.b, args.cu_seqlen_q_ptr, args.cu_seqlen_q_padded_ptr, args.lse_ptr, stream); } } // print ck traits and fmha_args when needed - if(ck_log_config){ - log_fwd_config(__FUNCTION__, has_dropout, fmha_args); + if(log_file){ + log_fwd_config(__FUNCTION__, has_dropout, fmha_args, log_file); } float average_runtime = QOLA_NS(mha_fwd)(fmha_args, stream_config); if(average_runtime < 0){ From da6e9a63bfa79b379fabe7b525a38e73661ce00c Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 29 Apr 2026 21:01:46 +0000 Subject: [PATCH 06/10] Updated qola to build aiter w/ new third_party spec --- 3rdparty/QoLA | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/QoLA b/3rdparty/QoLA index a597de03f..aac57fec6 160000 --- a/3rdparty/QoLA +++ b/3rdparty/QoLA @@ -1 +1 @@ -Subproject commit a597de03f36bf4ea59fe3681675c45e24c441669 +Subproject commit aac57fec69b37a8b51922246a4497275987f9a68 From e7ed124f2e829a0321c2316f590ae6112f90012c Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 29 Apr 2026 21:02:09 +0000 Subject: [PATCH 07/10] Added guards against AITER known buggy implementations --- .../common/ck_fused_attn/src/ck_fused_attn_bwd.cpp | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index 7e5a529b2..638e4b877 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -461,7 +461,18 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ fmha_args.sink_ptr = nullptr; fmha_args.d_sink_ptr = nullptr; fmha_args.mask_type = static_cast(static_cast(args.attn_mask_type)); - fmha_args.use_asm_v3 = args.uses_bwd_v3; + // Mirrors AITER's small-seqlen guard at aiter/ops/mha.py:1689. + const bool buggy_small_sq = (args.s_q < 16); + // Predicate matches exactly bwd_hd128_bf16_causal_br_a32_psskddv_group.co + // (broken by AITER PR #2189). Other psskddv_group variants are unaffected. + const bool buggy_br_psskddv_group = + args.is_group_mode() && + args.attn_mask_type == ck_fused_attn::MaskType::mask_bottom_right && + args.dtype == ck_fused_attn::DType::kBFloat16 && + args.d_qk == 128 && args.d_v == 128 && + args.is_v3_atomic_fp32; + fmha_args.use_asm_v3 = + (buggy_small_sq || buggy_br_psskddv_group) ? false : args.uses_bwd_v3; fmha_args.v3_atomic_fp32 = args.is_v3_atomic_fp32; fmha_args.v3_bf16_cvt = args.how_v3_bf16_cvt; fmha_args.v3_api_check = false; From 6241f99d2a6c34c14cec4e4169925451aaf08b1e Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 12 May 2026 19:58:49 +0000 Subject: [PATCH 08/10] Updated build --- 3rdparty/QoLA | 2 +- .../common/ck_fused_attn/CMakeLists.txt | 56 ++++++++++--------- .../common/ck_fused_attn/qola_manifest.toml | 2 +- 3 files changed, 33 insertions(+), 27 deletions(-) diff --git a/3rdparty/QoLA b/3rdparty/QoLA index aac57fec6..9c13e77ef 160000 --- a/3rdparty/QoLA +++ b/3rdparty/QoLA @@ -1 +1 @@ -Subproject commit aac57fec69b37a8b51922246a4497275987f9a68 +Subproject commit 9c13e77ef3cf89053aad61ed3a0f27470f123ee5 diff --git a/transformer_engine/common/ck_fused_attn/CMakeLists.txt b/transformer_engine/common/ck_fused_attn/CMakeLists.txt index 5aea4db7d..0f78311f2 100644 --- a/transformer_engine/common/ck_fused_attn/CMakeLists.txt +++ b/transformer_engine/common/ck_fused_attn/CMakeLists.txt @@ -8,41 +8,30 @@ project(ck_fused_attn LANGUAGES HIP CXX) set(AITER_MHA_INSTALL_PREFIX "transformer_engine" CACHE STRING "aiter mha shared lib install prefix in TE") -set(__AITER_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/QoLA/3rdparty/aiter") +# QoLA no longer vendors AITER as a submodule; it clones on demand into +# build/third_party/aiter (git-ignored) via `qola checkout`. Mirror that +# default here so the source-build path and the header-include paths +# resolve to the same tree. +set(__QOLA_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/QoLA") +set(__AITER_SOURCE_DIR "${__QOLA_DIR}/build/third_party/aiter") set(__CK_SOURCE_DIR "${__AITER_SOURCE_DIR}/3rdparty/composable_kernel") - set(CK_INCLUDE_DIR "${__CK_SOURCE_DIR}/include") -message(STATUS "ck_include_dir: ${CK_INCLUDE_DIR}") -if(NOT EXISTS "${CK_INCLUDE_DIR}") - message(FATAL_ERROR - "Could not find CK API. " - "Try running 'git submodule update --init --recursive' " - "within the Transformer Engine source.") -endif() - set(AITER_INCLUDE_DIR "${__AITER_SOURCE_DIR}/csrc/include") -message(STATUS "aiter_include_dir: ${AITER_INCLUDE_DIR}") -if(NOT EXISTS "${AITER_INCLUDE_DIR}") - message(FATAL_ERROR - "Could not find AITER API. " - "Try running 'git submodule update --init --recursive' " - "within the Transformer Engine source.") -endif() if(NOT Python_EXECUTABLE) find_package(Python COMPONENTS Interpreter QUIET) endif() # Resolve the manifest-pinned AITER commit (defines AITER_SHA) and bring the -# QoLA-nested AITER source tree to that commit before any consumer reads it +# QoLA-managed AITER source tree to that commit before any consumer reads it # (header validation below, header includes for the .cpp build later, and # QoLA's own kernel build if the prebuilt cache misses). include("${CMAKE_CURRENT_LIST_DIR}/aiter_prebuilt.cmake") if(Python_EXECUTABLE) - set(__QOLA_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/QoLA") + set(__QOLA_MANIFEST "${CMAKE_CURRENT_LIST_DIR}/qola_manifest.toml") # Redirect GIT_CONFIG_GLOBAL to a tempfile carrying `safe.directory = *` so - # git operations inside the QoLA-nested AITER tree (and its recursive + # git operations inside the QoLA-managed AITER tree (and its recursive # submodules) work in containerized builds where the bind-mounted .git is # owned by a different UID than the build process. Mirrors the pattern in # transformer_engine/common/CMakeLists.txt:get_git_commit(). @@ -50,7 +39,9 @@ if(Python_EXECUTABLE) COMMAND sh -c "tmp=$(mktemp /tmp/gitconfig.XXXXXX) || exit 1; \ GIT_CONFIG_GLOBAL=$tmp git config --global --add safe.directory '*' >/dev/null 2>&1; \ -GIT_CONFIG_GLOBAL=$tmp PYTHONPATH=\"${__QOLA_DIR}:$PYTHONPATH\" '${Python_EXECUTABLE}' -c 'from qola.build_tools.submodule import ensure_aiter_commit; ensure_aiter_commit(r\"${__AITER_SOURCE_DIR}\", r\"${AITER_SHA}\")'; \ +GIT_CONFIG_GLOBAL=$tmp PYTHONPATH=\"${__QOLA_DIR}:$PYTHONPATH\" '${Python_EXECUTABLE}' -m qola.cli checkout \ +--manifest '${__QOLA_MANIFEST}' \ +--aiter-root '${__AITER_SOURCE_DIR}'; \ rc=$?; rm -f \"$tmp\"; exit $rc" RESULT_VARIABLE AITER_CHECKOUT_RESULT OUTPUT_VARIABLE AITER_CHECKOUT_OUTPUT @@ -88,6 +79,23 @@ else() message(WARNING "Python interpreter not found; skipping AITER source-tree sync and API validation.") endif() +# Sanity-check the resolved include directories now that `qola checkout` has +# materialized the AITER tree. +message(STATUS "ck_include_dir: ${CK_INCLUDE_DIR}") +if(NOT EXISTS "${CK_INCLUDE_DIR}") + message(FATAL_ERROR + "Could not find CK API at ${CK_INCLUDE_DIR}. " + "Re-run the build to let `qola checkout` clone AITER and its " + "composable_kernel submodule.") +endif() + +message(STATUS "aiter_include_dir: ${AITER_INCLUDE_DIR}") +if(NOT EXISTS "${AITER_INCLUDE_DIR}") + message(FATAL_ERROR + "Could not find AITER API at ${AITER_INCLUDE_DIR}. " + "Re-run the build to let `qola checkout` clone AITER.") +endif() + # so far, there are only gfx942 and gfx950 v3 kernels SET(V3_ASM_ARCHS_SUPPORTED "gfx942;gfx950") @@ -118,11 +126,9 @@ else() if(__AITER_MHA_PATH STREQUAL "") # If not available, fallback: Build from source via QoLA message(STATUS "[AITER-BUILD] Building AITER kernels via QoLA.") - set(__QOLA_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/QoLA") set(__QOLA_BUILD_DIR "${__QOLA_DIR}/build") - set(__QOLA_MANIFEST "${CMAKE_CURRENT_LIST_DIR}/qola_manifest.toml") - # Same GIT_CONFIG_GLOBAL trick as the early `ensure_aiter_commit` call: - # qola.cli build re-invokes ensure_aiter_commit internally and will hit + # Same GIT_CONFIG_GLOBAL trick as the earlier `qola.cli checkout` call: + # `qola.cli build` re-invokes ensure_aiter_commit internally and will hit # the same dubious-ownership trap without it. execute_process( COMMAND sh -c diff --git a/transformer_engine/common/ck_fused_attn/qola_manifest.toml b/transformer_engine/common/ck_fused_attn/qola_manifest.toml index 2b7e41537..255d11c18 100644 --- a/transformer_engine/common/ck_fused_attn/qola_manifest.toml +++ b/transformer_engine/common/ck_fused_attn/qola_manifest.toml @@ -1,5 +1,5 @@ [qola] -aiter_commit = "8f816a049449f39609ee7daca8c21d63aa4274ed" # pinned AITER submodule commit +aiter_commit = "4b00d2ea91e88b5381ea7051521956a716485f30" # pinned AITER submodule commit namespace = "te" rocm_versions = ["7.2"] From 82ebdb6c97a6adf1b0d1721bc4167329b3e9e926 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 19 May 2026 14:45:31 +0000 Subject: [PATCH 09/10] Drop guard for corrected bug --- .../common/ck_fused_attn/src/ck_fused_attn_bwd.cpp | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index 638e4b877..0ff659949 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -462,17 +462,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ fmha_args.d_sink_ptr = nullptr; fmha_args.mask_type = static_cast(static_cast(args.attn_mask_type)); // Mirrors AITER's small-seqlen guard at aiter/ops/mha.py:1689. - const bool buggy_small_sq = (args.s_q < 16); - // Predicate matches exactly bwd_hd128_bf16_causal_br_a32_psskddv_group.co - // (broken by AITER PR #2189). Other psskddv_group variants are unaffected. - const bool buggy_br_psskddv_group = - args.is_group_mode() && - args.attn_mask_type == ck_fused_attn::MaskType::mask_bottom_right && - args.dtype == ck_fused_attn::DType::kBFloat16 && - args.d_qk == 128 && args.d_v == 128 && - args.is_v3_atomic_fp32; - fmha_args.use_asm_v3 = - (buggy_small_sq || buggy_br_psskddv_group) ? false : args.uses_bwd_v3; + fmha_args.use_asm_v3 = (args.s_q < 16) ? false : args.uses_bwd_v3; fmha_args.v3_atomic_fp32 = args.is_v3_atomic_fp32; fmha_args.v3_bf16_cvt = args.how_v3_bf16_cvt; fmha_args.v3_api_check = false; From f9ab59c28fc8222859a6bb8d9ea9859cf49742ac Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 28 May 2026 18:29:26 +0000 Subject: [PATCH 10/10] Update AITER commit, adopt new API --- .../ck_fused_attn/check_aiter_mha_args.py | 2 +- .../include/ck_fused_attn/ck_fused_attn.hpp | 1 - .../common/ck_fused_attn/qola_manifest.toml | 2 +- .../ck_fused_attn/src/ck_fused_attn_bwd.cpp | 67 +++++++++++++++---- .../common/fused_attn_rocm/fused_attn_ck.cpp | 9 --- 5 files changed, 57 insertions(+), 24 deletions(-) diff --git a/transformer_engine/common/ck_fused_attn/check_aiter_mha_args.py b/transformer_engine/common/ck_fused_attn/check_aiter_mha_args.py index 2cce9484f..6bae3091d 100644 --- a/transformer_engine/common/ck_fused_attn/check_aiter_mha_args.py +++ b/transformer_engine/common/ck_fused_attn/check_aiter_mha_args.py @@ -31,7 +31,7 @@ def parse_with_skip_comments(buffer, line, regex, outputs): def extract_fields_from_header(text: str, struct_name: str) -> List[str]: - struct_field_re = re.compile(r"([A-Za-z_][A-Za-z0-9_]*)\s*(?:=[^;]*)?;\s*$") + struct_field_re = re.compile(r"([A-Za-z_][A-Za-z0-9_]*)\s*(?:=[^;]*|\{[^;]*\})?;\s*$") struct_end_re = re.compile(r"^\s*};\s*$") struct_start_re = re.compile(rf"\bstruct\s+{re.escape(struct_name)}\b") diff --git a/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp b/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp index 127d75b4c..abd1ec371 100644 --- a/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp +++ b/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp @@ -113,7 +113,6 @@ struct CkAttnBwdArgs : CKAttnCommonArgs { // dQ void* dq_ptr = nullptr; uint64_t stride_b_dq = 0, stride_h_dq = 0, stride_s_dq = 0; - void* dq_acc_ptr = nullptr; // dK / dV expanded (MQA/GQA reduction inputs; null when h==hg) void* dk_expanded_ptr = nullptr; diff --git a/transformer_engine/common/ck_fused_attn/qola_manifest.toml b/transformer_engine/common/ck_fused_attn/qola_manifest.toml index 255d11c18..2b445ae08 100644 --- a/transformer_engine/common/ck_fused_attn/qola_manifest.toml +++ b/transformer_engine/common/ck_fused_attn/qola_manifest.toml @@ -1,5 +1,5 @@ [qola] -aiter_commit = "4b00d2ea91e88b5381ea7051521956a716485f30" # pinned AITER submodule commit +aiter_commit = "e3940660b40f4764cdf09147af96a2a764f264be" # pinned AITER submodule commit namespace = "te" rocm_versions = ["7.2"] diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index 0ff659949..145d9b139 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -6,9 +6,13 @@ #include #include +#include #include #include +#include +#include #include "ck_fused_attn/ck_fused_attn.hpp" +#include "ck_tile/host/pinned_host_releaser.hpp" #include "qola_mha_bwd.h" #include "ck_fused_attn_utils.hpp" @@ -364,7 +368,6 @@ void log_bwd_config(const char* func_name, const aiter::mha_bwd_args& fmha_args, log_value(log_file, "dk_ptr", fmha_args.dk_ptr); log_value(log_file, "dv_ptr", fmha_args.dv_ptr); log_value(log_file, "dbias_ptr", fmha_args.dbias_ptr); - log_value(log_file, "dq_acc_ptr", fmha_args.dq_acc_ptr); log_value(log_file, "seqstart_q_ptr", fmha_args.seqstart_q_ptr); log_value(log_file, "seqstart_k_ptr", fmha_args.seqstart_k_ptr); @@ -389,7 +392,6 @@ void log_bwd_config(const char* func_name, const aiter::mha_bwd_args& fmha_args, log_value(log_file, "stride_o", fmha_args.stride_o); log_value(log_file, "stride_randval", fmha_args.stride_randval); log_value(log_file, "stride_do", fmha_args.stride_do); - log_value(log_file, "stride_dq_acc", fmha_args.stride_dq_acc); log_value(log_file, "stride_dq", fmha_args.stride_dq); log_value(log_file, "stride_dk", fmha_args.stride_dk); log_value(log_file, "stride_dv", fmha_args.stride_dv); @@ -402,7 +404,6 @@ void log_bwd_config(const char* func_name, const aiter::mha_bwd_args& fmha_args, log_value(log_file, "nhead_stride_randval", fmha_args.nhead_stride_randval); log_value(log_file, "nhead_stride_do", fmha_args.nhead_stride_do); log_value(log_file, "nhead_stride_lsed", fmha_args.nhead_stride_lsed); - log_value(log_file, "nhead_stride_dq_acc", fmha_args.nhead_stride_dq_acc); log_value(log_file, "nhead_stride_dq", fmha_args.nhead_stride_dq); log_value(log_file, "nhead_stride_dk", fmha_args.nhead_stride_dk); log_value(log_file, "nhead_stride_dv", fmha_args.nhead_stride_dv); @@ -415,7 +416,6 @@ void log_bwd_config(const char* func_name, const aiter::mha_bwd_args& fmha_args, log_value(log_file, "batch_stride_randval", fmha_args.batch_stride_randval); log_value(log_file, "batch_stride_do", fmha_args.batch_stride_do); log_value(log_file, "batch_stride_lsed", fmha_args.batch_stride_lsed); - log_value(log_file, "batch_stride_dq_acc", fmha_args.batch_stride_dq_acc); log_value(log_file, "batch_stride_dq", fmha_args.batch_stride_dq); log_value(log_file, "batch_stride_dk", fmha_args.batch_stride_dk); log_value(log_file, "batch_stride_dv", fmha_args.batch_stride_dv); @@ -493,7 +493,6 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ fmha_args.dbias_ptr = ((!args.is_group_mode()) && has_dbias) ? (bias_shape==BiasShape::kBHSS ? args.dbias_ptr : args.dbias_expanded_ptr) : nullptr; - fmha_args.dq_acc_ptr = args.dq_acc_ptr; if (args.is_group_mode()) { fmha_args.seqstart_q_ptr = args.cu_seqlen_q_padded_ptr==nullptr? args.cu_seqlen_q_ptr : args.cu_seqlen_q_padded_ptr; @@ -509,8 +508,13 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ fmha_args.seqlen_q_ptr = nullptr; fmha_args.seqlen_k_ptr = nullptr; - fmha_args.seqlen_q = args.s_q; - fmha_args.seqlen_k = args.s_kv; + // Group mode contract (matches aiter asm_mha_varlen_bwd.cu): seqlen_q/k + // carry the total token counts, max_seqlen_q/k the per-sequence maximum. + // aiter sizes dq_acc and related workspaces from seqlen_q; passing the + // per-sequence length in group mode under-sizes them and the kernel writes + // past the end. + fmha_args.seqlen_q = args.is_group_mode() ? args.max_tokens_q : args.s_q; + fmha_args.seqlen_k = args.is_group_mode() ? args.max_tokens_kv : args.s_kv; fmha_args.batch = args.b; fmha_args.max_seqlen_q = args.s_q; fmha_args.max_seqlen_k = args.s_kv; @@ -527,8 +531,6 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ fmha_args.stride_o = args.stride_s_o; fmha_args.stride_randval = args.s_kv; fmha_args.stride_do = args.stride_s_do; - //dq_acc of shape (nsplits, B, H, S, D) - fmha_args.stride_dq_acc = args.d_qk; fmha_args.stride_dq = args.stride_s_dq; fmha_args.stride_dk = is_mqa_gqa? args.stride_s_dk_expanded : args.stride_s_dk; fmha_args.stride_dv = is_mqa_gqa? args.stride_s_dv_expanded : args.stride_s_dv; @@ -546,7 +548,6 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ fmha_args.nhead_stride_randval = args.is_group_mode() ? 0 : args.s_q * args.s_kv; fmha_args.nhead_stride_do = args.stride_h_do; fmha_args.nhead_stride_lsed = args.is_group_mode() ? args.max_tokens_q : args.s_q; - fmha_args.nhead_stride_dq_acc = static_cast((args.is_group_mode() ? args.max_tokens_q : args.s_q) * args.d_qk); fmha_args.nhead_stride_dq = args.stride_h_dq; fmha_args.nhead_stride_dk = is_mqa_gqa? args.stride_h_dk_expanded : args.stride_h_dk; fmha_args.nhead_stride_dv = is_mqa_gqa? args.stride_h_dv_expanded : args.stride_h_dv; @@ -562,13 +563,11 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ fmha_args.batch_stride_randval = args.is_group_mode() ? 0 : args.h * args.s_q * args.s_kv; fmha_args.batch_stride_do = args.is_group_mode() ? 0 : args.stride_b_do; fmha_args.batch_stride_lsed = args.is_group_mode() ? 0 : args.h * args.s_q; - fmha_args.batch_stride_dq_acc = args.is_group_mode() ? 0 : static_cast(args.h * args.s_q * args.d_qk); fmha_args.batch_stride_dq = args.is_group_mode() ? 0 : args.stride_b_dq; fmha_args.batch_stride_dk = args.is_group_mode() ? 0 : (is_mqa_gqa? args.stride_b_dk_expanded : args.stride_b_dk); fmha_args.batch_stride_dv = args.is_group_mode() ? 0 : (is_mqa_gqa? args.stride_b_dv_expanded : args.stride_b_dv); // for dbias, use h since h can be different from bias_h fmha_args.batch_stride_dbias = args.is_group_mode() ? 0 : args.h * args.s_q * args.s_kv; - fmha_args.split_stride_dq_acc = static_cast(args.is_group_mode() ? (args.max_tokens_q * args.h * args.d_qk) : (args.b * args.h * args.s_q * args.d_qk)); fmha_args.window_size_left = args.window_size_left; fmha_args.window_size_right = args.window_size_right; @@ -588,11 +587,55 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ } } + // Device-side workspace allocations made inside mha_bwd (launcher metadata + // and the dq_acc accumulator). aiter only contracts that the pointer remain + // valid for the duration of the kernels it enqueues; hipFreeAsync on the + // same stream defers the free until that work completes. + std::vector mha_bwd_workspaces; + fmha_args.workspace_alloc = [&mha_bwd_workspaces, stream](size_t bytes, bool zero_init) -> void* { + if(bytes == 0){ + return nullptr; + } + void* ptr = nullptr; + if(hipMallocAsync(&ptr, bytes, stream) != hipSuccess){ + throw std::runtime_error("ck_fused_attn bwd: hipMallocAsync failed for AITER workspace."); + } + if(zero_init){ + if(hipMemsetAsync(ptr, 0, bytes, stream) != hipSuccess){ + hipFreeAsync(ptr, stream); + throw std::runtime_error("ck_fused_attn bwd: hipMemsetAsync failed for AITER workspace."); + } + } + mha_bwd_workspaces.push_back(ptr); + return ptr; + }; + // Group mode requires a pinned host buffer for the async D2H seqstart + // pipeline; aiter keeps the shared_ptr alive past kernel completion via a + // stream-tail hipLaunchHostFunc keepalive. The deleter fires from that HIP + // callback thread, which holds runtime locks — calling any HIP API from it + // (including hipHostFree) deadlocks against concurrent main-thread HIP + // calls. Defer the free to ck_tile::pinned_host_releaser's worker thread. + fmha_args.pinned_host_alloc = [](size_t bytes) -> std::shared_ptr { + if(bytes == 0){ + return {}; + } + void* ptr = nullptr; + if(hipHostMalloc(&ptr, bytes, hipHostMallocDefault) != hipSuccess){ + throw std::runtime_error("ck_fused_attn bwd: hipHostMalloc failed for AITER pinned host buffer."); + } + return std::shared_ptr(ptr, [](void* p){ + ck_tile::pinned_host_releaser::instance().enqueue(p); + }); + }; + // print ck traits and args when needed if(log_file){ log_bwd_config(__FUNCTION__, fmha_args, log_file); } float average_runtime = QOLA_NS(mha_bwd)(fmha_args, stream_config); + for(void* ws_ptr : mha_bwd_workspaces){ + hipFreeAsync(ws_ptr, stream); + } if(average_runtime < 0){ //TODO: better error out system throw std::runtime_error("fused attn configs not supported in ck_fused_attn bwd pass."); diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index 744d0575a..7586a8388 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -743,9 +743,6 @@ void fused_attn_ck_bwd_impl( bool is_mqa_gqa = (h > hg); - size_t kN0 = (d_qk <= 128)? 128:64; - size_t nsplits = deterministic? ceil(1.0*s_kv/kN0):1; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(layout); bool is_ragged = qkv_format==NVTE_QKV_Format::NVTE_THD; bool is_SBHD = qkv_format==NVTE_QKV_Format::NVTE_SBHD || qkv_format==NVTE_QKV_Format::NVTE_SBHD_2BSHD; @@ -770,9 +767,6 @@ void fused_attn_ck_bwd_impl( // First h*max_tokens_q*sizeof(float) is the lse-d buffer (passed as softmax_lsed) void* lse_workspace = planner.allocate(h*max_tokens_q*sizeof(float)); - // CK requires dq_acc ptr; size depends on deterministic mode - void* dq_acc_ptr = planner.allocate(nsplits*h*max_tokens_q*d_qk*sizeof(float)); - void* dk_expanded_ptr = nullptr; void* dv_expanded_ptr = nullptr; std::array dk_expanded_stride; @@ -913,8 +907,6 @@ void fused_attn_ck_bwd_impl( } // Initialize workspace buffers. - // dq_acc is of shape (nsplits, B, S, H, D_qk); CK requires zeroing - NVTE_CHECK_CUDA(cudaMemsetAsync(dq_acc_ptr, 0, sizeof(float)*nsplits*h*max_tokens_q*d_qk, stream)); if(devPtrAlibiSlope){ dim3 block, grid; block.x = 1024; @@ -992,7 +984,6 @@ void fused_attn_ck_bwd_impl( ck_args.attn_mask_type = set_ck_mask(mask_type, window_size_left, window_size_right); ck_args.window_size_left = window_size_left; ck_args.window_size_right = window_size_right; - ck_args.dq_acc_ptr = dq_acc_ptr; ck_args.dk_expanded_ptr = dk_expanded_ptr; ck_args.dv_expanded_ptr = dv_expanded_ptr; ck_args.lse_workspace_ptr = lse_workspace;