triton_bwd is a wrapper around triton that allows you to use Triton kernels in PyTorch autograd.
+from triton_bwd import triton_bwd, autotune
-@triton.autotune(...)
-@triton.jit
+@autotune(...)
+@triton_bwd(in_args=["a", "b"], out_args=["c"])
def my_triton_kernel(a, stride_a, b, stride_b, c, stride_c):
...def compute_something(a, b):
c = torch.zeros_like(a)
- my_triton_kernel[grid](
+ c, = my_triton_kernel.forward(
+ grid,
a, a.stride(0),
b, b.stride(0),
c, c.stride(0),
)
return c # is now differentiable!git clone <this repo>
cd triton_bwd
pip install .