Skip to content

Commit c0b01aa

Browse files
NicolasHugpytorchmergebot
authored andcommitted
Add deleter support to torch::stable::from_blob (pytorch#173371)
This PR adds `deleter` support to `torch::stable::from_blob` by adding a new `aoti_torch_create_tensor_from_blob_v3`. We need it to cleanly port TorchCodec to the stable ABI in meta-pytorch/torchcodec#1188 There's a bit of scaffolding, especially for the tests where I had to create a new `test/cpp_extensions/libtorch_agn_2_11_extension` folder. Most of it is just copy/paste from `test/cpp_extensions/libtorch_agn_2_11_extension/libtorch_agn_2_10/__init__.py` Pull Request resolved: pytorch#173371 Approved by: https://github.com/janeyx99
1 parent 280457c commit c0b01aa

10 files changed

Lines changed: 465 additions & 3 deletions

File tree

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
#include <torch/csrc/stable/device.h>
2+
#include <torch/csrc/stable/library.h>
3+
#include <torch/csrc/stable/ops.h>
4+
#include <torch/csrc/stable/tensor.h>
5+
6+
#ifdef LAE_USE_CUDA
7+
#include <cuda_runtime.h>
8+
#endif
9+
10+
using torch::stable::Tensor;
11+
12+
// Global counter to track deleter calls for testing
13+
static int64_t g_deleter_call_count = 0;
14+
15+
static void test_deleter(void* /*data*/) {
16+
g_deleter_call_count++;
17+
}
18+
19+
// Wrapper for from_blob with deleter - uses a test deleter that increments
20+
// a global counter
21+
Tensor my_from_blob_with_deleter(
22+
int64_t data_ptr,
23+
torch::headeronly::HeaderOnlyArrayRef<int64_t> sizes,
24+
torch::headeronly::HeaderOnlyArrayRef<int64_t> strides,
25+
torch::stable::Device device,
26+
torch::headeronly::ScalarType dtype) {
27+
void* data = reinterpret_cast<void*>(data_ptr);
28+
return torch::stable::from_blob(
29+
data, sizes, strides, device, dtype, test_deleter);
30+
}
31+
32+
int64_t get_deleter_call_count() {
33+
return g_deleter_call_count;
34+
}
35+
36+
void reset_deleter_call_count() {
37+
g_deleter_call_count = 0;
38+
}
39+
40+
STABLE_TORCH_LIBRARY(libtorch_agn_2_11, m) {
41+
m.def(
42+
"my_from_blob_with_deleter(int data_ptr, int[] sizes, int[] strides, Device device, ScalarType dtype) -> Tensor");
43+
m.def("get_deleter_call_count() -> int");
44+
m.def("reset_deleter_call_count() -> ()");
45+
}
46+
47+
STABLE_TORCH_LIBRARY_IMPL(
48+
libtorch_agn_2_11,
49+
CompositeExplicitAutograd,
50+
m) {
51+
m.impl("my_from_blob_with_deleter", TORCH_BOX(&my_from_blob_with_deleter));
52+
m.impl("get_deleter_call_count", TORCH_BOX(&get_deleter_call_count));
53+
m.impl("reset_deleter_call_count", TORCH_BOX(&reset_deleter_call_count));
54+
}
55+
56+
#ifdef LAE_USE_CUDA
57+
58+
// Wrapper for cudaFree since it returns cudaError_t, not void
59+
static void cuda_deleter(void* data) {
60+
cudaFree(data);
61+
}
62+
63+
// Creates a tensor that owns its CUDA memory via cudaMalloc.
64+
// When the tensor is destroyed, the deleter will call cudaFree.
65+
// This tests that from_blob's deleter properly frees memory.
66+
Tensor my_from_blob_with_cuda_deleter(
67+
int64_t numel,
68+
torch::stable::Device device) {
69+
size_t size_bytes = numel * sizeof(float);
70+
71+
void* data = nullptr;
72+
cudaError_t err = cudaMalloc(&data, size_bytes);
73+
if (err != cudaSuccess) {
74+
throw std::runtime_error("cudaMalloc failed");
75+
}
76+
77+
// Zero the memory
78+
cudaMemset(data, 0, size_bytes);
79+
80+
std::array<int64_t, 1> sizes = {numel};
81+
std::array<int64_t, 1> strides = {1};
82+
83+
return torch::stable::from_blob(
84+
data,
85+
torch::headeronly::HeaderOnlyArrayRef<int64_t>(sizes.data(), sizes.size()),
86+
torch::headeronly::HeaderOnlyArrayRef<int64_t>(strides.data(), strides.size()),
87+
device,
88+
torch::headeronly::ScalarType::Float,
89+
cuda_deleter);
90+
}
91+
92+
STABLE_TORCH_LIBRARY(libtorch_agn_2_11_cuda, m) {
93+
m.def("my_from_blob_with_cuda_deleter(int numel, Device device) -> Tensor");
94+
}
95+
96+
STABLE_TORCH_LIBRARY_IMPL(libtorch_agn_2_11_cuda, CompositeExplicitAutograd, m) {
97+
m.impl("my_from_blob_with_cuda_deleter", TORCH_BOX(&my_from_blob_with_cuda_deleter));
98+
}
99+
100+
#endif // LAE_USE_CUDA
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import ctypes
2+
import sys
3+
from pathlib import Path
4+
5+
import torch
6+
7+
8+
so_files = list(
9+
Path(__file__).parent.glob("_C*" + (".pyd" if sys.platform == "win32" else ".so"))
10+
)
11+
assert len(so_files) == 1, f"Expected one _C*.{{so,pyd}} file, found {len(so_files)}"
12+
13+
# use ctypes.CDLL instead of load_library to be able to test the unload logic
14+
# below code is reduced from the load_library code
15+
with torch._ops.dl_open_guard():
16+
loaded_lib = ctypes.CDLL(str(so_files[0]))
17+
18+
from . import ops
19+
20+
21+
__all__ = [
22+
"loaded_lib",
23+
"ops",
24+
]
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import torch
2+
from torch import Tensor
3+
4+
5+
def my_from_blob_with_deleter(data_ptr, sizes, strides, device, dtype) -> Tensor:
6+
"""
7+
Creates a Tensor from existing memory with a deleter callback.
8+
9+
The deleter will be called when the tensor's storage is deallocated. For
10+
this test, the deleter just updates a global call count, which allows us to
11+
assert that is was called from get_deleter_call_count().
12+
13+
Args:
14+
data_ptr: int - pointer to the data buffer
15+
sizes: tuple[int] - size of the tensor
16+
strides: tuple[int] - strides of the tensor
17+
device: Device - device on which the tensor resides
18+
dtype: ScalarType - data type of the tensor
19+
20+
Returns: Tensor - tensor wrapping the existing memory
21+
"""
22+
return torch.ops.libtorch_agn_2_11.my_from_blob_with_deleter.default(
23+
data_ptr, sizes, strides, device, dtype
24+
)
25+
26+
27+
def get_deleter_call_count() -> int:
28+
"""
29+
Returns the number of times the test deleter has been called.
30+
"""
31+
return torch.ops.libtorch_agn_2_11.get_deleter_call_count.default()
32+
33+
34+
def reset_deleter_call_count() -> None:
35+
"""
36+
Resets the deleter call counter to zero.
37+
"""
38+
torch.ops.libtorch_agn_2_11.reset_deleter_call_count.default()
39+
40+
41+
def my_from_blob_with_cuda_deleter(numel: int, device) -> Tensor:
42+
"""
43+
Creates a CUDA tensor that owns its memory via cudaMalloc.
44+
45+
The tensor's memory is allocated with cudaMalloc and will be freed
46+
with cudaFree when the tensor is destroyed (via from_blob's deleter).
47+
This is useful for testing that the deleter properly frees memory.
48+
49+
Args:
50+
numel: int - number of elements in the tensor
51+
device: Device - CUDA device
52+
53+
Returns: Tensor - a 1D float32 tensor of zeros
54+
"""
55+
return torch.ops.libtorch_agn_2_11_cuda.my_from_blob_with_cuda_deleter.default(
56+
numel, device
57+
)
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import distutils.command.clean
2+
import shutil
3+
from pathlib import Path
4+
5+
from setuptools import find_packages, setup
6+
7+
import torch
8+
from torch.utils.cpp_extension import (
9+
BuildExtension,
10+
CppExtension,
11+
CUDAExtension,
12+
IS_WINDOWS,
13+
)
14+
15+
16+
ROOT_DIR = Path(__file__).parent
17+
CSRC_DIR = ROOT_DIR / "csrc"
18+
19+
20+
class clean(distutils.command.clean.clean):
21+
def run(self):
22+
# Run default behavior first
23+
distutils.command.clean.clean.run(self)
24+
25+
# Remove extension
26+
for path in (ROOT_DIR / "libtorch_agn_2_11").glob("**/*.so"):
27+
path.unlink()
28+
# Remove build and dist and egg-info directories
29+
dirs = [
30+
ROOT_DIR / "build",
31+
ROOT_DIR / "dist",
32+
ROOT_DIR / "libtorch_agn_2_11.egg-info",
33+
]
34+
for path in dirs:
35+
if path.exists():
36+
shutil.rmtree(str(path), ignore_errors=True)
37+
38+
39+
def get_extension():
40+
extra_compile_args = {
41+
"cxx": [
42+
"-DTORCH_TARGET_VERSION=0x020b000000000000",
43+
],
44+
}
45+
if not IS_WINDOWS:
46+
extra_compile_args["cxx"].append("-fdiagnostics-color=always")
47+
48+
sources = list(CSRC_DIR.glob("**/*.cpp"))
49+
50+
extension = CppExtension
51+
# allow including <cuda_runtime.h>
52+
if torch.cuda.is_available():
53+
extra_compile_args["cxx"].append("-DLAE_USE_CUDA")
54+
extra_compile_args["nvcc"] = ["-O2", "-DUSE_CUDA"]
55+
extension = CUDAExtension
56+
sources.extend(CSRC_DIR.glob("**/*.cu"))
57+
58+
return [
59+
extension(
60+
"libtorch_agn_2_11._C",
61+
sources=sorted(str(s) for s in sources),
62+
py_limited_api=True,
63+
extra_compile_args=extra_compile_args,
64+
extra_link_args=[],
65+
)
66+
]
67+
68+
69+
setup(
70+
name="libtorch_agn_2_11",
71+
version="0.0",
72+
author="PyTorch Core Team",
73+
description="Example of libtorch agnostic extension for PyTorch 2.11+",
74+
packages=find_packages(exclude=("test",)),
75+
package_data={"libtorch_agn_2_11": ["*.dll", "*.dylib", "*.so"]},
76+
install_requires=[
77+
"torch",
78+
],
79+
ext_modules=get_extension(),
80+
cmdclass={
81+
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
82+
"clean": clean,
83+
},
84+
options={"bdist_wheel": {"py_limited_api": "cp39"}},
85+
)

test/cpp_extensions/test_libtorch_agnostic.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Owner(s): ["module: cpp"]
22

3+
import gc
34
import math
45
import sysconfig
56
import unittest
@@ -80,18 +81,19 @@ class TestLibtorchAgnostic(TestCase):
8081
"""
8182
Tests for versioned libtorch_agnostic extensions.
8283
83-
This test class supports testing both:
84+
This test class supports testing:
8485
8586
- libtorch_agn_2_9: Extension built with TORCH_TARGET_VERSION=2.9.0
8687
- libtorch_agn_2_10: Extension built with TORCH_TARGET_VERSION=2.10.0
88+
- libtorch_agn_2_11: Extension built with TORCH_TARGET_VERSION=2.11.0
8789
8890
Tests should be decorated with @skipIfTorchVersionLessThan to indicate the
8991
version that they target.
9092
"""
9193

9294
@classmethod
9395
def setUpClass(cls):
94-
# Build both 2.9 and 2.10 extensions
96+
# Build versioned extensions
9597
base_dir = Path(__file__).parent
9698

9799
try:
@@ -101,7 +103,7 @@ def setUpClass(cls):
101103
extension_root=base_dir / "libtorch_agn_2_9_extension"
102104
)
103105

