diff --git a/3rdparty/QoLA b/3rdparty/QoLA index 549844d77..9c13e77ef 160000 --- a/3rdparty/QoLA +++ b/3rdparty/QoLA @@ -1 +1 @@ -Subproject commit 549844d771ed3155dd75a6bf2c714cb3f710bada +Subproject commit 9c13e77ef3cf89053aad61ed3a0f27470f123ee5 diff --git a/transformer_engine/common/ck_fused_attn/CMakeLists.txt b/transformer_engine/common/ck_fused_attn/CMakeLists.txt index 4640374aa..09329ddeb 100644 --- a/transformer_engine/common/ck_fused_attn/CMakeLists.txt +++ b/transformer_engine/common/ck_fused_attn/CMakeLists.txt @@ -22,34 +22,56 @@ if(NOT _gfx1250_idx EQUAL -1) endif() set(GPU_TARGETS ${CMAKE_HIP_ARCHITECTURES}) endif() -set(__AITER_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/QoLA/3rdparty/aiter") +set(__QOLA_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/QoLA") +set(__AITER_SOURCE_DIR "${__QOLA_DIR}/build/third_party/aiter") set(__CK_SOURCE_DIR "${__AITER_SOURCE_DIR}/3rdparty/composable_kernel") - set(CK_INCLUDE_DIR "${__CK_SOURCE_DIR}/include") -message(STATUS "ck_include_dir: ${CK_INCLUDE_DIR}") -if(NOT EXISTS "${CK_INCLUDE_DIR}") - message(FATAL_ERROR - "Could not find CK API. " - "Try running 'git submodule update --init --recursive' " - "within the Transformer Engine source.") -endif() - set(AITER_INCLUDE_DIR "${__AITER_SOURCE_DIR}/csrc/include") -message(STATUS "aiter_include_dir: ${AITER_INCLUDE_DIR}") -if(NOT EXISTS "${AITER_INCLUDE_DIR}") - message(FATAL_ERROR - "Could not find AITER API. " - "Try running 'git submodule update --init --recursive' " - "within the Transformer Engine source.") -endif() if(NOT Python_EXECUTABLE) find_package(Python COMPONENTS Interpreter QUIET) endif() +# Resolve the manifest-pinned AITER commit (defines AITER_SHA) and bring the +# QoLA-managed AITER source tree to that commit before any consumer reads it +# (header validation below, header includes for the .cpp build later, and +# QoLA's own kernel build if the prebuilt cache misses). +include("${CMAKE_CURRENT_LIST_DIR}/aiter_prebuilt.cmake") + if(Python_EXECUTABLE) + set(__QOLA_MANIFEST "${CMAKE_CURRENT_LIST_DIR}/qola_manifest.toml") + # Redirect GIT_CONFIG_GLOBAL to a tempfile carrying `safe.directory = *` so + # git operations inside the QoLA-managed AITER tree (and its recursive + # submodules) work in containerized builds where the bind-mounted .git is + # owned by a different UID than the build process. Mirrors the pattern in + # transformer_engine/common/CMakeLists.txt:get_git_commit(). + execute_process( + COMMAND sh -c + "tmp=$(mktemp /tmp/gitconfig.XXXXXX) || exit 1; \ +GIT_CONFIG_GLOBAL=$tmp git config --global --add safe.directory '*' >/dev/null 2>&1; \ +GIT_CONFIG_GLOBAL=$tmp PYTHONPATH=\"${__QOLA_DIR}:$PYTHONPATH\" '${Python_EXECUTABLE}' -m qola.cli checkout \ +--manifest '${__QOLA_MANIFEST}' \ +--aiter-root '${__AITER_SOURCE_DIR}'; \ +rc=$?; rm -f \"$tmp\"; exit $rc" + RESULT_VARIABLE AITER_CHECKOUT_RESULT + OUTPUT_VARIABLE AITER_CHECKOUT_OUTPUT + ERROR_VARIABLE AITER_CHECKOUT_ERROR + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_STRIP_TRAILING_WHITESPACE + ) + if(NOT AITER_CHECKOUT_RESULT EQUAL 0) + message(FATAL_ERROR + "Failed to sync AITER source tree at ${__AITER_SOURCE_DIR} to " + "manifest-pinned commit ${AITER_SHA}.\n" + "${AITER_CHECKOUT_OUTPUT}\n${AITER_CHECKOUT_ERROR}") + endif() + message(STATUS "[AITER] Synced ${__AITER_SOURCE_DIR} to ${AITER_SHA}") + execute_process( - COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/check_aiter_mha_args.py --mode both --te-dir "${CMAKE_CURRENT_LIST_DIR}/../../.." + COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/check_aiter_mha_args.py + --mode both + --te-dir "${CMAKE_CURRENT_LIST_DIR}/../../.." + --aiter-root "${__AITER_SOURCE_DIR}" RESULT_VARIABLE AITER_ARG_CHECK_RESULT OUTPUT_VARIABLE AITER_ARG_CHECK_OUTPUT ERROR_VARIABLE AITER_ARG_CHECK_ERROR @@ -64,7 +86,24 @@ if(Python_EXECUTABLE) endif() message(STATUS "AITER API validation passed via check_aiter_mha_args.py") else() - message(WARNING "Python interpreter not found; skipping AITER API validation.") + message(WARNING "Python interpreter not found; skipping AITER source-tree sync and API validation.") +endif() + +# Sanity-check the resolved include directories now that `qola checkout` has +# materialized the AITER tree. +message(STATUS "ck_include_dir: ${CK_INCLUDE_DIR}") +if(NOT EXISTS "${CK_INCLUDE_DIR}") + message(FATAL_ERROR + "Could not find CK API at ${CK_INCLUDE_DIR}. " + "Re-run the build to let `qola checkout` clone AITER and its " + "composable_kernel submodule.") +endif() + +message(STATUS "aiter_include_dir: ${AITER_INCLUDE_DIR}") +if(NOT EXISTS "${AITER_INCLUDE_DIR}") + message(FATAL_ERROR + "Could not find AITER API at ${AITER_INCLUDE_DIR}. " + "Re-run the build to let `qola checkout` clone AITER.") endif() if(DEFINED AITER_MHA_PATH) @@ -73,16 +112,16 @@ if(DEFINED AITER_MHA_PATH) set(__AITER_MHA_PATH ${AITER_MHA_PATH}) else() set(__AITER_MHA_PATH "") - include("${CMAKE_CURRENT_LIST_DIR}/aiter_prebuilt.cmake") get_prebuilt_aiter(__AITER_MHA_PATH) if(__AITER_MHA_PATH STREQUAL "") # If not available, fallback: Build from source via QoLA list(JOIN CMAKE_HIP_ARCHITECTURES ";" GPU_ARCHS_STR) message(STATUS "[AITER-BUILD] Building AITER kernels for ${GPU_ARCHS_STR} via QoLA.") - set(__QOLA_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/QoLA") set(__QOLA_BUILD_DIR "${__QOLA_DIR}/build") - set(__QOLA_MANIFEST "${CMAKE_CURRENT_LIST_DIR}/qola_manifest.toml") + # Same GIT_CONFIG_GLOBAL trick as the earlier `qola.cli checkout` call: + # `qola.cli build` re-invokes ensure_aiter_commit internally and will hit + # the same dubious-ownership trap without it. execute_process( COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${__QOLA_DIR}:$ENV{PYTHONPATH}" ${Python_EXECUTABLE} -m qola.cli build @@ -124,7 +163,8 @@ endforeach() add_library(ck_fused_attn SHARED ${ck_fused_attn_SOURCES}) set(CK_FUSED_ATTN_COMPILE_OPTIONS) list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS - -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT}) + -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT} + -DENABLE_CK=1) # Public QoLA headers ship alongside the .so libs in ${__AITER_MHA_PATH}/../include # (emitted by qola.cli build, or copied from the QoLA build dir above for the diff --git a/transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake b/transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake index ea0396116..65ee3ec81 100644 --- a/transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake +++ b/transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake @@ -18,8 +18,22 @@ string(STRIP "${ROCM_VER_CONTENT}" ROCM_VER_CONTENT) string(REGEX MATCH "^[0-9]+\\.[0-9]+" ROCM_VER "${ROCM_VER_CONTENT}") string(REGEX MATCH "^[0-9]+" ROCM_VER_MAJOR "${ROCM_VER}") -# AITER commit -get_git_commit("${__AITER_SOURCE_DIR}" AITER_SHA) +# AITER commit — read from the QoLA manifest so the cache key tracks the +# commit QoLA will actually check out and build, not whatever happens to be +# the submodule's current HEAD at configure time. +set(__QOLA_MANIFEST "${CMAKE_CURRENT_LIST_DIR}/qola_manifest.toml") +set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS "${__QOLA_MANIFEST}") +file(STRINGS "${__QOLA_MANIFEST}" __AITER_COMMIT_LINES + REGEX "^[ \t]*aiter_commit[ \t]*=[ \t]*\"[^\"]+\"") +list(LENGTH __AITER_COMMIT_LINES __AITER_COMMIT_COUNT) +if(NOT __AITER_COMMIT_COUNT EQUAL 1) + message(FATAL_ERROR + "Expected exactly one 'aiter_commit = \"...\"' line in " + "${__QOLA_MANIFEST}, found ${__AITER_COMMIT_COUNT}.") +endif() +list(GET __AITER_COMMIT_LINES 0 __AITER_COMMIT_LINE) +string(REGEX MATCH "\"([^\"]+)\"" _UNUSED "${__AITER_COMMIT_LINE}") +set(AITER_SHA "${CMAKE_MATCH_1}") # Cache key & local paths set(AITER_CACHE_ROOT "${CMAKE_CURRENT_LIST_DIR}/../../../build/aiter-prebuilts") diff --git a/transformer_engine/common/ck_fused_attn/check_aiter_mha_args.py b/transformer_engine/common/ck_fused_attn/check_aiter_mha_args.py index 2e9831f1a..6bae3091d 100644 --- a/transformer_engine/common/ck_fused_attn/check_aiter_mha_args.py +++ b/transformer_engine/common/ck_fused_attn/check_aiter_mha_args.py @@ -31,7 +31,7 @@ def parse_with_skip_comments(buffer, line, regex, outputs): def extract_fields_from_header(text: str, struct_name: str) -> List[str]: - struct_field_re = re.compile(r"([A-Za-z_][A-Za-z0-9_]*)\s*(?:=[^;]*)?;\s*$") + struct_field_re = re.compile(r"([A-Za-z_][A-Za-z0-9_]*)\s*(?:=[^;]*|\{[^;]*\})?;\s*$") struct_end_re = re.compile(r"^\s*};\s*$") struct_start_re = re.compile(rf"\bstruct\s+{re.escape(struct_name)}\b") @@ -64,11 +64,14 @@ def main() -> int: parser = argparse.ArgumentParser(description="Check aiter args usage vs header definition") parser.add_argument("--mode", choices=["fwd", "bwd", "both"], default="both", help="Mode: fwd, bwd, or both") parser.add_argument("--te-dir", type=Path, default=Path(__file__).parent.parent.parent.parent, help="Root directory of TransformerEngine") + parser.add_argument("--aiter-root", type=Path, default=None, + help="AITER source tree root. Defaults to /3rdparty/aiter.") args = parser.parse_args() + aiter_root = args.aiter_root if args.aiter_root else args.te_dir / "3rdparty/aiter" modes = ["fwd", "bwd"] if args.mode == "both" else [args.mode] mismatch = 0 for mode in modes: - header_path = args.te_dir / f"3rdparty/aiter/csrc/include/mha_{mode}.h" + header_path = aiter_root / f"csrc/include/mha_{mode}.h" source_path = args.te_dir / f"transformer_engine/common/ck_fused_attn/src/ck_fused_attn_{mode}.cpp" header_text = header_path.read_text(encoding="utf-8") source_text = source_path.read_text(encoding="utf-8") diff --git a/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp b/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp index 127d75b4c..abd1ec371 100644 --- a/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp +++ b/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp @@ -113,7 +113,6 @@ struct CkAttnBwdArgs : CKAttnCommonArgs { // dQ void* dq_ptr = nullptr; uint64_t stride_b_dq = 0, stride_h_dq = 0, stride_s_dq = 0; - void* dq_acc_ptr = nullptr; // dK / dV expanded (MQA/GQA reduction inputs; null when h==hg) void* dk_expanded_ptr = nullptr; diff --git a/transformer_engine/common/ck_fused_attn/qola_manifest.toml b/transformer_engine/common/ck_fused_attn/qola_manifest.toml index b6877c6c0..2b445ae08 100644 --- a/transformer_engine/common/ck_fused_attn/qola_manifest.toml +++ b/transformer_engine/common/ck_fused_attn/qola_manifest.toml @@ -1,5 +1,5 @@ [qola] -aiter_commit = "33f2e6af5f39379c739720080ed0033d533f5cb2" # pinned AITER submodule commit +aiter_commit = "e3940660b40f4764cdf09147af96a2a764f264be" # pinned AITER submodule commit namespace = "te" rocm_versions = ["7.2"] @@ -9,9 +9,11 @@ architectures = ["gfx950", "gfx942"] [[modules]] name = "libmha_fwd" mode = "cpp_itfs" +receipt = 700 drop_srcs = ["mha_fwd_split.cu", "mha_fwd_batch_prefill.cu"] drop_directions = ["fwd_splitkv", "batch_prefill"] [[modules]] name = "libmha_bwd" mode = "cpp_itfs" +receipt = 700 diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index b0ed9ae6e..145d9b139 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -6,9 +6,13 @@ #include #include +#include #include #include +#include +#include #include "ck_fused_attn/ck_fused_attn.hpp" +#include "ck_tile/host/pinned_host_releaser.hpp" #include "qola_mha_bwd.h" #include "ck_fused_attn_utils.hpp" @@ -331,9 +335,8 @@ __global__ void dbias_reduce_b1ss( } // print the fmha_traits and args passed into ck apis -void log_bwd_config(const char* func_name, const aiter::mha_bwd_args& fmha_args){ +void log_bwd_config(const char* func_name, const aiter::mha_bwd_args& fmha_args, std::ostream* log_file){ - std::ostream* log_file = get_ck_log_stream(); (*log_file) << "\n" << func_name << "\n"; // fmha_traits debug @@ -365,7 +368,6 @@ void log_bwd_config(const char* func_name, const aiter::mha_bwd_args& fmha_args) log_value(log_file, "dk_ptr", fmha_args.dk_ptr); log_value(log_file, "dv_ptr", fmha_args.dv_ptr); log_value(log_file, "dbias_ptr", fmha_args.dbias_ptr); - log_value(log_file, "dq_acc_ptr", fmha_args.dq_acc_ptr); log_value(log_file, "seqstart_q_ptr", fmha_args.seqstart_q_ptr); log_value(log_file, "seqstart_k_ptr", fmha_args.seqstart_k_ptr); @@ -390,7 +392,6 @@ void log_bwd_config(const char* func_name, const aiter::mha_bwd_args& fmha_args) log_value(log_file, "stride_o", fmha_args.stride_o); log_value(log_file, "stride_randval", fmha_args.stride_randval); log_value(log_file, "stride_do", fmha_args.stride_do); - log_value(log_file, "stride_dq_acc", fmha_args.stride_dq_acc); log_value(log_file, "stride_dq", fmha_args.stride_dq); log_value(log_file, "stride_dk", fmha_args.stride_dk); log_value(log_file, "stride_dv", fmha_args.stride_dv); @@ -403,7 +404,6 @@ void log_bwd_config(const char* func_name, const aiter::mha_bwd_args& fmha_args) log_value(log_file, "nhead_stride_randval", fmha_args.nhead_stride_randval); log_value(log_file, "nhead_stride_do", fmha_args.nhead_stride_do); log_value(log_file, "nhead_stride_lsed", fmha_args.nhead_stride_lsed); - log_value(log_file, "nhead_stride_dq_acc", fmha_args.nhead_stride_dq_acc); log_value(log_file, "nhead_stride_dq", fmha_args.nhead_stride_dq); log_value(log_file, "nhead_stride_dk", fmha_args.nhead_stride_dk); log_value(log_file, "nhead_stride_dv", fmha_args.nhead_stride_dv); @@ -416,7 +416,6 @@ void log_bwd_config(const char* func_name, const aiter::mha_bwd_args& fmha_args) log_value(log_file, "batch_stride_randval", fmha_args.batch_stride_randval); log_value(log_file, "batch_stride_do", fmha_args.batch_stride_do); log_value(log_file, "batch_stride_lsed", fmha_args.batch_stride_lsed); - log_value(log_file, "batch_stride_dq_acc", fmha_args.batch_stride_dq_acc); log_value(log_file, "batch_stride_dq", fmha_args.batch_stride_dq); log_value(log_file, "batch_stride_dk", fmha_args.batch_stride_dk); log_value(log_file, "batch_stride_dv", fmha_args.batch_stride_dv); @@ -447,14 +446,10 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ bool has_dbias = args.dbias_ptr != nullptr; bool is_mqa_gqa = (args.h > args.hg); - bool ck_log_config = false; - if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") - ck_log_config = true; - } + auto* log_file = get_ck_log_stream(); const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); // print kernel name on verbose mode - ck_tile::stream_config stream_config{stream, dump_path!=nullptr, get_ck_log_stream() != nullptr}; + ck_tile::stream_config stream_config{stream, dump_path!=nullptr, log_file != nullptr}; bias_enum bias_type = bias_enum::no_bias; BiasShape bias_shape = BiasShape::k11SS; @@ -463,8 +458,11 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ } aiter::mha_bwd_args fmha_args{}; + fmha_args.sink_ptr = nullptr; + fmha_args.d_sink_ptr = nullptr; fmha_args.mask_type = static_cast(static_cast(args.attn_mask_type)); - fmha_args.use_asm_v3 = args.uses_bwd_v3; + // Mirrors AITER's small-seqlen guard at aiter/ops/mha.py:1689. + fmha_args.use_asm_v3 = (args.s_q < 16) ? false : args.uses_bwd_v3; fmha_args.v3_atomic_fp32 = args.is_v3_atomic_fp32; fmha_args.v3_bf16_cvt = args.how_v3_bf16_cvt; fmha_args.v3_api_check = false; @@ -495,7 +493,6 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ fmha_args.dbias_ptr = ((!args.is_group_mode()) && has_dbias) ? (bias_shape==BiasShape::kBHSS ? args.dbias_ptr : args.dbias_expanded_ptr) : nullptr; - fmha_args.dq_acc_ptr = args.dq_acc_ptr; if (args.is_group_mode()) { fmha_args.seqstart_q_ptr = args.cu_seqlen_q_padded_ptr==nullptr? args.cu_seqlen_q_ptr : args.cu_seqlen_q_padded_ptr; @@ -511,8 +508,13 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ fmha_args.seqlen_q_ptr = nullptr; fmha_args.seqlen_k_ptr = nullptr; - fmha_args.seqlen_q = args.s_q; - fmha_args.seqlen_k = args.s_kv; + // Group mode contract (matches aiter asm_mha_varlen_bwd.cu): seqlen_q/k + // carry the total token counts, max_seqlen_q/k the per-sequence maximum. + // aiter sizes dq_acc and related workspaces from seqlen_q; passing the + // per-sequence length in group mode under-sizes them and the kernel writes + // past the end. + fmha_args.seqlen_q = args.is_group_mode() ? args.max_tokens_q : args.s_q; + fmha_args.seqlen_k = args.is_group_mode() ? args.max_tokens_kv : args.s_kv; fmha_args.batch = args.b; fmha_args.max_seqlen_q = args.s_q; fmha_args.max_seqlen_k = args.s_kv; @@ -529,8 +531,6 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ fmha_args.stride_o = args.stride_s_o; fmha_args.stride_randval = args.s_kv; fmha_args.stride_do = args.stride_s_do; - //dq_acc of shape (nsplits, B, H, S, D) - fmha_args.stride_dq_acc = args.d_qk; fmha_args.stride_dq = args.stride_s_dq; fmha_args.stride_dk = is_mqa_gqa? args.stride_s_dk_expanded : args.stride_s_dk; fmha_args.stride_dv = is_mqa_gqa? args.stride_s_dv_expanded : args.stride_s_dv; @@ -548,7 +548,6 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ fmha_args.nhead_stride_randval = args.is_group_mode() ? 0 : args.s_q * args.s_kv; fmha_args.nhead_stride_do = args.stride_h_do; fmha_args.nhead_stride_lsed = args.is_group_mode() ? args.max_tokens_q : args.s_q; - fmha_args.nhead_stride_dq_acc = static_cast((args.is_group_mode() ? args.max_tokens_q : args.s_q) * args.d_qk); fmha_args.nhead_stride_dq = args.stride_h_dq; fmha_args.nhead_stride_dk = is_mqa_gqa? args.stride_h_dk_expanded : args.stride_h_dk; fmha_args.nhead_stride_dv = is_mqa_gqa? args.stride_h_dv_expanded : args.stride_h_dv; @@ -564,13 +563,11 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ fmha_args.batch_stride_randval = args.is_group_mode() ? 0 : args.h * args.s_q * args.s_kv; fmha_args.batch_stride_do = args.is_group_mode() ? 0 : args.stride_b_do; fmha_args.batch_stride_lsed = args.is_group_mode() ? 0 : args.h * args.s_q; - fmha_args.batch_stride_dq_acc = args.is_group_mode() ? 0 : static_cast(args.h * args.s_q * args.d_qk); fmha_args.batch_stride_dq = args.is_group_mode() ? 0 : args.stride_b_dq; fmha_args.batch_stride_dk = args.is_group_mode() ? 0 : (is_mqa_gqa? args.stride_b_dk_expanded : args.stride_b_dk); fmha_args.batch_stride_dv = args.is_group_mode() ? 0 : (is_mqa_gqa? args.stride_b_dv_expanded : args.stride_b_dv); // for dbias, use h since h can be different from bias_h fmha_args.batch_stride_dbias = args.is_group_mode() ? 0 : args.h * args.s_q * args.s_kv; - fmha_args.split_stride_dq_acc = static_cast(args.is_group_mode() ? (args.max_tokens_q * args.h * args.d_qk) : (args.b * args.h * args.s_q * args.d_qk)); fmha_args.window_size_left = args.window_size_left; fmha_args.window_size_right = args.window_size_right; @@ -582,19 +579,63 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ // lse_workspace_ptr used as buffer if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")) { if(args.is_group_mode() && std::string(env_p) == "1"){ - if(ck_log_config){ - std::cout << "attn_bwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; + if(log_file){ + *log_file << "attn_bwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; } fmha_args.max_seqlen_q = get_runtime_max_seqlen(args.b, args.cu_seqlen_q_ptr, nullptr, args.lse_workspace_ptr, stream); fmha_args.max_seqlen_k = get_runtime_max_seqlen(args.b, args.cu_seqlen_kv_ptr, nullptr, args.lse_workspace_ptr, stream); } } + // Device-side workspace allocations made inside mha_bwd (launcher metadata + // and the dq_acc accumulator). aiter only contracts that the pointer remain + // valid for the duration of the kernels it enqueues; hipFreeAsync on the + // same stream defers the free until that work completes. + std::vector mha_bwd_workspaces; + fmha_args.workspace_alloc = [&mha_bwd_workspaces, stream](size_t bytes, bool zero_init) -> void* { + if(bytes == 0){ + return nullptr; + } + void* ptr = nullptr; + if(hipMallocAsync(&ptr, bytes, stream) != hipSuccess){ + throw std::runtime_error("ck_fused_attn bwd: hipMallocAsync failed for AITER workspace."); + } + if(zero_init){ + if(hipMemsetAsync(ptr, 0, bytes, stream) != hipSuccess){ + hipFreeAsync(ptr, stream); + throw std::runtime_error("ck_fused_attn bwd: hipMemsetAsync failed for AITER workspace."); + } + } + mha_bwd_workspaces.push_back(ptr); + return ptr; + }; + // Group mode requires a pinned host buffer for the async D2H seqstart + // pipeline; aiter keeps the shared_ptr alive past kernel completion via a + // stream-tail hipLaunchHostFunc keepalive. The deleter fires from that HIP + // callback thread, which holds runtime locks — calling any HIP API from it + // (including hipHostFree) deadlocks against concurrent main-thread HIP + // calls. Defer the free to ck_tile::pinned_host_releaser's worker thread. + fmha_args.pinned_host_alloc = [](size_t bytes) -> std::shared_ptr { + if(bytes == 0){ + return {}; + } + void* ptr = nullptr; + if(hipHostMalloc(&ptr, bytes, hipHostMallocDefault) != hipSuccess){ + throw std::runtime_error("ck_fused_attn bwd: hipHostMalloc failed for AITER pinned host buffer."); + } + return std::shared_ptr(ptr, [](void* p){ + ck_tile::pinned_host_releaser::instance().enqueue(p); + }); + }; + // print ck traits and args when needed - if(ck_log_config){ - log_bwd_config(__FUNCTION__, fmha_args); + if(log_file){ + log_bwd_config(__FUNCTION__, fmha_args, log_file); } float average_runtime = QOLA_NS(mha_bwd)(fmha_args, stream_config); + for(void* ws_ptr : mha_bwd_workspaces){ + hipFreeAsync(ws_ptr, stream); + } if(average_runtime < 0){ //TODO: better error out system throw std::runtime_error("fused attn configs not supported in ck_fused_attn bwd pass."); @@ -610,7 +651,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ dim3 grid(args.max_tokens_kv, args.hg); if(args.d_qk == args.d_v){ dim3 block(args.d_qk); - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dk_dv_reduce_thd: " << "\n"; *log_file << "cu_seqlen_kv_ptr: " << args.cu_seqlen_kv_ptr << "\n"; *log_file << "cu_seqlen_kv_padded_ptr: " << args.cu_seqlen_kv_padded_ptr << "\n"; @@ -637,7 +678,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ args.stride_h_dk, args.stride_s_dk);); } else { dim3 block_dk(args.d_qk); - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dk_or_dv_reduce_thd on dk: " << "\n"; *log_file << "cu_seqlen_kv_ptr: " << args.cu_seqlen_kv_ptr << "\n"; *log_file << "cu_seqlen_kv_padded_ptr: " << args.cu_seqlen_kv_padded_ptr << "\n"; @@ -660,7 +701,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ args.stride_h_dk, args.stride_s_dk);); dim3 block_dv(args.d_v); - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dk_or_dv_reduce_thd on dv: " << "\n"; *log_file << "cu_seqlen_kv_ptr: " << args.cu_seqlen_kv_ptr << "\n"; *log_file << "cu_seqlen_kv_padded_ptr: " << args.cu_seqlen_kv_padded_ptr << "\n"; @@ -686,7 +727,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ dim3 grid(args.b, args.s_kv, args.hg); if(args.d_qk == args.d_v){ dim3 block(args.d_qk); - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dk_dv_reduce: " << "\n"; *log_file << "dk_expanded_ptr: " << args.dk_expanded_ptr << "\n"; *log_file << "dv_expanded_ptr: " << args.dv_expanded_ptr << "\n"; @@ -711,7 +752,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ args.stride_b_dk, args.stride_h_dk, args.stride_s_dk);); } else { dim3 block_dk(args.d_qk); - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dk_or_dv_reduce on dk: " << "\n"; *log_file << "dk_expanded_ptr: " << args.dk_expanded_ptr << "\n"; *log_file << "stride_b_dk_expanded: " << args.stride_b_dk_expanded << "\n"; @@ -732,7 +773,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ args.stride_b_dk, args.stride_h_dk, args.stride_s_dk);); dim3 block_dv(args.d_v); - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dk_or_dv_reduce on dv: " << "\n"; *log_file << "dv_expanded_ptr: " << args.dv_expanded_ptr << "\n"; *log_file << "stride_b_dv_expanded: " << args.stride_b_dv_expanded << "\n"; @@ -762,7 +803,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ dim3 block(THREADS_PER_BLOCK); dim3 grid(ceil(1.0 * args.s_q * args.s_kv / THREADS_PER_BLOCK)); if(bias_shape==BiasShape::k11SS){ - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dbias_reduce_11SS: " << "\n"; *log_file << "dbias_ptr: " << args.dbias_ptr << "\n"; *log_file << "dbias_expanded_ptr: " << args.dbias_expanded_ptr << "\n"; @@ -774,7 +815,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ static_cast(args.dbias_expanded_ptr), static_cast(args.dbias_ptr));); }else if(bias_shape==BiasShape::k1HSS){ - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dbias_reduce_1HSS: " << "\n"; *log_file << "dbias_ptr: " << args.dbias_ptr << "\n"; *log_file << "dbias_expanded_ptr: " << args.dbias_expanded_ptr << "\n"; @@ -786,7 +827,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ static_cast(args.dbias_expanded_ptr), static_cast(args.dbias_ptr));); }else if(bias_shape==BiasShape::kB1SS){ - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dbias_reduce_B1SS: " << "\n"; *log_file << "dbias_ptr: " << args.dbias_ptr << "\n"; *log_file << "dbias_expanded_ptr: " << args.dbias_expanded_ptr << "\n"; diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp index 0f407230c..0f4e9a424 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp @@ -15,9 +15,8 @@ namespace ck_fused_attn{ // print the fmha traits and fmha_args when calling ck apis -void log_fwd_config(const char* func_name, bool has_dropout, const aiter::mha_fwd_args& fmha_args){ +void log_fwd_config(const char* func_name, bool has_dropout, const aiter::mha_fwd_args& fmha_args, std::ostream* log_file){ - std::ostream* log_file = get_ck_log_stream(); (*log_file) << "\n" << func_name << "\n"; // debug fmha_traits @@ -103,11 +102,7 @@ hipError_t ck_attn_fwd(const CKAttnFwdArgs& args, hipStream_t stream){ bool has_dropout = (args.is_training && args.dropout_probability > 0.f); - bool ck_log_config = false; - if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") - ck_log_config = true; - } + auto* log_file = get_ck_log_stream(); const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); // print kernel name on verbose mode ck_tile::stream_config stream_config{stream, dump_path!=nullptr, get_ck_log_stream() != nullptr}; @@ -218,16 +213,16 @@ hipError_t ck_attn_fwd(const CKAttnFwdArgs& args, hipStream_t stream){ if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")){ if(args.is_group_mode() && std::string(env_p) == "1"){ - if(ck_log_config){ - std::cout << "attn_fwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; + if(log_file){ + *log_file << "attn_fwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; } fmha_args.max_seqlen_q = get_runtime_max_seqlen(args.b, args.cu_seqlen_q_ptr, args.cu_seqlen_q_padded_ptr, args.lse_ptr, stream); } } // print ck traits and fmha_args when needed - if(ck_log_config){ - log_fwd_config(__FUNCTION__, has_dropout, fmha_args); + if(log_file){ + log_fwd_config(__FUNCTION__, has_dropout, fmha_args, log_file); } float average_runtime = QOLA_NS(mha_fwd)(fmha_args, stream_config); if(average_runtime < 0){ diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index 744d0575a..7586a8388 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -743,9 +743,6 @@ void fused_attn_ck_bwd_impl( bool is_mqa_gqa = (h > hg); - size_t kN0 = (d_qk <= 128)? 128:64; - size_t nsplits = deterministic? ceil(1.0*s_kv/kN0):1; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(layout); bool is_ragged = qkv_format==NVTE_QKV_Format::NVTE_THD; bool is_SBHD = qkv_format==NVTE_QKV_Format::NVTE_SBHD || qkv_format==NVTE_QKV_Format::NVTE_SBHD_2BSHD; @@ -770,9 +767,6 @@ void fused_attn_ck_bwd_impl( // First h*max_tokens_q*sizeof(float) is the lse-d buffer (passed as softmax_lsed) void* lse_workspace = planner.allocate(h*max_tokens_q*sizeof(float)); - // CK requires dq_acc ptr; size depends on deterministic mode - void* dq_acc_ptr = planner.allocate(nsplits*h*max_tokens_q*d_qk*sizeof(float)); - void* dk_expanded_ptr = nullptr; void* dv_expanded_ptr = nullptr; std::array dk_expanded_stride; @@ -913,8 +907,6 @@ void fused_attn_ck_bwd_impl( } // Initialize workspace buffers. - // dq_acc is of shape (nsplits, B, S, H, D_qk); CK requires zeroing - NVTE_CHECK_CUDA(cudaMemsetAsync(dq_acc_ptr, 0, sizeof(float)*nsplits*h*max_tokens_q*d_qk, stream)); if(devPtrAlibiSlope){ dim3 block, grid; block.x = 1024; @@ -992,7 +984,6 @@ void fused_attn_ck_bwd_impl( ck_args.attn_mask_type = set_ck_mask(mask_type, window_size_left, window_size_right); ck_args.window_size_left = window_size_left; ck_args.window_size_right = window_size_right; - ck_args.dq_acc_ptr = dq_acc_ptr; ck_args.dk_expanded_ptr = dk_expanded_ptr; ck_args.dv_expanded_ptr = dv_expanded_ptr; ck_args.lse_workspace_ptr = lse_workspace;