Skip to content

Commit b6de5bf

Browse files
fix: fix weight tying, embedding/transform on cuda
1 parent cc97ea2 commit b6de5bf

6 files changed

Lines changed: 113 additions & 62 deletions

File tree

example/gpt2/net.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,10 @@ GPT2::GPT2(const GPT2Config &config) : config_(config) {
162162
// don't init this one, we will tie weights
163163
modules_[kLMHeadLayerName] = std::make_unique<GPT2Linear>(config.n_embd, config.vocab_size, false, true);
164164
// https://paperswithcode.com/method/weight-tying
165-
mutable_module(kTransformerLayerName)
165+
*mutable_module(kTransformerLayerName)
166166
->mutable_module(kWTELayerName)
167167
->mutable_parameter(GPT2Linear::kParamWeightName)
168-
->reset(module(kLMHeadLayerName).parameter(GPT2Linear::kParamWeightName).get());
168+
= module(kLMHeadLayerName).parameter(GPT2Linear::kParamWeightName);
169169

170170
// init all weights
171171
Apply([&](Module *module) {

infini_train/src/autograd/transform.cc

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
#include "infini_train/include/autograd/transform.h"
22
#include "infini_train/include/kernels/cpu/transform.h"
3+
4+
#ifdef USE_CUDA
5+
#include "infini_train/include/kernels/cuda/transform.h"
6+
#endif
7+
38
#include <vector>
49

510
namespace infini_train::autograd {
@@ -13,6 +18,12 @@ std::vector<std::shared_ptr<Tensor>> Tril::Forward(const std::vector<std::shared
1318
output = kernels::cpu::TrilForward(input, diagonal_);
1419
break;
1520
}
21+
#ifdef USE_CUDA
22+
case DeviceType::kCUDA: {
23+
output = kernels::cuda::TrilForward(input, diagonal_);
24+
break;
25+
}
26+
#endif
1627
default:
1728
LOG(FATAL) << "Unsupported device type: " << static_cast<int>(input->GetDevice().Type());
1829
break;
@@ -29,6 +40,12 @@ std::vector<std::shared_ptr<Tensor>> Tril::Backward(const std::vector<std::share
2940
grad_input = kernels::cpu::TrilBackward(grad_output, diagonal_);
3041
break;
3142
}
43+
#ifdef USE_CUDA
44+
case DeviceType::kCUDA: {
45+
grad_input = kernels::cuda::TrilBackward(grad_output, diagonal_);
46+
break;
47+
}
48+
#endif
3249
default:
3350
LOG(FATAL) << "Unsupported device type: " << static_cast<int>(grad_output->GetDevice().Type());
3451
break;
@@ -46,6 +63,12 @@ std::vector<std::shared_ptr<Tensor>> Transpose::Forward(const std::vector<std::s
4663
output = kernels::cpu::TransposeForward(input, dim0_, dim1_);
4764
break;
4865
}
66+
#ifdef USE_CUDA
67+
case DeviceType::kCUDA: {
68+
output = kernels::cuda::TransposeForward(input, dim0_, dim1_);
69+
break;
70+
}
71+
#endif
4972
default:
5073
LOG(FATAL) << "Unsupported device type: " << static_cast<int>(input->GetDevice().Type());
5174
break;
@@ -62,6 +85,12 @@ std::vector<std::shared_ptr<Tensor>> Transpose::Backward(const std::vector<std::
6285
grad_input = kernels::cpu::TransposeBackward(grad_output, dim0_, dim1_);
6386
break;
6487
}
88+
#ifdef USE_CUDA
89+
case DeviceType::kCUDA: {
90+
grad_input = kernels::cuda::TransposeBackward(grad_output, dim0_, dim1_);
91+
break;
92+
}
93+
#endif
6594
default:
6695
LOG(FATAL) << "Unsupported device type: " << static_cast<int>(grad_output->GetDevice().Type());
6796
break;
@@ -79,6 +108,12 @@ std::vector<std::shared_ptr<Tensor>> Mask::Forward(const std::vector<std::shared
79108
output = kernels::cpu::MaskForward(input, mask_, value_);
80109
break;
81110
}
111+
#ifdef USE_CUDA
112+
case DeviceType::kCUDA: {
113+
output = kernels::cuda::MaskForward(input, mask_, value_);
114+
break;
115+
}
116+
#endif
82117
default:
83118
LOG(FATAL) << "Unsupported device type: " << static_cast<int>(input->GetDevice().Type());
84119
break;
@@ -95,6 +130,12 @@ std::vector<std::shared_ptr<Tensor>> Mask::Backward(const std::vector<std::share
95130
grad_input = kernels::cpu::MaskBackward(grad_output, mask_);
96131
break;
97132
}
133+
#ifdef USE_CUDA
134+
case DeviceType::kCUDA: {
135+
grad_input = kernels::cuda::MaskBackward(grad_output, mask_);
136+
break;
137+
}
138+
#endif
98139
default:
99140
LOG(FATAL) << "Unsupported device type: " << static_cast<int>(grad_output->GetDevice().Type());
100141
break;

infini_train/src/kernels/cuda/embedding.cu

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,10 @@ __global__ void EmbeddingForwardKernel(const uint16_t *input, float *output, con
3636
}
3737

3838
std::shared_ptr<Tensor> EmbeddingForward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &weight) {
39-
CHECK_EQ(input->Dims().size(), 2);
4039
CHECK_EQ(weight->Dims().size(), 2);
4140

42-
const int batch_size = input->Dims()[0];
43-
const int max_seqlen = input->Dims()[1];
41+
const int batch_size = input->Dims().size() == 2 ? input->Dims()[0] : 1;
42+
const int max_seqlen = input->Dims().size() == 2 ? input->Dims()[1] : input->Dims()[0];
4443
const int embed_dim = weight->Dims()[1];
4544

4645
auto output = std::make_shared<Tensor>(std::vector<int64_t>{batch_size, max_seqlen, embed_dim}, DataType::kFLOAT32,
@@ -75,11 +74,10 @@ __global__ void WeightBackwardKernel(float *grad_weight, const float *grad_outpu
7574

7675
std::shared_ptr<Tensor> EmbeddingBackward(const std::shared_ptr<Tensor> &input, const std::vector<int64_t> &weight_dims,
7776
const std::shared_ptr<Tensor> &grad_output) {
78-
CHECK_EQ(input->Dims().size(), 2);
7977
CHECK_EQ(weight_dims.size(), 2);
8078

81-
const int batch_size = input->Dims()[0];
82-
const int max_seqlen = input->Dims()[1];
79+
const int batch_size = input->Dims().size() == 2 ? input->Dims()[0] : 1;
80+
const int max_seqlen = input->Dims().size() == 2 ? input->Dims()[1] : input->Dims()[0];
8381
const int embed_dim = weight_dims[1];
8482

8583
auto grad_weight = std::make_shared<Tensor>(weight_dims, DataType::kFLOAT32, Device(DeviceType::kCUDA, 0));

infini_train/src/kernels/cuda/linear.cu

Lines changed: 55 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,11 @@ std::shared_ptr<Tensor> MatmulForward(const std::shared_ptr<Tensor> &input, cons
5151

5252
std::vector<int64_t> output_dims = input_dims;
5353
output_dims[output_dims.size() - 1] = n;
54-
auto output = std::make_shared<Tensor>(output_dims, DataType::kFLOAT32, input->GetDevice());
54+
auto output = std::make_shared<Tensor>(output_dims, DataType::kFLOAT32, Device(DeviceType::kCUDA, 0));
5555

5656
const float alpha = 1.0f, beta = 0.0f;
5757
cublasHandle_t handle;
58-
cublasCreate(&handle);
58+
CUBLAS_CHECK(cublasCreate(&handle));
5959

6060
// cuBLAS is colmun-major
6161
// output = input * other --> output.T = other.T * input.T
@@ -69,11 +69,16 @@ std::shared_ptr<Tensor> MatmulForward(const std::shared_ptr<Tensor> &input, cons
6969
int64_t stride_a = n * k;
7070
int64_t stride_b = k * m;
7171
int64_t stride_c = m * n;
72-
cublasGemmStridedBatchedEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, other->DataPtr(), CUDA_R_32F, lda,
73-
stride_a, input->DataPtr(), CUDA_R_32F, ldb, stride_b, &beta, output->DataPtr(),
74-
CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT);
75-
76-
cublasDestroy(handle);
72+
// TODO(zbl): check GEMM algo
73+
// CUBLAS_GEMM_DEFAULT might requires TensorCore
74+
// Use CUBLAS_GEMM_ALGO0 to disable TensorCore algos
75+
76+
CUBLAS_CHECK(cublasGemmStridedBatchedEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, other->DataPtr(),
77+
CUDA_R_32F, lda, stride_a, input->DataPtr(), CUDA_R_32F, ldb, stride_b,
78+
&beta, output->DataPtr(), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F,
79+
CUBLAS_GEMM_DEFAULT));
80+
CUDA_CHECK(cudaDeviceSynchronize());
81+
CUBLAS_CHECK(cublasDestroy(handle));
7782
return output;
7883
}
7984

@@ -112,7 +117,7 @@ MatmulBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tenso
112117

113118
float alpha = 1.0f, beta = 0.0f;
114119
cublasHandle_t handle;
115-
cublasCreate(&handle);
120+
CUBLAS_CHECK(cublasCreate(&handle));
116121

117122
{
118123
// cuBLAS is colmun-major
@@ -125,10 +130,10 @@ MatmulBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tenso
125130
const int64_t stride_a = k * n;
126131
const int64_t stride_b = n * m;
127132
const int64_t stride_c = m * k;
128-
cublasGemmStridedBatchedEx(handle, CUBLAS_OP_T, CUBLAS_OP_N, k, m, n, &alpha, other->DataPtr(), CUDA_R_32F, lda,
129-
stride_a, grad_output->DataPtr(), CUDA_R_32F, ldb, stride_b, &beta,
130-
grad_input->DataPtr(), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F,
131-
CUBLAS_GEMM_DEFAULT);
133+
CUBLAS_CHECK(cublasGemmStridedBatchedEx(handle, CUBLAS_OP_T, CUBLAS_OP_N, k, m, n, &alpha, other->DataPtr(),
134+
CUDA_R_32F, lda, stride_a, grad_output->DataPtr(), CUDA_R_32F, ldb,
135+
stride_b, &beta, grad_input->DataPtr(), CUDA_R_32F, ldc, stride_c, bs,
136+
CUDA_R_32F, CUBLAS_GEMM_DEFAULT));
132137
}
133138

134139
{
@@ -142,13 +147,13 @@ MatmulBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tenso
142147
const int64_t stride_a = n * m;
143148
const int64_t stride_b = m * k;
144149
const int64_t stride_c = n * k;
145-
cublasGemmStridedBatchedEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, n, k, m, &alpha, grad_output->DataPtr(),
146-
CUDA_R_32F, lda, stride_a, input->DataPtr(), CUDA_R_32F, ldb, stride_b, &beta,
147-
grad_other->DataPtr(), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F,
148-
CUBLAS_GEMM_DEFAULT);
150+
CUBLAS_CHECK(cublasGemmStridedBatchedEx(handle, CUBLAS_OP_N, CUBLAS_OP_T, n, k, m, &alpha,
151+
grad_output->DataPtr(), CUDA_R_32F, lda, stride_a, input->DataPtr(),
152+
CUDA_R_32F, ldb, stride_b, &beta, grad_other->DataPtr(), CUDA_R_32F,
153+
ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));
149154
}
150155

151-
cublasDestroy(handle);
156+
CUBLAS_CHECK(cublasDestroy(handle));
152157
return {grad_input, grad_other};
153158
}
154159

@@ -163,34 +168,27 @@ std::shared_ptr<Tensor> LinearForward(const std::shared_ptr<Tensor> &input, cons
163168
output[*, out_features] = input[*, in_features] * weight[out_features, in_features]^T + bias[out_features]
164169
*/
165170

166-
CHECK_EQ(input->Dims().size(), 2);
167-
const int64_t bs = input->Dims()[0];
168-
const int64_t in_features = input->Dims()[1];
169-
CHECK_EQ(weight->Dims().size(), 2);
171+
const auto &input_dims = input->Dims();
172+
CHECK_GE(input_dims.size(), 2);
173+
const int64_t bs = std::accumulate(input_dims.rbegin() + 1, input_dims.rend(), 1, std::multiplies<int64_t>{});
174+
const int64_t in_features = *input_dims.rbegin();
175+
176+
const auto &weight_dims = weight->Dims();
177+
CHECK_EQ(weight_dims.size(), 2);
178+
CHECK_EQ(in_features, weight_dims[transpose ? 1 : 0]);
170179

171180
// As for cublas:
172181
// C = alpha * op(B) * op(A) + beta * C
173182
// Dimensions:
174183
// input: (bs, in_features)
175184
// weight: (in_features, out_features) or (out_features, in_features) if transposed
176185
// output: (bs, out_features)
177-
int64_t out_features = 0;
178-
cublasOperation_t op_weight = CUBLAS_OP_N;
179-
180-
if (transpose) {
181-
// weight: (out_features, in_features)
182-
CHECK_EQ(in_features, weight->Dims()[1]);
183-
out_features = weight->Dims()[0];
184-
op_weight = CUBLAS_OP_T;
185-
} else {
186-
// weight: (in_features, out_features)
187-
CHECK_EQ(in_features, weight->Dims()[0]);
188-
out_features = weight->Dims()[1];
189-
op_weight = CUBLAS_OP_N;
190-
}
186+
const int64_t out_features = weight_dims[transpose ? 0 : 1];
187+
cublasOperation_t op_weight = transpose ? CUBLAS_OP_T : CUBLAS_OP_N;
191188

192-
auto output = std::make_shared<Tensor>(std::vector<int64_t>{bs, out_features}, DataType::kFLOAT32,
193-
Device(DeviceType::kCUDA, 0));
189+
auto output_dims = input_dims;
190+
*output_dims.rbegin() = out_features;
191+
auto output = std::make_shared<Tensor>(output_dims, DataType::kFLOAT32, Device(DeviceType::kCUDA, 0));
194192

195193
if (bias) {
196194
CHECK_EQ(bias->Dims().size(), 1);
@@ -206,17 +204,18 @@ std::shared_ptr<Tensor> LinearForward(const std::shared_ptr<Tensor> &input, cons
206204
const float alpha = 1.0f;
207205
const float beta = 1.0f;
208206
cublasHandle_t handle;
209-
cublasCreate(&handle);
207+
CUBLAS_CHECK(cublasCreate(&handle));
210208

211209
// C = alpha * op(B) * op(A) + beta * C
212210
// output = alpha * (input * weight) + beta * output
213211
// TODO(zbl): use cublasSgemv if possible
214-
cublasSgemm(handle, op_weight, CUBLAS_OP_N, out_features, bs, in_features, &alpha,
215-
static_cast<const float *>(weight->DataPtr()), (op_weight == CUBLAS_OP_N) ? out_features : in_features,
216-
static_cast<const float *>(input->DataPtr()), in_features, &beta,
217-
static_cast<float *>(output->DataPtr()), out_features);
212+
CUBLAS_CHECK(cublasSgemm(handle, op_weight, CUBLAS_OP_N, out_features, bs, in_features, &alpha,
213+
static_cast<const float *>(weight->DataPtr()),
214+
(op_weight == CUBLAS_OP_N) ? out_features : in_features,
215+
static_cast<const float *>(input->DataPtr()), in_features, &beta,
216+
static_cast<float *>(output->DataPtr()), out_features));
218217

219-
cublasDestroy(handle);
218+
CUBLAS_CHECK(cublasDestroy(handle));
220219

221220
return {output};
222221
}
@@ -231,13 +230,18 @@ __global__ void set_ones(float *data, int num_elements) {
231230
std::tuple<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>
232231
LinearBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &weight, bool transpose,
233232
int64_t out_features, const std::shared_ptr<Tensor> &grad_output, const bool bias) {
234-
CHECK_EQ(input->Dims().size(), 2);
235-
const int bs = input->Dims()[0];
236-
const int in_features = input->Dims()[1];
237-
CHECK_EQ(weight->Dims().size(), 2);
233+
const auto &input_dims = input->Dims();
234+
CHECK_GE(input_dims.size(), 2);
235+
const int64_t bs = std::accumulate(input_dims.rbegin() + 1, input_dims.rend(), 1, std::multiplies<int64_t>{});
236+
const int64_t in_features = *input_dims.rbegin();
237+
238+
const auto &weight_dims = weight->Dims();
239+
CHECK_EQ(weight_dims.size(), 2);
240+
CHECK_EQ(in_features, weight_dims[transpose ? 1 : 0]);
241+
CHECK_EQ(out_features, weight_dims[transpose ? 0 : 1]);
238242

239-
auto grad_input = std::make_shared<Tensor>(input->Dims(), DataType::kFLOAT32, Device(DeviceType::kCUDA, 0));
240-
auto grad_weight = std::make_shared<Tensor>(weight->Dims(), DataType::kFLOAT32, Device(DeviceType::kCUDA, 0));
243+
auto grad_input = std::make_shared<Tensor>(input_dims, DataType::kFLOAT32, Device(DeviceType::kCUDA, 0));
244+
auto grad_weight = std::make_shared<Tensor>(weight_dims, DataType::kFLOAT32, Device(DeviceType::kCUDA, 0));
241245
grad_weight->Fill<float>(0.0f);
242246
std::shared_ptr<Tensor> grad_bias = nullptr;
243247
if (bias) {
@@ -249,7 +253,7 @@ LinearBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tenso
249253
float alpha = 1.0f;
250254
float beta = 0.0f;
251255
cublasHandle_t handle;
252-
cublasCreate(&handle);
256+
CUBLAS_CHECK(cublasCreate(&handle));
253257

254258
// TODO(zbl): use cublasSgemv if possible
255259
if (transpose) {
@@ -299,7 +303,7 @@ LinearBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tenso
299303
out_features, static_cast<float *>(ones_ptr), 1, &beta, static_cast<float *>(grad_bias->DataPtr()), 1));
300304
}
301305

302-
cublasDestroy(handle);
306+
CUBLAS_CHECK(cublasDestroy(handle));
303307

304308
return {grad_input, grad_weight, grad_bias};
305309
}

infini_train/src/kernels/cuda/no_op.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ std::shared_ptr<Tensor> NoOpForward(const std::shared_ptr<Tensor> &input, const
1717
}
1818

1919
std::shared_ptr<Tensor> NoOpBackward(const std::vector<int64_t> &dims, const std::shared_ptr<Tensor> &grad_output) {
20-
CHECK_EQ(dims.size(), grad_output->Dims().size());
21-
for (int idx = 0; idx < dims.size(); ++idx) { CHECK_EQ(dims[idx], grad_output->Dims()[idx]); }
20+
auto num_elements = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<int64_t>());
21+
CHECK_EQ(num_elements, grad_output->NumElements());
2222

2323
auto grad_input = std::make_shared<Tensor>(*grad_output, 0, dims);
2424
return grad_input;

infini_train/src/kernels/cuda/softmax.cu

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@
1111

1212
namespace infini_train::kernels::cuda {
1313

14+
#define CUDA_CHECK(call) \
15+
do { \
16+
cudaError_t status = call; \
17+
if (status != cudaSuccess) { \
18+
LOG(FATAL) << "CUDA Error: " << cudaGetErrorString(status) << " at " << __FILE__ << ":" << __LINE__; \
19+
} \
20+
} while (0)
21+
1422
template <size_t BLOCK_SIZE, typename T>
1523
__global__ void SoftmaxForwardKernel(T *output, const T *input, int64_t outer_size, int64_t axis_size,
1624
int64_t inner_size) {
@@ -108,7 +116,7 @@ std::shared_ptr<Tensor> SoftmaxForward(const std::shared_ptr<Tensor> &input, int
108116
default:
109117
LOG(FATAL) << "CUDA softmax forward: 'Unsupported data type' at " << __FILE__ << ":" << __LINE__;
110118
}
111-
119+
CUDA_CHECK(cudaDeviceSynchronize());
112120
return output;
113121
}
114122

0 commit comments

Comments
 (0)