diff --git a/3rdparty/QoLA b/3rdparty/QoLA index 549844d77..9c13e77ef 160000 --- a/3rdparty/QoLA +++ b/3rdparty/QoLA @@ -1 +1 @@ -Subproject commit 549844d771ed3155dd75a6bf2c714cb3f710bada +Subproject commit 9c13e77ef3cf89053aad61ed3a0f27470f123ee5 diff --git a/transformer_engine/common/ck_fused_attn/CMakeLists.txt b/transformer_engine/common/ck_fused_attn/CMakeLists.txt index 4640374aa..3a1913a96 100644 --- a/transformer_engine/common/ck_fused_attn/CMakeLists.txt +++ b/transformer_engine/common/ck_fused_attn/CMakeLists.txt @@ -8,48 +8,58 @@ project(ck_fused_attn LANGUAGES HIP CXX) set(AITER_MHA_INSTALL_PREFIX "transformer_engine" CACHE STRING "aiter mha shared lib install prefix in TE") -#Corresponding runtime check is in nvte_get_fused_attn_backend() -list(FIND CMAKE_HIP_ARCHITECTURES "gfx1250" _gfx1250_idx) -if(NOT _gfx1250_idx EQUAL -1) - message(WARNING - "Removing unsupported gfx1250 from CMAKE_HIP_ARCHITECTURES for ck_fused_attn build.") - list(REMOVE_ITEM CMAKE_HIP_ARCHITECTURES "gfx1250") - list(LENGTH CMAKE_HIP_ARCHITECTURES _hip_arch_count) - if(_hip_arch_count EQUAL 0) - message(FATAL_ERROR - "No supported architectures remain for the ck_fused_attn build. " - "Re-run the build with FUSED_ATTN_CK backend disabled.") - endif() - set(GPU_TARGETS ${CMAKE_HIP_ARCHITECTURES}) -endif() -set(__AITER_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/QoLA/3rdparty/aiter") +# gfx1250 carries AITER V3 bwd kernels only (hd128, bf16, batch mode). The +# runtime envelope is enforced in nvte_get_fused_attn_backend(). +set(__QOLA_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/QoLA") +set(__AITER_SOURCE_DIR "${__QOLA_DIR}/build/third_party/aiter") set(__CK_SOURCE_DIR "${__AITER_SOURCE_DIR}/3rdparty/composable_kernel") - set(CK_INCLUDE_DIR "${__CK_SOURCE_DIR}/include") -message(STATUS "ck_include_dir: ${CK_INCLUDE_DIR}") -if(NOT EXISTS "${CK_INCLUDE_DIR}") - message(FATAL_ERROR - "Could not find CK API. " - "Try running 'git submodule update --init --recursive' " - "within the Transformer Engine source.") -endif() - set(AITER_INCLUDE_DIR "${__AITER_SOURCE_DIR}/csrc/include") -message(STATUS "aiter_include_dir: ${AITER_INCLUDE_DIR}") -if(NOT EXISTS "${AITER_INCLUDE_DIR}") - message(FATAL_ERROR - "Could not find AITER API. " - "Try running 'git submodule update --init --recursive' " - "within the Transformer Engine source.") -endif() if(NOT Python_EXECUTABLE) find_package(Python COMPONENTS Interpreter QUIET) endif() +# Resolve the manifest-pinned AITER commit (defines AITER_SHA) and bring the +# QoLA-managed AITER source tree to that commit before any consumer reads it +# (header validation below, header includes for the .cpp build later, and +# QoLA's own kernel build if the prebuilt cache misses). +include("${CMAKE_CURRENT_LIST_DIR}/aiter_prebuilt.cmake") + if(Python_EXECUTABLE) + set(__QOLA_MANIFEST "${CMAKE_CURRENT_LIST_DIR}/qola_manifest.toml") + # Redirect GIT_CONFIG_GLOBAL to a tempfile carrying `safe.directory = *` so + # git operations inside the QoLA-managed AITER tree (and its recursive + # submodules) work in containerized builds where the bind-mounted .git is + # owned by a different UID than the build process. Mirrors the pattern in + # transformer_engine/common/CMakeLists.txt:get_git_commit(). execute_process( - COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/check_aiter_mha_args.py --mode both --te-dir "${CMAKE_CURRENT_LIST_DIR}/../../.." + COMMAND sh -c + "tmp=$(mktemp /tmp/gitconfig.XXXXXX) || exit 1; \ +GIT_CONFIG_GLOBAL=$tmp git config --global --add safe.directory '*' >/dev/null 2>&1; \ +GIT_CONFIG_GLOBAL=$tmp PYTHONPATH=\"${__QOLA_DIR}:$PYTHONPATH\" '${Python_EXECUTABLE}' -m qola.cli checkout \ +--manifest '${__QOLA_MANIFEST}' \ +--aiter-root '${__AITER_SOURCE_DIR}'; \ +rc=$?; rm -f \"$tmp\"; exit $rc" + RESULT_VARIABLE AITER_CHECKOUT_RESULT + OUTPUT_VARIABLE AITER_CHECKOUT_OUTPUT + ERROR_VARIABLE AITER_CHECKOUT_ERROR + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_STRIP_TRAILING_WHITESPACE + ) + if(NOT AITER_CHECKOUT_RESULT EQUAL 0) + message(FATAL_ERROR + "Failed to sync AITER source tree at ${__AITER_SOURCE_DIR} to " + "manifest-pinned commit ${AITER_SHA}.\n" + "${AITER_CHECKOUT_OUTPUT}\n${AITER_CHECKOUT_ERROR}") + endif() + message(STATUS "[AITER] Synced ${__AITER_SOURCE_DIR} to ${AITER_SHA}") + + execute_process( + COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/check_aiter_mha_args.py + --mode both + --te-dir "${CMAKE_CURRENT_LIST_DIR}/../../.." + --aiter-root "${__AITER_SOURCE_DIR}" RESULT_VARIABLE AITER_ARG_CHECK_RESULT OUTPUT_VARIABLE AITER_ARG_CHECK_OUTPUT ERROR_VARIABLE AITER_ARG_CHECK_ERROR @@ -64,50 +74,125 @@ if(Python_EXECUTABLE) endif() message(STATUS "AITER API validation passed via check_aiter_mha_args.py") else() - message(WARNING "Python interpreter not found; skipping AITER API validation.") + message(WARNING "Python interpreter not found; skipping AITER source-tree sync and API validation.") endif() -if(DEFINED AITER_MHA_PATH) - message(STATUS "[AITER-BUILD] Using AITER_MHA_PATH=${AITER_MHA_PATH}") - # use pre-built te_libmha_fwd.so te_libmha_bwd.so - set(__AITER_MHA_PATH ${AITER_MHA_PATH}) -else() - set(__AITER_MHA_PATH "") - include("${CMAKE_CURRENT_LIST_DIR}/aiter_prebuilt.cmake") - get_prebuilt_aiter(__AITER_MHA_PATH) +# Sanity-check the resolved include directories now that `qola checkout` has +# materialized the AITER tree. +message(STATUS "ck_include_dir: ${CK_INCLUDE_DIR}") +if(NOT EXISTS "${CK_INCLUDE_DIR}") + message(FATAL_ERROR + "Could not find CK API at ${CK_INCLUDE_DIR}. " + "Re-run the build to let `qola checkout` clone AITER and its " + "composable_kernel submodule.") +endif() - if(__AITER_MHA_PATH STREQUAL "") - # If not available, fallback: Build from source via QoLA - list(JOIN CMAKE_HIP_ARCHITECTURES ";" GPU_ARCHS_STR) - message(STATUS "[AITER-BUILD] Building AITER kernels for ${GPU_ARCHS_STR} via QoLA.") - set(__QOLA_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/QoLA") - set(__QOLA_BUILD_DIR "${__QOLA_DIR}/build") - set(__QOLA_MANIFEST "${CMAKE_CURRENT_LIST_DIR}/qola_manifest.toml") - execute_process( - COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${__QOLA_DIR}:$ENV{PYTHONPATH}" - ${Python_EXECUTABLE} -m qola.cli build - --manifest ${__QOLA_MANIFEST} - --aiter-root ${__AITER_SOURCE_DIR} - --output-dir ${__QOLA_BUILD_DIR} - --arch "${GPU_ARCHS_STR}" - RESULT_VARIABLE QOLA_BUILD_RESULT - ) - if(NOT QOLA_BUILD_RESULT EQUAL 0) - message(FATAL_ERROR "[AITER-BUILD] QoLA build failed.") +message(STATUS "aiter_include_dir: ${AITER_INCLUDE_DIR}") +if(NOT EXISTS "${AITER_INCLUDE_DIR}") + message(FATAL_ERROR + "Could not find AITER API at ${AITER_INCLUDE_DIR}. " + "Re-run the build to let `qola checkout` clone AITER.") +endif() + +# Partition the requested HIP architectures into the CK-full set (CDNA, where +# the AITER CK FMHA template headers compile) and the V3-asm-only set. gfx1250 +# (RDNA4) has AITER V3 *backward* asm kernels but no CK FMHA support and no +# forward kernels, so it is built as a separate CK-free library (namespace +# te_v3, manifest qola_manifest_gfx1250.toml) and dispatched at runtime in +# ck_attn_bwd. The two tiers coexist via distinct QoLA namespaces. +set(__CK_FULL_ARCHS ${CMAKE_HIP_ARCHITECTURES}) +set(__HAS_GFX1250 FALSE) +list(FIND __CK_FULL_ARCHS "gfx1250" __GFX1250_IDX) +if(NOT __GFX1250_IDX EQUAL -1) + set(__HAS_GFX1250 TRUE) + list(REMOVE_ITEM __CK_FULL_ARCHS "gfx1250") +endif() +list(LENGTH __CK_FULL_ARCHS __CK_FULL_ARCH_COUNT) +if(__CK_FULL_ARCH_COUNT EQUAL 0 AND NOT __HAS_GFX1250) + message(FATAL_ERROR "ck_fused_attn: no target architectures requested.") +endif() + +set(__AITER_MHA_PATH "") +set(__HAVE_CK_FULL FALSE) + +# --- CK-full libraries (CDNA): te_libmha_fwd.so / te_libmha_bwd.so --- +if(__CK_FULL_ARCH_COUNT GREATER 0) + set(__HAVE_CK_FULL TRUE) + if(DEFINED AITER_MHA_PATH) + message(STATUS "[AITER-BUILD] Using AITER_MHA_PATH=${AITER_MHA_PATH}") + # use pre-built te_libmha_fwd.so te_libmha_bwd.so + set(__AITER_MHA_PATH ${AITER_MHA_PATH}) + else() + get_prebuilt_aiter(__AITER_MHA_PATH) + + if(__AITER_MHA_PATH STREQUAL "") + # If not available, fallback: Build from source via QoLA + list(JOIN __CK_FULL_ARCHS ";" GPU_ARCHS_STR) + message(STATUS "[AITER-BUILD] Building CK-full AITER kernels for ${GPU_ARCHS_STR} via QoLA.") + set(__QOLA_BUILD_DIR "${__QOLA_DIR}/build") + execute_process( + COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${__QOLA_DIR}:$ENV{PYTHONPATH}" + ${Python_EXECUTABLE} -m qola.cli build + --manifest ${__QOLA_MANIFEST} + --aiter-root ${__AITER_SOURCE_DIR} + --output-dir ${__QOLA_BUILD_DIR} + --arch "${GPU_ARCHS_STR}" + RESULT_VARIABLE QOLA_BUILD_RESULT + ) + if(NOT QOLA_BUILD_RESULT EQUAL 0) + message(FATAL_ERROR "[AITER-BUILD] QoLA build failed.") + endif() + + # Copy the final .so libs and exported public headers into the aiter + # prebuilt cache so downstream consumers see a self-contained tree. + get_default_aiter_cache_dir(__QOLA_CACHE_DIR) + set(__QOLA_CACHE_LIB "${__QOLA_CACHE_DIR}/lib") + file(MAKE_DIRECTORY ${__QOLA_CACHE_LIB}) + file(GLOB __QOLA_BUILT_LIBS "${__QOLA_BUILD_DIR}/lib/*.so") + file(COPY ${__QOLA_BUILT_LIBS} DESTINATION ${__QOLA_CACHE_LIB}) + file(COPY "${__QOLA_BUILD_DIR}/include" DESTINATION "${__QOLA_CACHE_DIR}") + set(__AITER_MHA_PATH "${__QOLA_CACHE_LIB}") + else() + message(STATUS "[AITER-BUILD] Using pre-built AITER from ${__AITER_MHA_PATH}") endif() + endif() +endif() + +# --- V3-asm-only backward library (gfx1250): te_v3_libmha_bwd.so --- +# There is no prebuilt cache path for gfx1250 (no public prebuilt, and a CK-free +# asm build is cheap), so always build it from source via QoLA. Both manifests +# pin the same AITER commit and share the already-checked-out source tree. +if(__HAS_GFX1250) + set(__QOLA_MANIFEST_V3 "${CMAKE_CURRENT_LIST_DIR}/qola_manifest_gfx1250.toml") + set(__QOLA_BUILD_DIR_V3 "${__QOLA_DIR}/build_gfx1250") + message(STATUS "[AITER-BUILD] Building CK-free V3 backward (gfx1250) via QoLA.") + # The asm-only / CK-free flags (ONLY_FAV3=1, ENABLE_CK=0) are carried by the + # gfx1250 manifest's libmha_bwd module, so no special env is needed here. + execute_process( + COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${__QOLA_DIR}:$ENV{PYTHONPATH}" + ${Python_EXECUTABLE} -m qola.cli build + --manifest ${__QOLA_MANIFEST_V3} + --aiter-root ${__AITER_SOURCE_DIR} + --output-dir ${__QOLA_BUILD_DIR_V3} + --arch "gfx1250" + RESULT_VARIABLE QOLA_V3_BUILD_RESULT + ) + if(NOT QOLA_V3_BUILD_RESULT EQUAL 0) + message(FATAL_ERROR "[AITER-BUILD] QoLA gfx1250 V3 build failed.") + endif() - # Copy the final .so libs and exported public headers into the aiter - # prebuilt cache so downstream consumers see a self-contained tree. + # Stage the v3 lib next to the CK-full libs so a single link/-L/install path + # covers both. For a gfx1250-only build there are no CK-full libs, so set up + # the cache lib dir here and stage the v3 public headers too. + if(__AITER_MHA_PATH STREQUAL "") get_default_aiter_cache_dir(__QOLA_CACHE_DIR) set(__QOLA_CACHE_LIB "${__QOLA_CACHE_DIR}/lib") file(MAKE_DIRECTORY ${__QOLA_CACHE_LIB}) - file(GLOB __QOLA_BUILT_LIBS "${__QOLA_BUILD_DIR}/lib/*.so") - file(COPY ${__QOLA_BUILT_LIBS} DESTINATION ${__QOLA_CACHE_LIB}) - file(COPY "${__QOLA_BUILD_DIR}/include" DESTINATION "${__QOLA_CACHE_DIR}") + file(COPY "${__QOLA_BUILD_DIR_V3}/include" DESTINATION "${__QOLA_CACHE_DIR}") set(__AITER_MHA_PATH "${__QOLA_CACHE_LIB}") - else() - message(STATUS "[AITER-BUILD] Using pre-built AITER from ${__AITER_MHA_PATH}") endif() + file(GLOB __QOLA_V3_LIBS "${__QOLA_BUILD_DIR_V3}/lib/te_v3_*.so") + file(COPY ${__QOLA_V3_LIBS} DESTINATION ${__AITER_MHA_PATH}) endif() set(ck_fused_attn_SOURCES) @@ -124,7 +209,18 @@ endforeach() add_library(ck_fused_attn SHARED ${ck_fused_attn_SOURCES}) set(CK_FUSED_ATTN_COMPILE_OPTIONS) list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS - -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT}) + -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT} + -DENABLE_CK=1) +# Tier guards consumed by src/ck_fused_attn_{fwd,bwd}.cpp: +# NVTE_AITER_CK_FULL -> qola::te::{mha_fwd,mha_bwd} (CDNA) are linked +# NVTE_AITER_V3_BWD_GFX1250 -> qola::te_v3::mha_bwd (gfx1250) is linked + +# runtime-dispatched in ck_attn_bwd +if(__HAVE_CK_FULL) + list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS -DNVTE_AITER_CK_FULL) +endif() +if(__HAS_GFX1250) + list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS -DNVTE_AITER_V3_BWD_GFX1250) +endif() # Public QoLA headers ship alongside the .so libs in ${__AITER_MHA_PATH}/../include # (emitted by qola.cli build, or copied from the QoLA build dir above for the @@ -141,10 +237,22 @@ target_include_directories(ck_fused_attn PRIVATE ${__QOLA_INCLUDE_DIR}) find_package(hip) target_link_directories(ck_fused_attn PUBLIC ${__AITER_MHA_PATH}) -list(APPEND ck_fused_attn_LINKER_LIBS hip::host hip::device roctx64 -l:te_libmha_fwd.so -l:te_libmha_bwd.so) +list(APPEND ck_fused_attn_LINKER_LIBS hip::host hip::device roctx64) +set(__INSTALL_AITER_LIBS) +if(__HAVE_CK_FULL) + list(APPEND ck_fused_attn_LINKER_LIBS -l:te_libmha_fwd.so -l:te_libmha_bwd.so) + list(APPEND __INSTALL_AITER_LIBS + ${__AITER_MHA_PATH}/te_libmha_fwd.so + ${__AITER_MHA_PATH}/te_libmha_bwd.so) +endif() +if(__HAS_GFX1250) + list(APPEND ck_fused_attn_LINKER_LIBS -l:te_v3_libmha_bwd.so) + list(APPEND __INSTALL_AITER_LIBS + ${__AITER_MHA_PATH}/te_v3_libmha_bwd.so) +endif() target_link_libraries(ck_fused_attn PUBLIC ${ck_fused_attn_LINKER_LIBS}) target_compile_options(ck_fused_attn PRIVATE ${CK_FUSED_ATTN_COMPILE_OPTIONS}) set_target_properties(ck_fused_attn PROPERTIES INSTALL_RPATH "$ORIGIN") -install(FILES ${__AITER_MHA_PATH}/te_libmha_fwd.so ${__AITER_MHA_PATH}/te_libmha_bwd.so DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib) +install(FILES ${__INSTALL_AITER_LIBS} DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib) install(TARGETS ck_fused_attn DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib) diff --git a/transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake b/transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake index ea0396116..65ee3ec81 100644 --- a/transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake +++ b/transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake @@ -18,8 +18,22 @@ string(STRIP "${ROCM_VER_CONTENT}" ROCM_VER_CONTENT) string(REGEX MATCH "^[0-9]+\\.[0-9]+" ROCM_VER "${ROCM_VER_CONTENT}") string(REGEX MATCH "^[0-9]+" ROCM_VER_MAJOR "${ROCM_VER}") -# AITER commit -get_git_commit("${__AITER_SOURCE_DIR}" AITER_SHA) +# AITER commit — read from the QoLA manifest so the cache key tracks the +# commit QoLA will actually check out and build, not whatever happens to be +# the submodule's current HEAD at configure time. +set(__QOLA_MANIFEST "${CMAKE_CURRENT_LIST_DIR}/qola_manifest.toml") +set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS "${__QOLA_MANIFEST}") +file(STRINGS "${__QOLA_MANIFEST}" __AITER_COMMIT_LINES + REGEX "^[ \t]*aiter_commit[ \t]*=[ \t]*\"[^\"]+\"") +list(LENGTH __AITER_COMMIT_LINES __AITER_COMMIT_COUNT) +if(NOT __AITER_COMMIT_COUNT EQUAL 1) + message(FATAL_ERROR + "Expected exactly one 'aiter_commit = \"...\"' line in " + "${__QOLA_MANIFEST}, found ${__AITER_COMMIT_COUNT}.") +endif() +list(GET __AITER_COMMIT_LINES 0 __AITER_COMMIT_LINE) +string(REGEX MATCH "\"([^\"]+)\"" _UNUSED "${__AITER_COMMIT_LINE}") +set(AITER_SHA "${CMAKE_MATCH_1}") # Cache key & local paths set(AITER_CACHE_ROOT "${CMAKE_CURRENT_LIST_DIR}/../../../build/aiter-prebuilts") diff --git a/transformer_engine/common/ck_fused_attn/check_aiter_mha_args.py b/transformer_engine/common/ck_fused_attn/check_aiter_mha_args.py index 2e9831f1a..6bae3091d 100644 --- a/transformer_engine/common/ck_fused_attn/check_aiter_mha_args.py +++ b/transformer_engine/common/ck_fused_attn/check_aiter_mha_args.py @@ -31,7 +31,7 @@ def parse_with_skip_comments(buffer, line, regex, outputs): def extract_fields_from_header(text: str, struct_name: str) -> List[str]: - struct_field_re = re.compile(r"([A-Za-z_][A-Za-z0-9_]*)\s*(?:=[^;]*)?;\s*$") + struct_field_re = re.compile(r"([A-Za-z_][A-Za-z0-9_]*)\s*(?:=[^;]*|\{[^;]*\})?;\s*$") struct_end_re = re.compile(r"^\s*};\s*$") struct_start_re = re.compile(rf"\bstruct\s+{re.escape(struct_name)}\b") @@ -64,11 +64,14 @@ def main() -> int: parser = argparse.ArgumentParser(description="Check aiter args usage vs header definition") parser.add_argument("--mode", choices=["fwd", "bwd", "both"], default="both", help="Mode: fwd, bwd, or both") parser.add_argument("--te-dir", type=Path, default=Path(__file__).parent.parent.parent.parent, help="Root directory of TransformerEngine") + parser.add_argument("--aiter-root", type=Path, default=None, + help="AITER source tree root. Defaults to /3rdparty/aiter.") args = parser.parse_args() + aiter_root = args.aiter_root if args.aiter_root else args.te_dir / "3rdparty/aiter" modes = ["fwd", "bwd"] if args.mode == "both" else [args.mode] mismatch = 0 for mode in modes: - header_path = args.te_dir / f"3rdparty/aiter/csrc/include/mha_{mode}.h" + header_path = aiter_root / f"csrc/include/mha_{mode}.h" source_path = args.te_dir / f"transformer_engine/common/ck_fused_attn/src/ck_fused_attn_{mode}.cpp" header_text = header_path.read_text(encoding="utf-8") source_text = source_path.read_text(encoding="utf-8") diff --git a/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp b/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp index 127d75b4c..abd1ec371 100644 --- a/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp +++ b/transformer_engine/common/ck_fused_attn/include/ck_fused_attn/ck_fused_attn.hpp @@ -113,7 +113,6 @@ struct CkAttnBwdArgs : CKAttnCommonArgs { // dQ void* dq_ptr = nullptr; uint64_t stride_b_dq = 0, stride_h_dq = 0, stride_s_dq = 0; - void* dq_acc_ptr = nullptr; // dK / dV expanded (MQA/GQA reduction inputs; null when h==hg) void* dk_expanded_ptr = nullptr; diff --git a/transformer_engine/common/ck_fused_attn/qola_manifest.toml b/transformer_engine/common/ck_fused_attn/qola_manifest.toml index b6877c6c0..b81e47ceb 100644 --- a/transformer_engine/common/ck_fused_attn/qola_manifest.toml +++ b/transformer_engine/common/ck_fused_attn/qola_manifest.toml @@ -1,5 +1,5 @@ [qola] -aiter_commit = "33f2e6af5f39379c739720080ed0033d533f5cb2" # pinned AITER submodule commit +aiter_commit = "f03a4ec572bb3d9e15da3b346763c8f126feec0d" # pinned AITER submodule commit namespace = "te" rocm_versions = ["7.2"] @@ -9,9 +9,11 @@ architectures = ["gfx950", "gfx942"] [[modules]] name = "libmha_fwd" mode = "cpp_itfs" +receipt = 700 drop_srcs = ["mha_fwd_split.cu", "mha_fwd_batch_prefill.cu"] drop_directions = ["fwd_splitkv", "batch_prefill"] [[modules]] name = "libmha_bwd" mode = "cpp_itfs" +receipt = 700 diff --git a/transformer_engine/common/ck_fused_attn/qola_manifest_gfx1250.toml b/transformer_engine/common/ck_fused_attn/qola_manifest_gfx1250.toml new file mode 100644 index 000000000..23b59027e --- /dev/null +++ b/transformer_engine/common/ck_fused_attn/qola_manifest_gfx1250.toml @@ -0,0 +1,34 @@ +# gfx1250 (RDNA4) carries AITER V3 *backward* asm kernels only — there are no +# forward kernels at the pinned commit, and the CK FMHA template headers do not +# compile for gfx1250. This manifest therefore builds a CK-free (ENABLE_CK=0), +# asm-v3-only backward library under a distinct namespace (te_v3) so it can +# coexist with the CK-full te_* libraries in a multi-arch build. TE selects +# between qola::te::mha_bwd and qola::te_v3::mha_bwd at runtime by device arch. +# +# Keep aiter_commit in lockstep with qola_manifest.toml — both consume the same +# checked-out AITER source tree. +[qola] +aiter_commit = "f03a4ec572bb3d9e15da3b346763c8f126feec0d" # pinned AITER submodule commit +namespace = "te_v3" +rocm_versions = ["7.2"] + +[build] +architectures = ["gfx1250"] + +# Reuse the torch-free libmha_bwd module (sources = mha_bwd.cu only; the same +# source the CK-full te_libmha_bwd.so builds from). Do NOT use +# module_fmha_v3_bwd here — it pulls in mha_common.cu, which includes +# and is therefore torch-dependent. To make this build +# CK-free and asm-only, two independent gates in mha_bwd.cu must both be set: +# - ONLY_FAV3=1 selects the asm-only dispatch (`#if ONLY_FAV3` returns the +# fmha_v3_bwd result; the `#else` branch instantiates CK fmha_bwd_traits), +# - ENABLE_CK=0 strips the CK fmha_bwd.hpp include and uses the ck_tile shim, +# mirroring AITER's own module_fmha_v3_bwd. drop_directions=["bwd"] removes the +# CK `generate.py -d bwd` codegen (the HSA codegen has no -d and is kept). +# flags_extra_cc values are eval'd as Python expressions (optCompilerConfig +# convention), hence the inner quotes; this replaces libmha_bwd's empty list. +[[modules]] +name = "libmha_bwd" +mode = "cpp_itfs" +drop_directions = ["bwd"] +flags_extra_cc = ["'-DONLY_FAV3=1'", "'-DENABLE_CK=0'"] diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index b0ed9ae6e..26a2def60 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -6,14 +6,43 @@ #include #include +#include #include #include +#include +#include #include "ck_fused_attn/ck_fused_attn.hpp" +#include "ck_tile/host/pinned_host_releaser.hpp" #include "qola_mha_bwd.h" #include "ck_fused_attn_utils.hpp" +// Staged gfx1250 backward dispatch. When this build includes the CK-free V3 +// backward library (te_v3_libmha_bwd.so, built for gfx1250), declare its +// namespaced entry point so ck_attn_bwd can route to it on gfx1250 devices at +// runtime. The CK-full path (QOLA_NS(mha_bwd) == qola::te::mha_bwd) is used on +// all other archs. +#if defined(NVTE_AITER_V3_BWD_GFX1250) +namespace qola { namespace te_v3 { +float mha_bwd(const aiter::mha_bwd_args& args, const ck_tile::stream_config& stream_config); +}} // namespace qola::te_v3 +#endif + namespace ck_fused_attn{ +#if defined(NVTE_AITER_V3_BWD_GFX1250) +namespace { +// True when the active device is gfx1250 (gcnArchName may carry feature +// suffixes, e.g. "gfx1250:sramecc+", so match on prefix). +bool is_gfx1250_device(){ + int dev = 0; + if(hipGetDevice(&dev) != hipSuccess){ return false; } + hipDeviceProp_t prop{}; + if(hipGetDeviceProperties(&prop, dev) != hipSuccess){ return false; } + return std::string(prop.gcnArchName).rfind("gfx1250", 0) == 0; +} +} // namespace +#endif + // TODO: unify with binary search in TE/common/fused_attn(rocm)/util // no device std::upper_bound // in an increasing array with given size len, search for the index that: @@ -331,9 +360,8 @@ __global__ void dbias_reduce_b1ss( } // print the fmha_traits and args passed into ck apis -void log_bwd_config(const char* func_name, const aiter::mha_bwd_args& fmha_args){ +void log_bwd_config(const char* func_name, const aiter::mha_bwd_args& fmha_args, std::ostream* log_file){ - std::ostream* log_file = get_ck_log_stream(); (*log_file) << "\n" << func_name << "\n"; // fmha_traits debug @@ -365,7 +393,6 @@ void log_bwd_config(const char* func_name, const aiter::mha_bwd_args& fmha_args) log_value(log_file, "dk_ptr", fmha_args.dk_ptr); log_value(log_file, "dv_ptr", fmha_args.dv_ptr); log_value(log_file, "dbias_ptr", fmha_args.dbias_ptr); - log_value(log_file, "dq_acc_ptr", fmha_args.dq_acc_ptr); log_value(log_file, "seqstart_q_ptr", fmha_args.seqstart_q_ptr); log_value(log_file, "seqstart_k_ptr", fmha_args.seqstart_k_ptr); @@ -390,7 +417,6 @@ void log_bwd_config(const char* func_name, const aiter::mha_bwd_args& fmha_args) log_value(log_file, "stride_o", fmha_args.stride_o); log_value(log_file, "stride_randval", fmha_args.stride_randval); log_value(log_file, "stride_do", fmha_args.stride_do); - log_value(log_file, "stride_dq_acc", fmha_args.stride_dq_acc); log_value(log_file, "stride_dq", fmha_args.stride_dq); log_value(log_file, "stride_dk", fmha_args.stride_dk); log_value(log_file, "stride_dv", fmha_args.stride_dv); @@ -403,7 +429,6 @@ void log_bwd_config(const char* func_name, const aiter::mha_bwd_args& fmha_args) log_value(log_file, "nhead_stride_randval", fmha_args.nhead_stride_randval); log_value(log_file, "nhead_stride_do", fmha_args.nhead_stride_do); log_value(log_file, "nhead_stride_lsed", fmha_args.nhead_stride_lsed); - log_value(log_file, "nhead_stride_dq_acc", fmha_args.nhead_stride_dq_acc); log_value(log_file, "nhead_stride_dq", fmha_args.nhead_stride_dq); log_value(log_file, "nhead_stride_dk", fmha_args.nhead_stride_dk); log_value(log_file, "nhead_stride_dv", fmha_args.nhead_stride_dv); @@ -416,7 +441,6 @@ void log_bwd_config(const char* func_name, const aiter::mha_bwd_args& fmha_args) log_value(log_file, "batch_stride_randval", fmha_args.batch_stride_randval); log_value(log_file, "batch_stride_do", fmha_args.batch_stride_do); log_value(log_file, "batch_stride_lsed", fmha_args.batch_stride_lsed); - log_value(log_file, "batch_stride_dq_acc", fmha_args.batch_stride_dq_acc); log_value(log_file, "batch_stride_dq", fmha_args.batch_stride_dq); log_value(log_file, "batch_stride_dk", fmha_args.batch_stride_dk); log_value(log_file, "batch_stride_dv", fmha_args.batch_stride_dv); @@ -447,14 +471,10 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ bool has_dbias = args.dbias_ptr != nullptr; bool is_mqa_gqa = (args.h > args.hg); - bool ck_log_config = false; - if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") - ck_log_config = true; - } + auto* log_file = get_ck_log_stream(); const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); // print kernel name on verbose mode - ck_tile::stream_config stream_config{stream, dump_path!=nullptr, get_ck_log_stream() != nullptr}; + ck_tile::stream_config stream_config{stream, dump_path!=nullptr, log_file != nullptr}; bias_enum bias_type = bias_enum::no_bias; BiasShape bias_shape = BiasShape::k11SS; @@ -463,8 +483,11 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ } aiter::mha_bwd_args fmha_args{}; + fmha_args.sink_ptr = nullptr; + fmha_args.d_sink_ptr = nullptr; fmha_args.mask_type = static_cast(static_cast(args.attn_mask_type)); - fmha_args.use_asm_v3 = args.uses_bwd_v3; + // Mirrors AITER's small-seqlen guard at aiter/ops/mha.py:1689. + fmha_args.use_asm_v3 = (args.s_q < 16) ? false : args.uses_bwd_v3; fmha_args.v3_atomic_fp32 = args.is_v3_atomic_fp32; fmha_args.v3_bf16_cvt = args.how_v3_bf16_cvt; fmha_args.v3_api_check = false; @@ -495,7 +518,6 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ fmha_args.dbias_ptr = ((!args.is_group_mode()) && has_dbias) ? (bias_shape==BiasShape::kBHSS ? args.dbias_ptr : args.dbias_expanded_ptr) : nullptr; - fmha_args.dq_acc_ptr = args.dq_acc_ptr; if (args.is_group_mode()) { fmha_args.seqstart_q_ptr = args.cu_seqlen_q_padded_ptr==nullptr? args.cu_seqlen_q_ptr : args.cu_seqlen_q_padded_ptr; @@ -511,8 +533,13 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ fmha_args.seqlen_q_ptr = nullptr; fmha_args.seqlen_k_ptr = nullptr; - fmha_args.seqlen_q = args.s_q; - fmha_args.seqlen_k = args.s_kv; + // Group mode contract (matches aiter asm_mha_varlen_bwd.cu): seqlen_q/k + // carry the total token counts, max_seqlen_q/k the per-sequence maximum. + // aiter sizes dq_acc and related workspaces from seqlen_q; passing the + // per-sequence length in group mode under-sizes them and the kernel writes + // past the end. + fmha_args.seqlen_q = args.is_group_mode() ? args.max_tokens_q : args.s_q; + fmha_args.seqlen_k = args.is_group_mode() ? args.max_tokens_kv : args.s_kv; fmha_args.batch = args.b; fmha_args.max_seqlen_q = args.s_q; fmha_args.max_seqlen_k = args.s_kv; @@ -529,8 +556,6 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ fmha_args.stride_o = args.stride_s_o; fmha_args.stride_randval = args.s_kv; fmha_args.stride_do = args.stride_s_do; - //dq_acc of shape (nsplits, B, H, S, D) - fmha_args.stride_dq_acc = args.d_qk; fmha_args.stride_dq = args.stride_s_dq; fmha_args.stride_dk = is_mqa_gqa? args.stride_s_dk_expanded : args.stride_s_dk; fmha_args.stride_dv = is_mqa_gqa? args.stride_s_dv_expanded : args.stride_s_dv; @@ -548,7 +573,6 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ fmha_args.nhead_stride_randval = args.is_group_mode() ? 0 : args.s_q * args.s_kv; fmha_args.nhead_stride_do = args.stride_h_do; fmha_args.nhead_stride_lsed = args.is_group_mode() ? args.max_tokens_q : args.s_q; - fmha_args.nhead_stride_dq_acc = static_cast((args.is_group_mode() ? args.max_tokens_q : args.s_q) * args.d_qk); fmha_args.nhead_stride_dq = args.stride_h_dq; fmha_args.nhead_stride_dk = is_mqa_gqa? args.stride_h_dk_expanded : args.stride_h_dk; fmha_args.nhead_stride_dv = is_mqa_gqa? args.stride_h_dv_expanded : args.stride_h_dv; @@ -564,13 +588,11 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ fmha_args.batch_stride_randval = args.is_group_mode() ? 0 : args.h * args.s_q * args.s_kv; fmha_args.batch_stride_do = args.is_group_mode() ? 0 : args.stride_b_do; fmha_args.batch_stride_lsed = args.is_group_mode() ? 0 : args.h * args.s_q; - fmha_args.batch_stride_dq_acc = args.is_group_mode() ? 0 : static_cast(args.h * args.s_q * args.d_qk); fmha_args.batch_stride_dq = args.is_group_mode() ? 0 : args.stride_b_dq; fmha_args.batch_stride_dk = args.is_group_mode() ? 0 : (is_mqa_gqa? args.stride_b_dk_expanded : args.stride_b_dk); fmha_args.batch_stride_dv = args.is_group_mode() ? 0 : (is_mqa_gqa? args.stride_b_dv_expanded : args.stride_b_dv); // for dbias, use h since h can be different from bias_h fmha_args.batch_stride_dbias = args.is_group_mode() ? 0 : args.h * args.s_q * args.s_kv; - fmha_args.split_stride_dq_acc = static_cast(args.is_group_mode() ? (args.max_tokens_q * args.h * args.d_qk) : (args.b * args.h * args.s_q * args.d_qk)); fmha_args.window_size_left = args.window_size_left; fmha_args.window_size_right = args.window_size_right; @@ -582,19 +604,77 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ // lse_workspace_ptr used as buffer if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")) { if(args.is_group_mode() && std::string(env_p) == "1"){ - if(ck_log_config){ - std::cout << "attn_bwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; + if(log_file){ + *log_file << "attn_bwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; } fmha_args.max_seqlen_q = get_runtime_max_seqlen(args.b, args.cu_seqlen_q_ptr, nullptr, args.lse_workspace_ptr, stream); fmha_args.max_seqlen_k = get_runtime_max_seqlen(args.b, args.cu_seqlen_kv_ptr, nullptr, args.lse_workspace_ptr, stream); } } + // Device-side workspace allocations made inside mha_bwd (launcher metadata + // and the dq_acc accumulator). aiter only contracts that the pointer remain + // valid for the duration of the kernels it enqueues; hipFreeAsync on the + // same stream defers the free until that work completes. + std::vector mha_bwd_workspaces; + fmha_args.workspace_alloc = [&mha_bwd_workspaces, stream](size_t bytes, bool zero_init) -> void* { + if(bytes == 0){ + return nullptr; + } + void* ptr = nullptr; + if(hipMallocAsync(&ptr, bytes, stream) != hipSuccess){ + throw std::runtime_error("ck_fused_attn bwd: hipMallocAsync failed for AITER workspace."); + } + if(zero_init){ + if(hipMemsetAsync(ptr, 0, bytes, stream) != hipSuccess){ + hipFreeAsync(ptr, stream); + throw std::runtime_error("ck_fused_attn bwd: hipMemsetAsync failed for AITER workspace."); + } + } + mha_bwd_workspaces.push_back(ptr); + return ptr; + }; + // Group mode requires a pinned host buffer for the async D2H seqstart + // pipeline; aiter keeps the shared_ptr alive past kernel completion via a + // stream-tail hipLaunchHostFunc keepalive. The deleter fires from that HIP + // callback thread, which holds runtime locks — calling any HIP API from it + // (including hipHostFree) deadlocks against concurrent main-thread HIP + // calls. Defer the free to ck_tile::pinned_host_releaser's worker thread. + fmha_args.pinned_host_alloc = [](size_t bytes) -> std::shared_ptr { + if(bytes == 0){ + return {}; + } + void* ptr = nullptr; + if(hipHostMalloc(&ptr, bytes, hipHostMallocDefault) != hipSuccess){ + throw std::runtime_error("ck_fused_attn bwd: hipHostMalloc failed for AITER pinned host buffer."); + } + return std::shared_ptr(ptr, [](void* p){ + ck_tile::pinned_host_releaser::instance().enqueue(p); + }); + }; + // print ck traits and args when needed - if(ck_log_config){ - log_bwd_config(__FUNCTION__, fmha_args); + if(log_file){ + log_bwd_config(__FUNCTION__, fmha_args, log_file); + } + float average_runtime; +#if defined(NVTE_AITER_V3_BWD_GFX1250) + if(is_gfx1250_device()){ + average_runtime = qola::te_v3::mha_bwd(fmha_args, stream_config); + } else +#endif + { +#if defined(NVTE_AITER_CK_FULL) + average_runtime = QOLA_NS(mha_bwd)(fmha_args, stream_config); +#else + throw std::runtime_error( + "ck_fused_attn bwd: this build has no CK-full AITER backward library " + "(no CDNA archs built); only the staged gfx1250 V3 path is present."); +#endif + } + for(void* ws_ptr : mha_bwd_workspaces){ + hipFreeAsync(ws_ptr, stream); } - float average_runtime = QOLA_NS(mha_bwd)(fmha_args, stream_config); if(average_runtime < 0){ //TODO: better error out system throw std::runtime_error("fused attn configs not supported in ck_fused_attn bwd pass."); @@ -610,7 +690,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ dim3 grid(args.max_tokens_kv, args.hg); if(args.d_qk == args.d_v){ dim3 block(args.d_qk); - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dk_dv_reduce_thd: " << "\n"; *log_file << "cu_seqlen_kv_ptr: " << args.cu_seqlen_kv_ptr << "\n"; *log_file << "cu_seqlen_kv_padded_ptr: " << args.cu_seqlen_kv_padded_ptr << "\n"; @@ -637,7 +717,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ args.stride_h_dk, args.stride_s_dk);); } else { dim3 block_dk(args.d_qk); - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dk_or_dv_reduce_thd on dk: " << "\n"; *log_file << "cu_seqlen_kv_ptr: " << args.cu_seqlen_kv_ptr << "\n"; *log_file << "cu_seqlen_kv_padded_ptr: " << args.cu_seqlen_kv_padded_ptr << "\n"; @@ -660,7 +740,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ args.stride_h_dk, args.stride_s_dk);); dim3 block_dv(args.d_v); - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dk_or_dv_reduce_thd on dv: " << "\n"; *log_file << "cu_seqlen_kv_ptr: " << args.cu_seqlen_kv_ptr << "\n"; *log_file << "cu_seqlen_kv_padded_ptr: " << args.cu_seqlen_kv_padded_ptr << "\n"; @@ -686,7 +766,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ dim3 grid(args.b, args.s_kv, args.hg); if(args.d_qk == args.d_v){ dim3 block(args.d_qk); - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dk_dv_reduce: " << "\n"; *log_file << "dk_expanded_ptr: " << args.dk_expanded_ptr << "\n"; *log_file << "dv_expanded_ptr: " << args.dv_expanded_ptr << "\n"; @@ -711,7 +791,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ args.stride_b_dk, args.stride_h_dk, args.stride_s_dk);); } else { dim3 block_dk(args.d_qk); - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dk_or_dv_reduce on dk: " << "\n"; *log_file << "dk_expanded_ptr: " << args.dk_expanded_ptr << "\n"; *log_file << "stride_b_dk_expanded: " << args.stride_b_dk_expanded << "\n"; @@ -732,7 +812,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ args.stride_b_dk, args.stride_h_dk, args.stride_s_dk);); dim3 block_dv(args.d_v); - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dk_or_dv_reduce on dv: " << "\n"; *log_file << "dv_expanded_ptr: " << args.dv_expanded_ptr << "\n"; *log_file << "stride_b_dv_expanded: " << args.stride_b_dv_expanded << "\n"; @@ -762,7 +842,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ dim3 block(THREADS_PER_BLOCK); dim3 grid(ceil(1.0 * args.s_q * args.s_kv / THREADS_PER_BLOCK)); if(bias_shape==BiasShape::k11SS){ - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dbias_reduce_11SS: " << "\n"; *log_file << "dbias_ptr: " << args.dbias_ptr << "\n"; *log_file << "dbias_expanded_ptr: " << args.dbias_expanded_ptr << "\n"; @@ -774,7 +854,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ static_cast(args.dbias_expanded_ptr), static_cast(args.dbias_ptr));); }else if(bias_shape==BiasShape::k1HSS){ - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dbias_reduce_1HSS: " << "\n"; *log_file << "dbias_ptr: " << args.dbias_ptr << "\n"; *log_file << "dbias_expanded_ptr: " << args.dbias_expanded_ptr << "\n"; @@ -786,7 +866,7 @@ hipError_t ck_attn_bwd(const CkAttnBwdArgs& args, hipStream_t stream){ static_cast(args.dbias_expanded_ptr), static_cast(args.dbias_ptr));); }else if(bias_shape==BiasShape::kB1SS){ - if (auto* log_file = get_ck_log_stream()) { + if (log_file) { *log_file << "\n" << "run dbias_reduce_B1SS: " << "\n"; *log_file << "dbias_ptr: " << args.dbias_ptr << "\n"; *log_file << "dbias_expanded_ptr: " << args.dbias_expanded_ptr << "\n"; diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp index 0f407230c..074ad0042 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp @@ -15,9 +15,8 @@ namespace ck_fused_attn{ // print the fmha traits and fmha_args when calling ck apis -void log_fwd_config(const char* func_name, bool has_dropout, const aiter::mha_fwd_args& fmha_args){ +void log_fwd_config(const char* func_name, bool has_dropout, const aiter::mha_fwd_args& fmha_args, std::ostream* log_file){ - std::ostream* log_file = get_ck_log_stream(); (*log_file) << "\n" << func_name << "\n"; // debug fmha_traits @@ -103,11 +102,7 @@ hipError_t ck_attn_fwd(const CKAttnFwdArgs& args, hipStream_t stream){ bool has_dropout = (args.is_training && args.dropout_probability > 0.f); - bool ck_log_config = false; - if (const char* env_p = std::getenv("CK_FUSED_ATTN_LOG_CONFIG") ) { - if (env_p != nullptr && std::string(env_p) == "1") - ck_log_config = true; - } + auto* log_file = get_ck_log_stream(); const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); // print kernel name on verbose mode ck_tile::stream_config stream_config{stream, dump_path!=nullptr, get_ck_log_stream() != nullptr}; @@ -218,18 +213,29 @@ hipError_t ck_attn_fwd(const CKAttnFwdArgs& args, hipStream_t stream){ if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")){ if(args.is_group_mode() && std::string(env_p) == "1"){ - if(ck_log_config){ - std::cout << "attn_fwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; + if(log_file){ + *log_file << "attn_fwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; } fmha_args.max_seqlen_q = get_runtime_max_seqlen(args.b, args.cu_seqlen_q_ptr, args.cu_seqlen_q_padded_ptr, args.lse_ptr, stream); } } // print ck traits and fmha_args when needed - if(ck_log_config){ - log_fwd_config(__FUNCTION__, has_dropout, fmha_args); + if(log_file){ + log_fwd_config(__FUNCTION__, has_dropout, fmha_args, log_file); } +#if defined(NVTE_AITER_CK_FULL) float average_runtime = QOLA_NS(mha_fwd)(fmha_args, stream_config); +#else + // gfx1250-only build: no CK-full forward library exists (gfx1250 has no + // forward kernels). The unified backend selector never picks CK on gfx1250, + // so this path is unreachable at runtime; the guard only keeps the link + // closed when te_libmha_fwd.so is absent. + float average_runtime = -1.0f; + throw std::runtime_error( + "ck_fused_attn fwd: no CK-full AITER forward library in this build " + "(gfx1250 has no forward kernels)."); +#endif if(average_runtime < 0){ //TODO: better error out system throw std::runtime_error("fused attn configs not supported in ck_fused_attn fwd pass."); diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp index 1f837be41..93b8ff0c5 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp @@ -11,7 +11,6 @@ #include "fused_attn_aotriton.h" #include "fused_attn_ck.h" #include "../common.h" -#include "../util/cuda_runtime.h" //cuda::sm_arch #include "utils.h" // map NVTE_QKV_Layout to NVTE_QKV_Layout_Group @@ -283,12 +282,6 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( int64_t window_size_right, bool return_max_logit, bool cuda_graph) { using namespace transformer_engine; - //gfx1250 is disabled in ck_fused_attn/CMakeLists.txt and is not supported by curretnt aotriton - const int gpu_arch = cuda::sm_arch(cuda::current_device()); - if (gpu_arch == 125) { - return NVTE_Fused_Attn_Backend::NVTE_No_Backend; - } - // TODO: Add return_max_logit support if (return_max_logit) return NVTE_Fused_Attn_Backend::NVTE_No_Backend; diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp index 9a0161ca5..0c78f5a24 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp @@ -54,6 +54,11 @@ bool is_aotriton_backend_supported( int64_t window_size_right) { #ifdef USE_FUSED_ATTN_AOTRITON + // AOTriton has no gfx1250 support. + if(cuda::sm_arch(cuda::current_device()) == 125){ + return false; + } + //TODO: release after AOTriton support support Multi-latent attention if(head_dim_qk != head_dim_v){ return false; diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index 744d0575a..d4b9b005e 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -30,9 +30,9 @@ bool is_ck_backend_supported( float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, - size_t head_dim_qk, - size_t head_dim_v, - int64_t window_size_left, + size_t head_dim_qk, + size_t head_dim_v, + int64_t window_size_left, int64_t window_size_right) { #ifdef USE_FUSED_ATTN_CK @@ -154,6 +154,22 @@ bool is_ck_backend_supported( } return false; } + + // gfx1250 (RDNA4) ships AITER V3 *backward* asm kernels only — there are no + // forward kernels at the pinned commit. TE selects one fused-attn backend per + // op and uses it for both directions (backward inherits the forward's choice), + // so selecting CK here would route the forward into a kernel-less path. + // Until the forward is handled (direction-aware backend selection, or gfx1250 + // forward kernels), do not select CK on gfx1250 through this unified path. + // The CK-free V3 backward library (te_v3_module_fmha_v3_bwd.so) and the + // runtime dispatch in ck_attn_bwd are built and staged for that activation. + if(cuda::sm_arch(cuda::current_device()) == 125){ + if(nvte_log_ck_config){ + std::cout<<"gfx1250 CK fused attn is staged (backward-only); not selected via the unified backend yet"< hg); - size_t kN0 = (d_qk <= 128)? 128:64; - size_t nsplits = deterministic? ceil(1.0*s_kv/kN0):1; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(layout); bool is_ragged = qkv_format==NVTE_QKV_Format::NVTE_THD; bool is_SBHD = qkv_format==NVTE_QKV_Format::NVTE_SBHD || qkv_format==NVTE_QKV_Format::NVTE_SBHD_2BSHD; @@ -770,9 +783,6 @@ void fused_attn_ck_bwd_impl( // First h*max_tokens_q*sizeof(float) is the lse-d buffer (passed as softmax_lsed) void* lse_workspace = planner.allocate(h*max_tokens_q*sizeof(float)); - // CK requires dq_acc ptr; size depends on deterministic mode - void* dq_acc_ptr = planner.allocate(nsplits*h*max_tokens_q*d_qk*sizeof(float)); - void* dk_expanded_ptr = nullptr; void* dv_expanded_ptr = nullptr; std::array dk_expanded_stride; @@ -913,8 +923,6 @@ void fused_attn_ck_bwd_impl( } // Initialize workspace buffers. - // dq_acc is of shape (nsplits, B, S, H, D_qk); CK requires zeroing - NVTE_CHECK_CUDA(cudaMemsetAsync(dq_acc_ptr, 0, sizeof(float)*nsplits*h*max_tokens_q*d_qk, stream)); if(devPtrAlibiSlope){ dim3 block, grid; block.x = 1024; @@ -992,7 +1000,6 @@ void fused_attn_ck_bwd_impl( ck_args.attn_mask_type = set_ck_mask(mask_type, window_size_left, window_size_right); ck_args.window_size_left = window_size_left; ck_args.window_size_right = window_size_right; - ck_args.dq_acc_ptr = dq_acc_ptr; ck_args.dk_expanded_ptr = dk_expanded_ptr; ck_args.dv_expanded_ptr = dv_expanded_ptr; ck_args.lse_workspace_ptr = lse_workspace; diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.h b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.h index 0772609ff..de45cbeb7 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.h +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.h @@ -26,9 +26,9 @@ bool is_ck_backend_supported( float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, - size_t head_dim_qk, - size_t head_dim_v, - int64_t window_size_left, + size_t head_dim_qk, + size_t head_dim_v, + int64_t window_size_left, int64_t window_size_right); } // namespace fused_attn_rocm