-
Notifications
You must be signed in to change notification settings - Fork 249
Expand file tree
/
Copy pathmuladd.cu
More file actions
91 lines (79 loc) · 3.5 KB
/
muladd.cu
File metadata and controls
91 lines (79 loc) · 3.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#include <ATen/Operators.h>
#include <torch/all.h>
#include <torch/library.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h>
namespace extension_cpp {
__global__ void muladd_kernel(int numel, const float* a, const float* b, float c, float* result) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numel) result[idx] = a[idx] * b[idx] + c;
}
at::Tensor mymuladd_cuda(const at::Tensor& a, const at::Tensor& b, double c) {
TORCH_CHECK(a.sizes() == b.sizes());
TORCH_CHECK(a.dtype() == at::kFloat);
TORCH_CHECK(b.dtype() == at::kFloat);
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA);
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA);
at::Tensor a_contig = a.contiguous();
at::Tensor b_contig = b.contiguous();
at::Tensor result = at::empty(a_contig.sizes(), a_contig.options());
const float* a_ptr = a_contig.data_ptr<float>();
const float* b_ptr = b_contig.data_ptr<float>();
float* result_ptr = result.data_ptr<float>();
int numel = a_contig.numel();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
muladd_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, c, result_ptr);
return result;
}
__global__ void mul_kernel(int numel, const float* a, const float* b, float* result) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numel) result[idx] = a[idx] * b[idx];
}
at::Tensor mymul_cuda(const at::Tensor& a, const at::Tensor& b) {
TORCH_CHECK(a.sizes() == b.sizes());
TORCH_CHECK(a.dtype() == at::kFloat);
TORCH_CHECK(b.dtype() == at::kFloat);
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA);
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA);
at::Tensor a_contig = a.contiguous();
at::Tensor b_contig = b.contiguous();
at::Tensor result = at::empty(a_contig.sizes(), a_contig.options());
const float* a_ptr = a_contig.data_ptr<float>();
const float* b_ptr = b_contig.data_ptr<float>();
float* result_ptr = result.data_ptr<float>();
int numel = a_contig.numel();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
mul_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, result_ptr);
return result;
}
__global__ void add_kernel(int numel, const float* a, const float* b, float* result) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numel) result[idx] = a[idx] + b[idx];
}
void myadd_out_cuda(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) {
TORCH_CHECK(a.sizes() == b.sizes());
TORCH_CHECK(b.sizes() == out.sizes());
TORCH_CHECK(a.dtype() == at::kFloat);
TORCH_CHECK(b.dtype() == at::kFloat);
TORCH_CHECK(out.dtype() == at::kFloat);
TORCH_CHECK(out.is_contiguous());
TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CUDA);
TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CUDA);
TORCH_INTERNAL_ASSERT(out.device().type() == at::DeviceType::CUDA);
at::Tensor a_contig = a.contiguous();
at::Tensor b_contig = b.contiguous();
const float* a_ptr = a_contig.data_ptr<float>();
const float* b_ptr = b_contig.data_ptr<float>();
float* result_ptr = out.data_ptr<float>();
int numel = a_contig.numel();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
add_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, result_ptr);
}
// Registers CUDA implementations for mymuladd, mymul, myadd_out
TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) {
m.impl("mymuladd", &mymuladd_cuda);
m.impl("mymul", &mymul_cuda);
m.impl("myadd_out", &myadd_out_cuda);
}
}