|
| 1 | +--- a/setup.py |
| 2 | ++++ b/setup.py |
| 3 | +@@ -5,13 +5,17 @@ |
| 4 | + import os |
| 5 | + import shutil |
| 6 | + from os import path |
| 7 | ++import sys |
| 8 | + from setuptools import find_packages, setup |
| 9 | + from typing import List |
| 10 | +-import torch |
| 11 | +-from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension |
| 12 | + |
| 13 | +-torch_ver = [int(x) for x in torch.__version__.split(".")[:2]] |
| 14 | +-assert torch_ver >= [1, 8], "Requires PyTorch >= 1.8" |
| 15 | ++ |
| 16 | ++def get_torch(): |
| 17 | ++ import torch |
| 18 | ++ |
| 19 | ++ torch_ver = [int(x) for x in torch.__version__.split(".")[:2]] |
| 20 | ++ assert torch_ver >= [1, 8], "Requires PyTorch >= 1.8" |
| 21 | ++ return torch, torch_ver |
| 22 | + |
| 23 | + |
| 24 | + def get_version(): |
| 25 | +@@ -44,6 +48,8 @@ |
| 26 | + main_source = path.join(extensions_dir, "vision.cpp") |
| 27 | + sources = glob.glob(path.join(extensions_dir, "**", "*.cpp")) |
| 28 | + |
| 29 | ++ torch, torch_ver = get_torch() |
| 30 | ++ from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension |
| 31 | + from torch.utils.cpp_extension import ROCM_HOME |
| 32 | + |
| 33 | + is_rocm_pytorch = ( |
| 34 | +@@ -144,6 +150,24 @@ |
| 35 | + "detectron2.projects.panoptic_deeplab": "projects/Panoptic-DeepLab/panoptic_deeplab", |
| 36 | + } |
| 37 | + |
| 38 | ++BUILD_COMMANDS = { |
| 39 | ++ "build", |
| 40 | ++ "build_ext", |
| 41 | ++ "bdist_wheel", |
| 42 | ++ "editable_wheel", |
| 43 | ++ "develop", |
| 44 | ++ "install", |
| 45 | ++} |
| 46 | ++SHOULD_BUILD_EXTENSIONS = any(cmd in BUILD_COMMANDS for cmd in sys.argv[1:]) |
| 47 | ++ |
| 48 | ++if SHOULD_BUILD_EXTENSIONS: |
| 49 | ++ torch, _ = get_torch() |
| 50 | ++ ext_modules = get_extensions() |
| 51 | ++ cmdclass = {"build_ext": torch.utils.cpp_extension.BuildExtension} |
| 52 | ++else: |
| 53 | ++ ext_modules = [] |
| 54 | ++ cmdclass = {} |
| 55 | ++ |
| 56 | + setup( |
| 57 | + name="detectron2", |
| 58 | + version=get_version(), |
| 59 | +@@ -203,6 +227,6 @@ |
| 60 | + "flake8-comprehensions", |
| 61 | + ], |
| 62 | + }, |
| 63 | +- ext_modules=get_extensions(), |
| 64 | +- cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, |
| 65 | ++ ext_modules=ext_modules, |
| 66 | ++ cmdclass=cmdclass, |
| 67 | + ) |
0 commit comments