Skip to content

Commit 4f4688b

Browse files
Beef up non-stable description
1 parent a8c5294 commit 4f4688b

1 file changed

Lines changed: 103 additions & 25 deletions

File tree

advanced_source/cpp_custom_ops.rst

Lines changed: 103 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,16 @@ the operation are as follows:
3737
return a * b + c
3838
3939
You can find the end-to-end working example for this tutorial
40-
`here <https://github.com/pytorch/extension-cpp>`_ .
40+
in the `extension-cpp <https://github.com/pytorch/extension-cpp>`_ repository,
41+
which contains two parallel implementations:
42+
43+
- `extension_cpp_stable/ <https://github.com/pytorch/extension-cpp/tree/main/extension_cpp_stable>`_:
44+
Uses APIs supported by the LibTorch Stable ABI (recommended for PyTorch 2.10+). The main body of this
45+
tutorial uses code snippets from this implementation.
46+
- `extension_cpp/ <https://github.com/pytorch/extension-cpp/tree/main/extension_cpp>`_:
47+
Uses the standard ATen/LibTorch API. Use this if you need APIs not yet available in the
48+
stable ABI. Code snippets from this implementation are shown in the
49+
:ref:`reverting-to-non-stable-api` section.
4150

4251
Setting up the Build System
4352
---------------------------
@@ -162,12 +171,12 @@ LibTorch Stable ABI (PyTorch Agnosticism)
162171
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
163172

164173
In addition to CPython agnosticism, there is a second axis of wheel compatibility:
165-
**LibTorch agnosticism**. While CPython agnosticism allows building a single wheel
174+
LibTorch agnosticism. While CPython agnosticism allows building a single wheel
166175
that works across multiple Python versions (3.9, 3.10, 3.11, etc.), LibTorch agnosticism
167176
allows building a single wheel that works across multiple PyTorch versions (2.10, 2.11, 2.12, etc.).
168177
These two concepts are orthogonal and can be combined.
169178

170-
To achieve LibTorch agnosticism, you must use the **LibTorch Stable ABI**, which provides
179+
To achieve LibTorch agnosticism, you must use the LibTorch Stable ABI, which provides
171180
a stable C API for interacting with PyTorch tensors and operators. For example, instead of
172181
using ``at::Tensor``, you must use ``torch::stable::Tensor``. For comprehensive
173182
documentation on the stable ABI, including migration guides, supported types, and
@@ -178,7 +187,10 @@ The setup.py above already includes ``TORCH_TARGET_VERSION=0x020a000000000000``,
178187
the extension targets the LibTorch Stable ABI with a minimum supported PyTorch version of 2.10. The version format is:
179188
``[MAJ 1 byte][MIN 1 byte][PATCH 1 byte][ABI TAG 5 bytes]``, so 2.10.0 = ``0x020a000000000000``.
180189

181-
See the section below for examples of code using the LibTorch Stable ABI.
190+
The sections below contain examples of code using the LibTorch Stable ABI.
191+
If the stable API/ABI does not contain what you need, see the :ref:`reverting-to-non-stable-api` section
192+
or the `extension_cpp/ subdirectory <https://github.com/pytorch/extension-cpp/tree/main/extension_cpp>`_
193+
in the extension-cpp repository for the equivalent examples using the non-stable API.
182194

183195

184196
Defining the custom op and adding backend implementations
@@ -317,27 +329,6 @@ in a separate ``STABLE_TORCH_LIBRARY_IMPL`` block:
317329
m.impl("mymuladd", TORCH_BOX(&mymuladd_cuda));
318330
}
319331
320-
Reverting to the Non-Stable LibTorch API
321-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
322-
323-
The LibTorch Stable ABI/API is still under active development, and certain APIs may not
324-
yet be available in ``torch/csrc/stable``, ``torch/headeronly``, or the C shims
325-
(``torch/csrc/stable/c/shim.h``).
326-
327-
If you need an API that is not yet available in the stable ABI/API, you can revert to
328-
the regular ATen API by:
329-
330-
1. Removing ``-DTORCH_TARGET_VERSION`` from your ``extra_compile_args``
331-
2. Using ``TORCH_LIBRARY`` instead of ``STABLE_TORCH_LIBRARY``
332-
3. Using ``TORCH_LIBRARY_IMPL`` instead of ``STABLE_TORCH_LIBRARY_IMPL``
333-
4. Reverting to ATen APIs (e.g. using ``at::Tensor`` instead of ``torch::stable::Tensor`` etc.)
334-
335-
Note that doing so means you will need to build separate wheels for each PyTorch
336-
version you want to support.
337-
338-
For reference, see the `PyTorch 2.9.1 version of this tutorial <https://github.com/pytorch/tutorials/blob/10eefc3b761a5b5407862b2336493b7ab859640f/advanced_source/cpp_custom_ops.rst>`_
339-
which uses the non-stable API, as well as `this commit of the extension-cpp repository <https://github.com/pytorch/extension-cpp/tree/0ec4969c7bc8e15a8456e5eb9d9ca0a7ec15bc95>`_.
340-
341332
Adding ``torch.compile`` support for an operator
342333
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
343334

