Skip to content

Commit 06e6939

Browse files
fix: fix gpt2 runtime errors on cuda, but loss is still NaN
1 parent b6de5bf commit 06e6939

6 files changed

Lines changed: 22 additions & 15 deletions

File tree

example/gpt2/net.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "glog/logging.h"
1313

14+
#include "infini_train/include/device.h"
1415
#include "infini_train/include/nn/functional.h"
1516
#include "infini_train/include/nn/init.h"
1617
#include "infini_train/include/nn/modules/container.h"
@@ -60,6 +61,11 @@ CausalSelfAttention::CausalSelfAttention(const GPT2Config &config)
6061
->View({1, 1, config_.block_size, config_.block_size});
6162
}
6263

64+
void CausalSelfAttention::To(infini_train::Device device) {
65+
nn::Module::To(device);
66+
bias_ = std::make_shared<infini_train::Tensor>(bias_->To(device));
67+
}
68+
6369
std::vector<std::shared_ptr<infini_train::Tensor>>
6470
CausalSelfAttention::Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) {
6571
const auto B = x[0]->Dims()[0]; // bs
@@ -163,8 +169,8 @@ GPT2::GPT2(const GPT2Config &config) : config_(config) {
163169
modules_[kLMHeadLayerName] = std::make_unique<GPT2Linear>(config.n_embd, config.vocab_size, false, true);
164170
// https://paperswithcode.com/method/weight-tying
165171
*mutable_module(kTransformerLayerName)
166-
->mutable_module(kWTELayerName)
167-
->mutable_parameter(GPT2Linear::kParamWeightName)
172+
->mutable_module(kWTELayerName)
173+
->mutable_parameter(GPT2Linear::kParamWeightName)
168174
= module(kLMHeadLayerName).parameter(GPT2Linear::kParamWeightName);
169175

170176
// init all weights

example/gpt2/net.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ class CausalSelfAttention : public infini_train::nn::Module {
3434
std::vector<std::shared_ptr<infini_train::Tensor>>
3535
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;
3636

37+
void To(infini_train::Device device) override;
38+
3739
private:
3840
GPT2Config config_;
3941
int64_t n_head_ = 0;

infini_train/include/nn/modules/module.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class Module {
3131

3232
virtual std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) = 0;
3333

34-
void To(Device device);
34+
virtual void To(Device device);
3535

3636
void Apply(std::function<void(Module *)> fn);
3737

infini_train/src/autograd/elementwise.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ std::vector<std::shared_ptr<Tensor>> Tanh::Backward(const std::vector<std::share
5454
}
5555
#ifdef USE_CUDA
5656
case DeviceType::kCUDA: {
57-
grad_input = kernels::cpu::TanhBackward(grad_output, output);
57+
grad_input = kernels::cuda::TanhBackward(grad_output, output);
5858
break;
5959
}
6060
#endif

infini_train/src/kernels/cuda/elementwise.cu

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ __global__ void BinaryForwardKernel(T *output, Func fn, size_t num_elements_a, s
3838
// launch the given kernel function with the given output and inputs
3939
template <size_t BLOCK_SIZE, typename T, typename Kernel, typename... Inputs>
4040
void LaunchKernel(Kernel &&kernel, const std::shared_ptr<Tensor> &output, const Inputs &...inputs) {
41-
auto extract_ptrs = [](const auto &...ts) { return std::make_tuple(static_cast<T *>(ts->DataPtr())...); };
41+
auto extract_ptrs
42+
= [](const auto &...ts) { return std::make_tuple(static_cast<T *>(ts ? ts->DataPtr() : nullptr)...); };
4243
auto input_ptrs = extract_ptrs(inputs...);
4344

4445
cudaDeviceProp prop;
@@ -135,7 +136,6 @@ void LaunchBackward(FuncA fun_a, FuncB fun_b, const std::shared_ptr<Tensor> &out
135136
T *output_a_ptr = static_cast<T *>(output_a->DataPtr());
136137
T *output_b_ptr = static_cast<T *>(output_b->DataPtr());
137138
const T *grad_output_ptr = static_cast<const T *>(grad_output->DataPtr());
138-
139139
LaunchKernel<BLOCK_SIZE, T>(
140140
[=](dim3 grid, dim3 block, size_t offset, auto... ptrs) {
141141
BinaryBackwardKernel<<<grid, block>>>(output_a_ptr, output_b_ptr, fun_a, fun_b, a_num_elements,
@@ -201,7 +201,6 @@ std::pair<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>
201201
BinaryBackward(const std::shared_ptr<Tensor> &grad_output, const std::shared_ptr<Tensor> &a,
202202
const std::shared_ptr<Tensor> &b, const std::vector<int64_t> &a_dims, const std::vector<int64_t> &b_dims,
203203
FuncA fn_a, FuncB fn_b) {
204-
205204
const auto a_num_elements = std::accumulate(a_dims.begin(), a_dims.end(), 1, std::multiplies<int64_t>());
206205
const auto b_num_elements = std::accumulate(b_dims.begin(), b_dims.end(), 1, std::multiplies<int64_t>());
207206

@@ -212,14 +211,15 @@ BinaryBackward(const std::shared_ptr<Tensor> &grad_output, const std::shared_ptr
212211
if (b) {
213212
CHECK(b_num_elements == b->NumElements());
214213
}
215-
216214
auto dtype = grad_output->Dtype();
217-
auto device = a->GetDevice();
218-
// Currently a and b should have the same data type
219-
CHECK(dtype == b->Dtype());
220-
auto grad_a = std::make_shared<Tensor>(a->Dims(), dtype, device);
221-
auto grad_b = std::make_shared<Tensor>(b->Dims(), dtype, device);
215+
auto device = grad_output->GetDevice();
222216

217+
// Currently a and b should have the same data type
218+
if (a && b) {
219+
CHECK(a->Dtype() == b->Dtype());
220+
}
221+
auto grad_a = std::make_shared<Tensor>(a_dims, dtype, device);
222+
auto grad_b = std::make_shared<Tensor>(b_dims, dtype, device);
223223
switch (dtype) {
224224
case DataType::kFLOAT32:
225225
LaunchBackward<256, float>(fn_a, fn_b, grad_a, grad_b, a_num_elements, b_num_elements, grad_output, a, b);

infini_train/src/kernels/cuda/linear.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ std::shared_ptr<Tensor> MatmulForward(const std::shared_ptr<Tensor> &input, cons
7777
CUDA_R_32F, lda, stride_a, input->DataPtr(), CUDA_R_32F, ldb, stride_b,
7878
&beta, output->DataPtr(), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F,
7979
CUBLAS_GEMM_DEFAULT));
80-
CUDA_CHECK(cudaDeviceSynchronize());
8180
CUBLAS_CHECK(cublasDestroy(handle));
8281
return output;
8382
}
@@ -259,7 +258,7 @@ LinearBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tenso
259258
if (transpose) {
260259
// d_input = d_output * weight --> d_input.T = weight * d_output.T
261260
CUBLAS_CHECK(cublasSgemm(handle, CUBLAS_OP_T, CUBLAS_OP_N, in_features, bs, out_features, &alpha,
262-
static_cast<const float *>(weight->DataPtr()), in_features,
261+
static_cast<const float *>(weight->DataPtr()), out_features,
263262
static_cast<const float *>(grad_output->DataPtr()), out_features, &beta,
264263
static_cast<float *>(grad_input->DataPtr()), in_features));
265264

0 commit comments

Comments
 (0)