104-
# Only build 2.10 extension if running on PyTorch 2.10+
106+
# Only build 2.X extension if running on PyTorch 2.X+
105107
import re
106108

107109
version_parts = torch.__version__.split(".")
@@ -119,6 +121,16 @@ def setUpClass(cls):
119121
else:
120122
print(f"Skipping 2.10 extension (running on PyTorch {torch.__version__})")
121123

124+
if (current_major > 2) or (current_major == 2 and current_minor >= 11):
125+
try:
126+
import libtorch_agn_2_11 # noqa: F401
127+
except Exception:
128+
install_cpp_extension(
129+
extension_root=base_dir / "libtorch_agn_2_11_extension"
130+
)
131+
else:
132+
print(f"Skipping 2.11 extension (running on PyTorch {torch.__version__})")
133+
122134
@onlyCPU
123135
def test_slow_sgd(self, device):
124136
import libtorch_agn_2_9 as libtorch_agnostic
@@ -1660,6 +1672,60 @@ def test_my_subtract(self, device):
16601672
expected_broadcast = torch.subtract(a, c)
16611673
self.assertEqual(result_broadcast, expected_broadcast)
16621674

1675+
@skipIfTorchVersionLessThan(2, 11)
1676+
@skipIfTorchDynamo("no data pointer defined for FakeTensor, FunctionalTensor")
1677+
def test_my_from_blob_with_deleter(self, device):
1678+
"""Test for from_blob with custom deleter (2.11 feature)."""
1679+
import libtorch_agn_2_11 as libtorch_agnostic
1680+
1681+
libtorch_agnostic.ops.reset_deleter_call_count()
1682+
self.assertEqual(libtorch_agnostic.ops.get_deleter_call_count(), 0)
1683+
1684+
# We need an original tensor to create the tensor with from_blob.
1685+
original = torch.rand(2, 3, device=device, dtype=torch.float32)
1686+
blob_tensor = libtorch_agnostic.ops.my_from_blob_with_deleter(
1687+
original.data_ptr(),
1688+
original.size(),
1689+
original.stride(),
1690+
device,
1691+
torch.float32,
1692+
)
1693+
1694+
self.assertEqual(blob_tensor, original)
1695+
self.assertEqual(blob_tensor.data_ptr(), original.data_ptr())
1696+
1697+
self.assertEqual(libtorch_agnostic.ops.get_deleter_call_count(), 0)
1698+
1699+
del blob_tensor
1700+
gc.collect()
1701+
1702+
# Ensure the deleter was called. The original tensor still exists and
1703+
# can be used.
1704+
self.assertEqual(libtorch_agnostic.ops.get_deleter_call_count(), 1)
1705+
original += 1
1706+
1707+
@onlyCUDA
1708+
@skipIfTorchVersionLessThan(2, 11)
1709+
def test_my_from_blob_with_cuda_deleter_no_leak(self, device):
1710+
"""Test that from_blob deleter properly frees cudaMalloc'd memory."""
1711+
import libtorch_agn_2_11 as libtorch_agnostic
1712+
1713+
torch.cuda.synchronize(device)
1714+
init_mem = torch.cuda.memory_allocated(device)
1715+
numel = 1024 * 1024 # 4 MB per tensor
1716+
1717+
for _ in range(10):
1718+
tensor = libtorch_agnostic.ops.my_from_blob_with_cuda_deleter(numel, device)
1719+
# Verify tensor was created correctly
1720+
self.assertEqual(tensor.numel(), numel)
1721+
self.assertEqual(tensor.device, torch.device(device))
1722+
del tensor
1723+
gc.collect()
1724+
torch.cuda.synchronize(device)
1725+
1726+
curr_mem = torch.cuda.memory_allocated(device)
1727+
self.assertEqual(curr_mem, init_mem)
1728+
16631729

16641730
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
16651731

0 commit comments

Comments
 (0)