-
Notifications
You must be signed in to change notification settings - Fork 248
Expand file tree
/
Copy pathsetup.py
More file actions
141 lines (124 loc) · 4.74 KB
/
setup.py
File metadata and controls
141 lines (124 loc) · 4.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import torch
import glob
from setuptools import find_packages, setup
from torch.utils.cpp_extension import (
CppExtension,
CUDAExtension,
BuildExtension,
CUDA_HOME,
)
# Conditional import for SyclExtension
try:
from torch.utils.cpp_extension import SyclExtension
except ImportError:
SyclExtension = None
library_name = "extension_cpp"
# NOTE: PyTorch versions < 2.6 use torch.extension.h which depends on pybind11,
# and pybind11 requires full access to Python's C API (including internal
# structures like PyObject). This makes it incompatible with Py_LIMITED_API
# which restricts access to only stable Python C API symbols.
# For Py_LIMITED_API compatibility, use torch.library.h instead (PyTorch 2.6+).
if torch.__version__ >= "2.6.0":
py_limited_api = True
else:
py_limited_api = False
def get_extensions():
debug_mode = os.getenv("DEBUG", "0") == "1"
# Determine backend (CUDA, SYCL, or C++)
use_cuda = os.getenv("USE_CUDA", "auto")
use_sycl = os.getenv("USE_SYCL", "auto")
# Auto-detect CUDA
if use_cuda == "auto":
use_cuda = torch.cuda.is_available() and CUDA_HOME is not None
else:
use_cuda = use_cuda.lower() == "true" or use_cuda == "1"
# Auto-detect SYCL
if use_sycl == "auto":
use_sycl = SyclExtension is not None and torch.xpu.is_available()
else:
use_sycl = use_sycl.lower() == "true" or use_sycl == "1"
if use_cuda and use_sycl:
raise RuntimeError("Cannot enable both CUDA and SYCL backends simultaneously.")
print("use cuda & use sycl",use_cuda, use_sycl)
extension = None
if use_cuda:
extension = CUDAExtension
print("Building with CUDA backend")
elif use_sycl and SyclExtension is not None:
extension = SyclExtension
print("Building with SYCL backend")
else:
extension = CppExtension
print("Building with C++ backend")
# Compilation arguments
extra_link_args = []
extra_compile_args = {"cxx": []}
if extension == CUDAExtension:
print("CUDA is available, compile using CUDAExtension")
extra_compile_args = {
"cxx": ["-O3" if not debug_mode else "-O0",
"-fdiagnostics-color=always",
"-DPy_LIMITED_API=0x03090000"],
"nvcc": ["-O3" if not debug_mode else "-O0"]
}
elif extension == SyclExtension:
print("XPU is available, compile using SyclExtension")
extra_compile_args = {
"cxx": ["-O3" if not debug_mode else "-O0",
"-fdiagnostics-color=always",
"-DPy_LIMITED_API=0x03090000"],
"sycl": ["-O3" if not debug_mode else "-O0"]
}
else:
extra_compile_args["cxx"] = [
"-O3" if not debug_mode else "-O0",
"-DPy_LIMITED_API=0x03090000"]
if debug_mode:
extra_compile_args["cxx"].append("-g")
if extension == CUDAExtension:
extra_compile_args["nvcc"].append("-g")
extra_link_args.extend(["-O0", "-g"])
elif extension == SyclExtension:
extra_compile_args["sycl"].append("-g")
extra_link_args.extend(["-O0", "-g"])
# Source files collection
this_dir = os.path.dirname(os.path.curdir)
extensions_dir = os.path.join(this_dir, library_name, "csrc")
sources = list(glob.glob(os.path.join(extensions_dir, "*.cpp")))
backend_sources = []
if extension == CUDAExtension:
backend_dir = os.path.join(extensions_dir, "cuda")
backend_sources = glob.glob(os.path.join(backend_dir, "*.cu"))
elif extension == SyclExtension:
backend_dir = os.path.join(extensions_dir, "sycl")
backend_sources = glob.glob(os.path.join(backend_dir, "*.sycl"))
sources += backend_sources
# Construct extension
ext_modules = [
extension(
f"{library_name}._C",
sources,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
py_limited_api=py_limited_api,
)
]
return ext_modules
setup(
name=library_name,
version="0.0.1",
packages=find_packages(),
ext_modules=get_extensions(),
install_requires=["torch"],
description="Example of PyTorch C++ and CUDA/Sycl extensions",
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
url="https://github.com/pytorch/extension-cpp",
cmdclass={"build_ext": BuildExtension},
options={"bdist_wheel": {"py_limited_api": "cp39"}} if py_limited_api else {},
)