Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/build_kernel_xpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ jobs:
# For now we only test that there are no regressions in building XPU
# kernels. Also run tests once we have a XPU runner.
- name: Build relu kernel
run: ( cd examples/kernels/relu && nix build .\#redistributable.torch211-xpu20253-x86_64-linux -L )
run: ( cd examples/kernels/relu && nix build .\#redistributable.torch211-cxx11-xpu20253-x86_64-linux -L )

- name: Build relu tvm-ffi kernel
run: ( cd examples/kernels/relu-tvm-ffi && nix build .\#redistributable.tvm-ffi01-xpu20253-x86_64-linux -L )

- name: Build relu kernel (compiler flags)
run: ( cd examples/kernels/relu-compiler-flags && nix build .\#redistributable.torch211-xpu20253-x86_64-linux )
run: ( cd examples/kernels/relu-compiler-flags && nix build .\#redistributable.torch211-cxx11-xpu20253-x86_64-linux )

- name: Build cutlass-gemm kernel
run: ( cd examples/kernels/cutlass-gemm && nix build .\#redistributable.torch211-xpu20253-x86_64-linux -L )
run: ( cd examples/kernels/cutlass-gemm && nix build .\#redistributable.torch211-cxx11-xpu20253-x86_64-linux -L )
31 changes: 20 additions & 11 deletions examples/kernels/flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@
{
name = "cpp20-symbols-kernel";
path = ./cpp20-symbols;
drv = sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-cpu-${sys}"};
drv = sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-cxx11-cpu-${sys}"};
}
{
name = "relu-kernel";
path = ./relu;
drv = sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-${cudaVersion}-${sys}"};
drv =
sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-cxx11-${cudaVersion}-${sys}"};
}
{
name = "relu-torch-stable-abi-kernel";
Expand All @@ -53,17 +54,19 @@
{
name = "extra-data";
path = ./extra-data;
drv = sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-${cudaVersion}-${sys}"};
drv =
sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-cxx11-${cudaVersion}-${sys}"};
}
{
name = "relu-kernel-cpu";
path = ./relu;
drv = sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-cpu-${sys}"};
drv = sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-cxx11-cpu-${sys}"};
}
{
name = "cutlass-gemm-kernel";
path = ./cutlass-gemm;
drv = sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-${cudaVersion}-${sys}"};
drv =
sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-cxx11-${cudaVersion}-${sys}"};
}
{
name = "cutlass-gemm-tvm-ffi-kernel";
Expand All @@ -74,7 +77,8 @@
{
name = "relu-backprop-compile-kernel";
path = ./relu-backprop-compile;
drv = sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-${cudaVersion}-${sys}"};
drv =
sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-cxx11-${cudaVersion}-${sys}"};
}
{
name = "silu-and-mul-kernel";
Expand All @@ -101,12 +105,14 @@
{
name = "relu-compiler-flags";
path = ./relu-compiler-flags;
drv = sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-${cudaVersion}-${sys}"};
drv =
sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-cxx11-${cudaVersion}-${sys}"};
}
{
name = "relu-invalid-capability";
path = ./relu-invalid-capability;
drv = sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-${cudaVersion}-${sys}"};
drv =
sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-cxx11-${cudaVersion}-${sys}"};
assertFail = true;
assertFailLogs = [ "empty set of capabilities" ];
}
Expand Down Expand Up @@ -142,19 +148,22 @@
{
name = "relu-invalid-capability";
path = ./relu-invalid-capability;
drv = sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-${rocmVersion}-${sys}"};
drv =
sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-cxx11-${rocmVersion}-${sys}"};
assertFail = true;
assertFailLogs = [ "empty set of architectures" ];
}
{
name = "relu-kernel";
path = ./relu;
drv = sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-${rocmVersion}-${sys}"};
drv =
sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-cxx11-${rocmVersion}-${sys}"};
}
{
name = "relu-compiler-flags";
path = ./relu-compiler-flags;
drv = sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-${rocmVersion}-${sys}"};
drv =
sys: out: out.packages.${sys}.redistributable.${"torch${torchVersion}-cxx11-${rocmVersion}-${sys}"};
}
];

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Generate a standardized build variant name following the pattern:
# torch<VERSION>-<COMPUTE>-<ARCH>-<OS>
# torch<VERSION>-[cxx11-]<COMPUTE>-<ARCH>-<OS>
# or, when compiled against the Torch stable ABI:
# torch-stable-abi<VERSION>-<COMPUTE>-<ARCH>-<OS>
#
Expand All @@ -13,7 +13,7 @@
# TORCH_STABLE_ABI - Stable ABI version the extension was compiled against (e.g., "2.11");
# when set, TORCH_VERSION is ignored and the prefix becomes
# torch-stable-abi<VERSION> (e.g., "2.11" -> "torch-stable-abi211")
# Example output: torch27-cu124-x86_64-linux (Linux)
# Example output: torch27-cxx11-cu124-x86_64-linux (Linux)
# torch27-cu124-x86_64-windows (Windows)
# torch27-metal-aarch64-darwin (macOS)
# torch-stable-abi211-cu124-x86_64-linux (Linux, stable ABI)
Expand Down Expand Up @@ -117,7 +117,12 @@ function(generate_build_name OUT_BUILD_NAME TORCH_VERSION COMPUTE_FRAMEWORK COMP
set(ARCH_OS_STRING "${CPU_ARCH}-${OS_NAME}")

# Assemble the final build name
set(BUILD_NAME "${TORCH_PREFIX}-${COMPUTE_STRING}-${ARCH_OS_STRING}")
# For non-stable-ABI Linux builds, include cxx11 ABI indicator for compatibility
if(NOT ARG_TORCH_STABLE_ABI AND ARCH_OS_STRING MATCHES "-linux$")
set(BUILD_NAME "${TORCH_PREFIX}-cxx11-${COMPUTE_STRING}-${ARCH_OS_STRING}")
else()
set(BUILD_NAME "${TORCH_PREFIX}-${COMPUTE_STRING}-${ARCH_OS_STRING}")
endif()

set(${OUT_BUILD_NAME} "${BUILD_NAME}" PARENT_SCOPE)
message(STATUS "Generated build name: ${BUILD_NAME}")
Expand Down
9 changes: 7 additions & 2 deletions nix-builder/lib/variants/torch.nix
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,14 @@ let
"torch${flattenVersion (lib.versions.majorMinor buildConfig.torchVersion)}"
else
"torch-stable-abi${flattenVersion (lib.versions.majorMinor stableAbi)}";
archString =
if buildConfig.system == "aarch64-darwin" then
"${torchString null}-${computeString}-${buildConfig.system}"
else
"${torchString null}-cxx11-${computeString}-${buildConfig.system}";
in
{
arch = "${torchString null}-${computeString}-${buildConfig.system}";
arch = archString;
noarch = "torch-${buildConfig.backend}";

kernelVariant =
Expand All @@ -36,7 +41,7 @@ in
if archVariant && kernelConfig.isTorchStableAbi then
"torch-stable-abi${flattenVersion (lib.versions.majorMinor kernelConfig.torchStableAbiVersion)}-${computeString}-${buildConfig.system}"
else if archVariant then
"torch${flattenVersion (lib.versions.majorMinor buildConfig.torchVersion)}-${computeString}-${buildConfig.system}"
archString
else
"torch-${buildConfig.backend}";

Expand Down
Loading