Skip to content

Commit da81c9e

Browse files
Ziminlikilinchange
authored andcommitted
feat: add forward and backward cuda implementation for softmax
1 parent 28c6844 commit da81c9e

1 file changed

Lines changed: 186 additions & 3 deletions

File tree

Lines changed: 186 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,202 @@
11
#include "infini_train/include/kernels/cuda/softmax.h"
22

33
#include <cmath>
4+
#include <cstddef>
45
#include <cstdint>
5-
#include <memory>
6+
#include <cub/block/block_reduce.cuh>
67

78
#include "glog/logging.h"
89

910
#include "infini_train/include/tensor.h"
1011

1112
namespace infini_train::kernels::cuda {
1213

13-
std::shared_ptr<Tensor> SoftmaxForward(const std::shared_ptr<Tensor> &input, int64_t dim) { return nullptr; }
14+
#define CEIL_DIV(x, y) (((x) + (y)-1) / (y))
15+
16+
template <size_t BLOCK_SIZE, typename T>
17+
__global__ void SoftmaxForwardKernel(T *output, const T *input, int64_t outer_size, int64_t axis_size,
18+
int64_t inner_size) {
19+
using BlockReduce = cub::BlockReduce<T, BLOCK_SIZE>;
20+
21+
__shared__ typename BlockReduce::TempStorage temp_storage_max;
22+
__shared__ typename BlockReduce::TempStorage temp_storage_sum;
23+
__shared__ T row_max;
24+
__shared__ T row_sum;
25+
26+
const int64_t group = blockIdx.x; // row of the grid
27+
const int64_t inner_idx = blockIdx.y; // column of the grid
28+
const int tid = threadIdx.x;
29+
30+
// calculate the maximum for each group
31+
T thread_max = -INFINITY;
32+
for (int64_t axis = tid; axis < axis_size; axis += BLOCK_SIZE) {
33+
int64_t idx = (group * axis_size + axis) * inner_size + inner_idx;
34+
thread_max = max(thread_max, input[idx]);
35+
}
36+
T block_max = BlockReduce(temp_storage_max).Reduce(thread_max, cub::Max());
37+
38+
if (tid == 0) {
39+
row_max = block_max;
40+
}
41+
__syncthreads();
42+
43+
// calculate the sum of exponents
44+
T thread_sum = 0;
45+
for (int64_t axis = tid; axis < axis_size; axis += BLOCK_SIZE) {
46+
int64_t idx = (group * axis_size + axis) * inner_size + inner_idx;
47+
T exp_val = exp(input[idx] - row_max);
48+
output[idx] = exp_val;
49+
thread_sum += exp_val;
50+
}
51+
T block_sum = BlockReduce(temp_storage_sum).Sum(thread_sum);
52+
53+
if (tid == 0) {
54+
row_sum = block_sum;
55+
}
56+
__syncthreads();
57+
58+
// normalize
59+
for (int64_t axis = tid; axis < axis_size; axis += BLOCK_SIZE) {
60+
int64_t idx = (group * axis_size + axis) * inner_size + inner_idx;
61+
output[idx] /= row_sum;
62+
}
63+
}
64+
65+
template <size_t BLOCK_SIZE, typename T>
66+
void LaunchForward(const std::shared_ptr<Tensor> &output, const std::shared_ptr<Tensor> &input, int64_t dim) {
67+
const auto &input_dims = input->Dims();
68+
int64_t outer_size = 1;
69+
int64_t axis_size = input_dims[dim];
70+
int64_t inner_size = 1;
71+
72+
for (int i = 0; i < dim; ++i) { outer_size *= input_dims[i]; };
73+
for (int i = dim + 1; i < input_dims.size(); ++i) { inner_size *= input_dims[i]; };
74+
if (axis_size == 0) {
75+
LOG(ERROR) << "CUDA softmax forward: 'input_dims[dim] == 0' at " << __FILE__ << ":" << __LINE__;
76+
}
77+
if (outer_size == 0) {
78+
return;
79+
}
80+
81+
T *output_ptr = reinterpret_cast<T *>(output->DataPtr());
82+
const T *input_ptr = reinterpret_cast<const T *>(input->DataPtr());
83+
84+
cudaDeviceProp prop;
85+
cudaGetDeviceProperties(&prop, input->GetDevice().Index());
86+
87+
if (BLOCK_SIZE > prop.maxThreadsPerBlock) {
88+
LOG(ERROR) << "CUDA softmax forward: 'BLOCK_SIZE used is larger than the max number of thread per block' at "
89+
<< __FILE__ << ":" << __LINE__;
90+
}
91+
dim3 block_dims(BLOCK_SIZE);
92+
dim3 grid_dims(outer_size, inner_size);
93+
94+
SoftmaxForwardKernel<BLOCK_SIZE, T>
95+
<<<grid_dims, block_dims>>>(output_ptr, input_ptr, outer_size, axis_size, inner_size);
96+
}
97+
98+
std::shared_ptr<Tensor> SoftmaxForward(const std::shared_ptr<Tensor> &input, int64_t dim) {
99+
auto dtype = input->Dtype();
100+
const auto &input_dims = input->Dims();
101+
dim = dim < 0 ? dim + input_dims.size() : dim;
102+
CHECK(dim >= 0 && dim < input_dims.size());
103+
auto output = std::make_shared<Tensor>(input_dims, dtype, input->GetDevice());
104+
105+
switch (dtype) {
106+
case DataType::kFLOAT32:
107+
LaunchForward<256, float>(output, input, dim);
108+
break;
109+
default:
110+
return nullptr;
111+
}
112+
113+
return output;
114+
}
115+
116+
template <size_t BLOCK_SIZE, typename T>
117+
__global__ void SoftmaxBackwardKernel(T *grad_input, const T *grad_output, const T *output, int64_t outer_size,
118+
int64_t axis_size, int64_t inner_size) {
119+
using BlockReduce = cub::BlockReduce<T, BLOCK_SIZE>;
120+
121+
__shared__ typename BlockReduce::TempStorage temp_storage_sum;
122+
__shared__ T row_sum;
123+
124+
const int64_t group = blockIdx.x;
125+
const int64_t inner_idx = blockIdx.y;
126+
const int tid = threadIdx.x;
127+
128+
// calculate the sum of the dot product of gradients
129+
T thread_sum = 0;
130+
for (int64_t axis = tid; axis < axis_size; axis += BLOCK_SIZE) {
131+
const int64_t idx = (group * axis_size + axis) * inner_size + inner_idx;
132+
thread_sum += grad_output[idx] * output[idx];
133+
}
134+
T block_sum = BlockReduce(temp_storage_sum).Sum(thread_sum);
135+
136+
if (tid == 0) {
137+
row_sum = block_sum;
138+
}
139+
__syncthreads();
140+
141+
// update the input gradient
142+
for (int64_t axis = tid; axis < axis_size; axis += BLOCK_SIZE) {
143+
const int64_t idx = (group * axis_size + axis) * inner_size + inner_idx;
144+
grad_input[idx] = output[idx] * (grad_output[idx] - row_sum);
145+
}
146+
}
147+
148+
template <size_t BLOCK_SIZE, typename T>
149+
void LaunchBackward(const std::shared_ptr<Tensor> &grad_input, const std::shared_ptr<Tensor> &grad_output,
150+
const std::shared_ptr<Tensor> &output, int64_t dim) {
151+
const auto &output_dims = output->Dims();
152+
int64_t outer_size = 1;
153+
int64_t axis_size = output_dims[dim];
154+
int64_t inner_size = 1;
155+
156+
for (int i = 0; i < dim; ++i) { outer_size *= output_dims[i]; };
157+
for (int i = dim + 1; i < output_dims.size(); ++i) { inner_size *= output_dims[i]; };
158+
if (axis_size == 0) {
159+
LOG(ERROR) << "CUDA softmax backward: 'output_dims[dim] == 0' at " << __FILE__ << ":" << __LINE__;
160+
}
161+
if (outer_size == 0) {
162+
return;
163+
}
164+
165+
T *grad_input_ptr = reinterpret_cast<T *>(grad_input->DataPtr());
166+
const T *grad_output_ptr = reinterpret_cast<const T *>(grad_output->DataPtr());
167+
const T *output_ptr = reinterpret_cast<const T *>(output->DataPtr());
168+
169+
cudaDeviceProp prop;
170+
cudaGetDeviceProperties(&prop, output->GetDevice().Index());
171+
172+
if (BLOCK_SIZE > prop.maxThreadsPerBlock) {
173+
LOG(ERROR) << "CUDA softmax backward: 'BLOCK_SIZE used is larger than the max number of thread per block' at "
174+
<< __FILE__ << ":" << __LINE__;
175+
}
176+
dim3 block(BLOCK_SIZE);
177+
dim3 grid(outer_size, inner_size);
178+
179+
SoftmaxBackwardKernel<BLOCK_SIZE, T>
180+
<<<grid, block>>>(grad_input_ptr, grad_output_ptr, output_ptr, outer_size, axis_size, inner_size);
181+
}
14182

15183
std::shared_ptr<Tensor> SoftmaxBackward(const std::shared_ptr<Tensor> &grad_output,
16184
const std::shared_ptr<Tensor> &output, int64_t dim) {
17-
return nullptr;
185+
auto dtype = output->Dtype();
186+
const auto &output_dims = output->Dims();
187+
dim = dim < 0 ? dim + output->Dims().size() : dim;
188+
CHECK(dim >= 0 && dim < output->Dims().size());
189+
190+
auto grad_input = std::make_shared<Tensor>(output_dims, dtype, output->GetDevice());
191+
192+
switch (dtype) {
193+
case DataType::kFLOAT32:
194+
LaunchBackward<256, float>(grad_input, grad_output, output, dim);
195+
break;
196+
default:
197+
return nullptr;
198+
}
199+
200+
return grad_input;
18201
}
19202
} // namespace infini_train::kernels::cuda

0 commit comments

Comments
 (0)