-
Notifications
You must be signed in to change notification settings - Fork 29
CK JIT integration #582
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
CK JIT integration #582
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -106,6 +106,7 @@ install_prerequisites | |
| pip list | egrep "flax|fidle|jax|ml_dtypes|numpy|transformer_e|typing_ext" | ||
| #check_test_jobs_requested | ||
| #test $? -eq 0 && init_test_jobs `python -c "import jax; print(len([d for d in jax.devices() if 'rocm' in d.client.platform_version]))"` | ||
| ck_jit_prebuild build | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggest making this fail-fast, e.g. |
||
|
|
||
| for _fus_attn in auto ck aotriton; do | ||
| configure_fused_attn_env $_fus_attn || continue | ||
|
|
@@ -139,4 +140,6 @@ if [ -n "$TEST_JOBS_MODE" -a -n "$TEST_MGPU" ]; then | |
| configure_fused_attn_env $_fus_attn && run_test_config_mgpu | ||
| done | ||
| fi | ||
|
|
||
| ck_jit_prebuild list | ||
| return_run_results | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -76,16 +76,12 @@ def setup_common_extension() -> CMakeExtension: | |
| os.getenv("MPI_HOME") is not None | ||
| ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1" | ||
| cmake_flags.append("-DNVTE_UB_WITH_MPI=ON") | ||
|
|
||
| if rocm_build(): | ||
| cmake_flags.append("-DUSE_ROCM=ON") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This hunk silently drops the Downstream the CMake files now read PR is marked as breaking, but the description doesn't call this rename out. Two reasonable options:
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this an intentional part of this change? Seems unrelated? |
||
| if os.getenv("NVTE_AOTRITON_PATH"): | ||
| aotriton_path = Path(os.getenv("NVTE_AOTRITON_PATH")) | ||
| cmake_flags.append(f"-DAOTRITON_PATH={aotriton_path}") | ||
| cmake_flags.append(f"-DCK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT={os.getenv('NVTE_CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT', 3)}") | ||
| if os.getenv("NVTE_CK_FUSED_ATTN_PATH"): | ||
| ck_path = Path(os.getenv("NVTE_CK_FUSED_ATTN_PATH")) | ||
| cmake_flags.append(f"-DAITER_MHA_PATH={ck_path}") | ||
| cmake_flags.append( | ||
| f"-DCK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT={os.getenv('NVTE_CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT', '3')}" | ||
| ) | ||
|
|
||
| if int(os.getenv("NVTE_FUSED_ATTN_AOTRITON", "1"))==0 or int(os.getenv("NVTE_FUSED_ATTN", "1"))==0: | ||
| cmake_flags.append("-DUSE_FUSED_ATTN_AOTRITON=OFF") | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,7 +6,7 @@ set(CMAKE_CXX_STANDARD 17) | |
| project(ck_fused_attn LANGUAGES HIP CXX) | ||
|
|
||
|
|
||
| set(AITER_MHA_INSTALL_PREFIX "transformer_engine" CACHE STRING "aiter mha shared lib install prefix in TE") | ||
| set(AITER_MHA_INSTALL_DIR "${CMAKE_INSTALL_PREFIX}/transformer_engine/lib") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor: this replaces the previous |
||
|
|
||
| #Corresponding runtime check is in nvte_get_fused_attn_backend() | ||
| list(FIND CMAKE_HIP_ARCHITECTURES "gfx1250" _gfx1250_idx) | ||
|
|
@@ -67,22 +67,54 @@ else() | |
| message(WARNING "Python interpreter not found; skipping AITER API validation.") | ||
| endif() | ||
|
|
||
| if(DEFINED AITER_MHA_PATH) | ||
| message(STATUS "[AITER-BUILD] Using AITER_MHA_PATH=${AITER_MHA_PATH}") | ||
| # use pre-built te_libmha_fwd.so te_libmha_bwd.so | ||
| set(__AITER_MHA_PATH ${AITER_MHA_PATH}) | ||
| set(__AITER_CACHE_DIR "") | ||
| set(__AITER_MHA_PATH "") | ||
| set(__QOLA_INCLUDE_DIR "") | ||
| if(NOT "$ENV{NVTE_CK_JIT}" STREQUAL "0") | ||
| set(__USE_CK_JIT TRUE) | ||
| else() | ||
| set(__AITER_MHA_PATH "") | ||
| set(__USE_CK_JIT FALSE) | ||
| endif() | ||
| if(DEFINED ENV{AITER_MHA_PATH}) | ||
| message(STATUS "[AITER-BUILD] Using AITER_MHA_PATH=$ENV{AITER_MHA_PATH}") | ||
| # use pre-built libraries and includes from a location specified by the user | ||
| set(__AITER_CACHE_DIR $ENV{AITER_MHA_PATH}) | ||
| elseif(NOT __USE_CK_JIT) #disable for CK_JIT for now | ||
| # use pre-built cache | ||
| include("${CMAKE_CURRENT_LIST_DIR}/aiter_prebuilt.cmake") | ||
| get_prebuilt_aiter(__AITER_MHA_PATH) | ||
| get_prebuilt_aiter(__AITER_CACHE_DIR) | ||
| endif() | ||
|
Comment on lines
+73
to
+86
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Two concerns about the default-on behavior:
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto on point 2, I think we should still use the pre-built cache if available as first priority. |
||
|
|
||
| if(__AITER_MHA_PATH STREQUAL "") | ||
| # If not available, fallback: Build from source via QoLA | ||
| list(JOIN CMAKE_HIP_ARCHITECTURES ";" GPU_ARCHS_STR) | ||
| message(STATUS "[AITER-BUILD] Building AITER kernels for ${GPU_ARCHS_STR} via QoLA.") | ||
| set(__QOLA_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/QoLA") | ||
| if(__AITER_CACHE_DIR STREQUAL "") | ||
| # If not available, fallback: Build from source via QoLA | ||
| list(JOIN CMAKE_HIP_ARCHITECTURES ";" GPU_ARCHS_STR) | ||
| message(STATUS "[AITER-BUILD] Building AITER kernels for ${GPU_ARCHS_STR} via QoLA.") | ||
| set(__QOLA_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/QoLA") | ||
| set(__QOLA_MANIFEST "${CMAKE_CURRENT_LIST_DIR}/qola_manifest.toml") | ||
| if(__USE_CK_JIT) | ||
| message(STATUS "[AITER-BUILD] CK_JIT is enabled; will build AITER kernels via CK_JIT.") | ||
| set(__CK_JIT_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/ck_jit") | ||
| set(__QOLA_BUILD_DIR "${__CK_JIT_BUILD_DIR}/qola") #Need it under ck_jit to clean on full build | ||
| if(DEFINED ENV{NVTE_CK_JIT_DIR}) | ||
| set(__CK_JIT_SOURCE_DIR $ENV{NVTE_CK_JIT_DIR}) | ||
| else() | ||
| set(__CK_JIT_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/ck_jit") | ||
| endif() | ||
| execute_process( | ||
| COMMAND ${Python_EXECUTABLE} "${__CK_JIT_SOURCE_DIR}/ck_jit_build.py" full | ||
| --with-qola | ||
| --qola-dir ${__QOLA_DIR} | ||
| --qola-manifest ${__QOLA_MANIFEST} | ||
| --qola-output "${__QOLA_BUILD_DIR}" | ||
| --gpu-archs "${GPU_ARCHS_STR}" | ||
| --aiter-dir ${__AITER_SOURCE_DIR} | ||
| --tmp-dir "${__CK_JIT_BUILD_DIR}" | ||
| --install-dir ${AITER_MHA_INSTALL_DIR} | ||
| --jit-name "te_ck_jit" | ||
| RESULT_VARIABLE QOLA_BUILD_RESULT | ||
| ) | ||
| else() | ||
| set(__QOLA_BUILD_DIR "${__QOLA_DIR}/build") | ||
| set(__QOLA_MANIFEST "${CMAKE_CURRENT_LIST_DIR}/qola_manifest.toml") | ||
| execute_process( | ||
| COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${__QOLA_DIR}:$ENV{PYTHONPATH}" | ||
| ${Python_EXECUTABLE} -m qola.cli build | ||
|
|
@@ -92,22 +124,29 @@ else() | |
| --arch "${GPU_ARCHS_STR}" | ||
| RESULT_VARIABLE QOLA_BUILD_RESULT | ||
| ) | ||
| if(NOT QOLA_BUILD_RESULT EQUAL 0) | ||
| message(FATAL_ERROR "[AITER-BUILD] QoLA build failed.") | ||
| endif() | ||
| endif() | ||
| if(NOT QOLA_BUILD_RESULT EQUAL 0) | ||
| message(FATAL_ERROR "[AITER-BUILD] QoLA build failed.") | ||
| endif() | ||
|
|
||
| if(__USE_CK_JIT) | ||
| set(__AITER_MHA_PATH ${AITER_MHA_INSTALL_DIR}) | ||
| set(__QOLA_INCLUDE_DIR "${__QOLA_BUILD_DIR}/include") | ||
| else() | ||
| # Copy the final .so libs and exported public headers into the aiter | ||
| # prebuilt cache so downstream consumers see a self-contained tree. | ||
| get_default_aiter_cache_dir(__QOLA_CACHE_DIR) | ||
| set(__QOLA_CACHE_LIB "${__QOLA_CACHE_DIR}/lib") | ||
| get_default_aiter_cache_dir(__AITER_CACHE_DIR) | ||
| set(__QOLA_CACHE_LIB "${__AITER_CACHE_DIR}/lib") | ||
| file(MAKE_DIRECTORY ${__QOLA_CACHE_LIB}) | ||
| file(GLOB __QOLA_BUILT_LIBS "${__QOLA_BUILD_DIR}/lib/*.so") | ||
| file(COPY ${__QOLA_BUILT_LIBS} DESTINATION ${__QOLA_CACHE_LIB}) | ||
| file(COPY "${__QOLA_BUILD_DIR}/include" DESTINATION "${__QOLA_CACHE_DIR}") | ||
| file(COPY "${__QOLA_BUILD_DIR}/include" DESTINATION "${__AITER_CACHE_DIR}") | ||
| set(__AITER_MHA_PATH "${__QOLA_CACHE_LIB}") | ||
| else() | ||
| message(STATUS "[AITER-BUILD] Using pre-built AITER from ${__AITER_MHA_PATH}") | ||
| set(__QOLA_INCLUDE_DIR "${__AITER_CACHE_DIR}/include") | ||
| endif() | ||
| else() | ||
| set(__AITER_MHA_PATH "${__AITER_CACHE_DIR}/lib") | ||
| set(__QOLA_INCLUDE_DIR "${__AITER_CACHE_DIR}/include") | ||
| endif() | ||
|
|
||
| set(ck_fused_attn_SOURCES) | ||
|
|
@@ -129,7 +168,6 @@ list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS | |
| # Public QoLA headers ship alongside the .so libs in ${__AITER_MHA_PATH}/../include | ||
| # (emitted by qola.cli build, or copied from the QoLA build dir above for the | ||
| # source-build path). | ||
| set(__QOLA_INCLUDE_DIR "${__AITER_MHA_PATH}/../include") | ||
| if(NOT EXISTS "${__QOLA_INCLUDE_DIR}/qola_config.h") | ||
| message(FATAL_ERROR "Could not find QoLA public headers at ${__QOLA_INCLUDE_DIR}.") | ||
| endif() | ||
|
|
@@ -146,5 +184,7 @@ target_link_libraries(ck_fused_attn PUBLIC ${ck_fused_attn_LINKER_LIBS}) | |
| target_compile_options(ck_fused_attn PRIVATE ${CK_FUSED_ATTN_COMPILE_OPTIONS}) | ||
| set_target_properties(ck_fused_attn PROPERTIES INSTALL_RPATH "$ORIGIN") | ||
|
|
||
| install(FILES ${__AITER_MHA_PATH}/te_libmha_fwd.so ${__AITER_MHA_PATH}/te_libmha_bwd.so DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib) | ||
| install(TARGETS ck_fused_attn DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib) | ||
| if (NOT "${__AITER_MHA_PATH}" STREQUAL "${AITER_MHA_INSTALL_DIR}") | ||
| install(FILES ${__AITER_MHA_PATH}/te_libmha_fwd.so ${__AITER_MHA_PATH}/te_libmha_bwd.so DESTINATION ${AITER_MHA_INSTALL_DIR}) | ||
| endif() | ||
| install(TARGETS ck_fused_attn DESTINATION ${AITER_MHA_INSTALL_DIR}) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
ck_jitsubmodule points at a personal repo (github.com/ipanfilo/ck_jit.git). With CK_JIT being enabled by default (seetransformer_engine/common/ck_fused_attn/CMakeLists.txt:73), this becomes a hard dependency onipanfilo/ck_jitfor every default ROCm build — including wheels published fromrocm-wheels-build.yml. Recommend moving the repo under theROCmorg (or another organizational account) before this lands, so availability/ownership isn't tied to a single contributor account. The existing personal-repo submodules in this file (HaiShaw/minGPT,floraamd/nanoGPTwTE) are example-only and don't gate the build.