|
1 | | -"""Custom build: compile CPU kernels (make) during pip install.""" |
| 1 | +"""Custom build: compile CPU and CUDA kernels during pip install.""" |
2 | 2 |
|
3 | 3 | import subprocess |
| 4 | +import sys |
4 | 5 | import os |
5 | 6 | from setuptools import setup |
6 | 7 | from setuptools.command.build_py import build_py |
7 | 8 | from setuptools.command.develop import develop |
8 | 9 |
|
9 | 10 | ROOT = os.path.dirname(os.path.abspath(__file__)) |
10 | 11 |
|
11 | | -KERNEL_DIRS = [ |
| 12 | +CPU_KERNEL_DIRS = [ |
12 | 13 | os.path.join(ROOT, "kernels", "bit_1", "cpu"), |
13 | 14 | os.path.join(ROOT, "kernels", "bit_1_58", "cpu"), |
14 | 15 | ] |
15 | 16 |
|
| 17 | +CUDA_KERNEL_DIR_BIT1 = os.path.join(ROOT, "kernels", "bit_1", "cuda") |
| 18 | +CUDA_KERNEL_DIR_BIT158 = os.path.join(ROOT, "kernels", "bit_1_58", "cuda") |
16 | 19 |
|
17 | | -def _build_kernels(): |
18 | | - for d in KERNEL_DIRS: |
| 20 | + |
| 21 | +def _build_cpu_kernels(): |
| 22 | + for d in CPU_KERNEL_DIRS: |
19 | 23 | if os.path.isdir(d) and os.path.isfile(os.path.join(d, "Makefile")): |
20 | 24 | subprocess.check_call(["make", "-C", d]) |
21 | 25 |
|
22 | 26 |
|
| 27 | +def _print_cuda_skip_warning(): |
| 28 | + """Print a warning that CUDA kernels were not pre-built.""" |
| 29 | + BOLD_RED = "\033[1;31m" |
| 30 | + RESET = "\033[0m" |
| 31 | + YELLOW = "\033[33m" |
| 32 | + print() |
| 33 | + print(f"{YELLOW}setup.py: CUDA not available — CUDA kernels were not pre-built.{RESET}") |
| 34 | + print(f"{YELLOW} They will be JIT-compiled on the first CUDA run, if available.{RESET}") |
| 35 | + print() |
| 36 | + print(f" {BOLD_RED}FOR BENCHMARKS PAY ATTENTION TO FIRST BUILD TIME{RESET}") |
| 37 | + print() |
| 38 | + |
| 39 | + |
| 40 | +def _build_cuda_kernels(): |
| 41 | + """JIT-compile all CUDA kernels so first run has zero compilation delay.""" |
| 42 | + try: |
| 43 | + import torch |
| 44 | + if not torch.cuda.is_available(): |
| 45 | + _print_cuda_skip_warning() |
| 46 | + return |
| 47 | + except ImportError: |
| 48 | + _print_cuda_skip_warning() |
| 49 | + return |
| 50 | + |
| 51 | + from torch.utils.cpp_extension import load |
| 52 | + |
| 53 | + major, minor = torch.cuda.get_device_capability() |
| 54 | + os.environ["TORCH_CUDA_ARCH_LIST"] = f"{major}.{minor}" |
| 55 | + |
| 56 | + # Ensure ninja is on PATH |
| 57 | + bindir = os.path.dirname(sys.executable) |
| 58 | + path_entries = os.environ.get("PATH", "").split(os.pathsep) |
| 59 | + if bindir and bindir not in path_entries: |
| 60 | + os.environ["PATH"] = os.pathsep.join([bindir, *path_entries]) |
| 61 | + |
| 62 | + # -- bit_1 CUDA kernels (torch JIT) -- |
| 63 | + bit1_kernels = [ |
| 64 | + ("rsr_cuda_v5_9", "rsr_v5_9.cu"), |
| 65 | + ("rsr_cuda_v5_8", "rsr_v5_8.cu"), |
| 66 | + ("rsr_cuda_v5_6", "rsr_v5_6.cu"), |
| 67 | + ("rsr_cuda_v4_10", "rsr_v4_10.cu"), |
| 68 | + ] |
| 69 | + for name, source in bit1_kernels: |
| 70 | + source_path = os.path.join(CUDA_KERNEL_DIR_BIT1, source) |
| 71 | + if not os.path.isfile(source_path): |
| 72 | + continue |
| 73 | + print(f"setup.py: JIT compiling {name} ...") |
| 74 | + try: |
| 75 | + load( |
| 76 | + name=name, |
| 77 | + sources=[source_path], |
| 78 | + extra_cuda_cflags=["-O3", "--use_fast_math"], |
| 79 | + verbose=False, |
| 80 | + ) |
| 81 | + except Exception as e: |
| 82 | + print(f"setup.py: WARNING: failed to compile {name}: {e}") |
| 83 | + |
| 84 | + # -- bit_1_58 CUDA kernels (torch JIT) -- |
| 85 | + bit158_jit_kernels = [ |
| 86 | + ("rsr_ternary_cuda_v2_0", "rsr_ternary_v2_0.cu"), |
| 87 | + ] |
| 88 | + for name, source in bit158_jit_kernels: |
| 89 | + source_path = os.path.join(CUDA_KERNEL_DIR_BIT158, source) |
| 90 | + if not os.path.isfile(source_path): |
| 91 | + continue |
| 92 | + print(f"setup.py: JIT compiling {name} ...") |
| 93 | + try: |
| 94 | + load( |
| 95 | + name=name, |
| 96 | + sources=[source_path], |
| 97 | + extra_cuda_cflags=["-O3", "--use_fast_math"], |
| 98 | + verbose=False, |
| 99 | + ) |
| 100 | + except Exception as e: |
| 101 | + print(f"setup.py: WARNING: failed to compile {name}: {e}") |
| 102 | + |
| 103 | + # -- bit_1_58 BitNet kernel (nvcc direct) -- |
| 104 | + bitnet_source = os.path.join(CUDA_KERNEL_DIR_BIT158, "bitnet_kernels.cu") |
| 105 | + bitnet_lib = os.path.join(CUDA_KERNEL_DIR_BIT158, "libbitnet.so") |
| 106 | + if os.path.isfile(bitnet_source) and not os.path.isfile(bitnet_lib): |
| 107 | + cuda_home = os.environ.get("CUDA_HOME", "/usr/local/cuda") |
| 108 | + nvcc = os.path.join(cuda_home, "bin", "nvcc") |
| 109 | + if os.path.isfile(nvcc): |
| 110 | + arch = f"{major}{minor}" |
| 111 | + cmd = [ |
| 112 | + nvcc, "-std=c++17", "--shared", "--compiler-options", "-fPIC", |
| 113 | + "-O3", "--use_fast_math", "-lineinfo", |
| 114 | + f"-gencode=arch=compute_{arch},code=sm_{arch}", |
| 115 | + f"-gencode=arch=compute_{arch},code=compute_{arch}", |
| 116 | + bitnet_source, "-o", bitnet_lib, |
| 117 | + ] |
| 118 | + print(f"setup.py: compiling libbitnet.so ...") |
| 119 | + try: |
| 120 | + subprocess.run(cmd, cwd=CUDA_KERNEL_DIR_BIT158, check=True, |
| 121 | + capture_output=True, text=True) |
| 122 | + except Exception as e: |
| 123 | + print(f"setup.py: WARNING: failed to compile libbitnet.so: {e}") |
| 124 | + |
| 125 | + |
| 126 | +def _build_all_kernels(): |
| 127 | + _build_cpu_kernels() |
| 128 | + _build_cuda_kernels() |
| 129 | + |
| 130 | + |
23 | 131 | class BuildPyWithKernels(build_py): |
24 | 132 | def run(self): |
25 | | - _build_kernels() |
| 133 | + _build_all_kernels() |
26 | 134 | super().run() |
27 | 135 |
|
28 | 136 |
|
29 | 137 | class DevelopWithKernels(develop): |
30 | 138 | def run(self): |
31 | | - _build_kernels() |
| 139 | + _build_all_kernels() |
32 | 140 | super().run() |
33 | 141 |
|
34 | 142 |
|
|
0 commit comments