@@ -673,6 +664,93 @@ When defining the operator, we must specify that it mutates the out Tensor in th
673664
Do not return any mutated Tensors as outputs of the operator as this will
674665
cause incompatibility with PyTorch subsystems like ``torch.compile``.
675666

667+
.. _reverting-to-non-stable-api:
668+
669+
Reverting to the Non-Stable LibTorch API
670+
----------------------------------------
671+
672+
The LibTorch Stable ABI/API is still under active development, and certain APIs may not
673+
yet be available in ``torch/csrc/stable``, ``torch/headeronly``, or the C shims
674+
(``torch/csrc/stable/c/shim.h``).
675+
676+
If you need an API that is not yet available in the stable ABI/API, you can revert to
677+
the regular ATen API. Note that doing so means you will need to build separate wheels
678+
for each PyTorch version you want to support.
679+
680+
We provide code snippets for ``mymuladd`` below to illustrate. The changes for the
681+
CUDA variant, ``mymul`` and ``myadd_out`` are similar in nature and can be found in the
682+
`extension_cpp/ <https://github.com/pytorch/extension-cpp/tree/main/extension_cpp>`_
683+
subdirectory of the extension-cpp repository.
684+
685+
**Setup (setup.py)**
686+
687+
Remove ``-DTORCH_TARGET_VERSION`` from your ``extra_compile_args``:
688+
689+
.. code-block:: python
690+
691+
extra_compile_args = {
692+
"cxx": [
693+
"-O3" if not debug_mode else "-O0",
694+
"-fdiagnostics-color=always",
695+
"-DPy_LIMITED_API=0x03090000", # min CPython version 3.9
696+
# Note: No -DTORCH_TARGET_VERSION flag
697+
],
698+
"nvcc": [
699+
"-O3" if not debug_mode else "-O0",
700+
],
701+
}
702+
703+
**C++ Implementation (muladd.cpp)**
704+
705+
Use ATen headers and types instead of the stable API:
706+
707+
.. code-block:: cpp
708+
709+
// Use ATen/torch headers instead of torch/csrc/stable headers
710+
#include <ATen/Operators.h>
711+
#include <torch/all.h>
712+
#include <torch/library.h>
713+
714+
namespace extension_cpp {
715+
716+
// Use at::Tensor instead of torch::stable::Tensor
717+
at::Tensor mymuladd_cpu(const at::Tensor& a, const at::Tensor& b, double c) {
718+
// Use TORCH_CHECK instead of STD_TORCH_CHECK
719+
TORCH_CHECK(a.sizes() == b.sizes());
720+
// Use at::kFloat instead of torch::headeronly::ScalarType::Float
721+
TORCH_CHECK(a.dtype() == at::kFloat);
722+
TORCH_CHECK(b.dtype() == at::kFloat);
723+
// Use at::DeviceType instead of torch::headeronly::DeviceType
724+
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU);
725+
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU);
726+
// Use tensor.contiguous() instead of torch::stable::contiguous(tensor)
727+
at::Tensor a_contig = a.contiguous();
728+
at::Tensor b_contig = b.contiguous();
729+
// Use torch::empty() instead of torch::stable::empty_like()
730+
at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options());
731+
// Use data_ptr<T>() instead of const_data_ptr<T>()
732+
const float* a_ptr = a_contig.data_ptr<float>();
733+
const float* b_ptr = b_contig.data_ptr<float>();
734+
float* result_ptr = result.data_ptr<float>();
735+
for (int64_t i = 0; i < result.numel(); i++) {
736+
result_ptr[i] = a_ptr[i] * b_ptr[i] + c;
737+
}
738+
return result;
739+
}
740+
741+
// Use TORCH_LIBRARY instead of STABLE_TORCH_LIBRARY
742+
TORCH_LIBRARY(extension_cpp, m) {
743+
m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");
744+
}
745+
746+
// Use TORCH_LIBRARY_IMPL instead of STABLE_TORCH_LIBRARY_IMPL
747+
TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) {
748+
// Pass function pointer directly instead of wrapping with TORCH_BOX()
749+
m.impl("mymuladd", &mymuladd_cpu);
750+
}
751+
752+
}
753+
676754
Conclusion
677755
----------
678756
In this tutorial, we went over the recommended approach to integrating Custom C++

0 commit comments

Comments
 (0)