@@ -51,11 +51,11 @@ std::shared_ptr<Tensor> MatmulForward(const std::shared_ptr<Tensor> &input, cons
5151
5252 std::vector<int64_t > output_dims = input_dims;
5353 output_dims[output_dims.size () - 1 ] = n;
54- auto output = std::make_shared<Tensor>(output_dims, DataType::kFLOAT32 , input-> GetDevice ( ));
54+ auto output = std::make_shared<Tensor>(output_dims, DataType::kFLOAT32 , Device (DeviceType:: kCUDA , 0 ));
5555
5656 const float alpha = 1 .0f , beta = 0 .0f ;
5757 cublasHandle_t handle;
58- cublasCreate (&handle);
58+ CUBLAS_CHECK ( cublasCreate (&handle) );
5959
6060 // cuBLAS is colmun-major
6161 // output = input * other --> output.T = other.T * input.T
@@ -69,11 +69,16 @@ std::shared_ptr<Tensor> MatmulForward(const std::shared_ptr<Tensor> &input, cons
6969 int64_t stride_a = n * k;
7070 int64_t stride_b = k * m;
7171 int64_t stride_c = m * n;
72- cublasGemmStridedBatchedEx (handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, other->DataPtr (), CUDA_R_32F, lda,
73- stride_a, input->DataPtr (), CUDA_R_32F, ldb, stride_b, &beta, output->DataPtr (),
74- CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT);
75-
76- cublasDestroy (handle);
72+ // TODO(zbl): check GEMM algo
73+ // CUBLAS_GEMM_DEFAULT might requires TensorCore
74+ // Use CUBLAS_GEMM_ALGO0 to disable TensorCore algos
75+
76+ CUBLAS_CHECK (cublasGemmStridedBatchedEx (handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, other->DataPtr (),
77+ CUDA_R_32F, lda, stride_a, input->DataPtr (), CUDA_R_32F, ldb, stride_b,
78+ &beta, output->DataPtr (), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F,
79+ CUBLAS_GEMM_DEFAULT));
80+ CUDA_CHECK (cudaDeviceSynchronize ());
81+ CUBLAS_CHECK (cublasDestroy (handle));
7782 return output;
7883}
7984
@@ -112,7 +117,7 @@ MatmulBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tenso
112117
113118 float alpha = 1 .0f , beta = 0 .0f ;
114119 cublasHandle_t handle;
115- cublasCreate (&handle);
120+ CUBLAS_CHECK ( cublasCreate (&handle) );
116121
117122 {
118123 // cuBLAS is colmun-major
@@ -125,10 +130,10 @@ MatmulBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tenso
125130 const int64_t stride_a = k * n;
126131 const int64_t stride_b = n * m;
127132 const int64_t stride_c = m * k;
128- cublasGemmStridedBatchedEx (handle, CUBLAS_OP_T, CUBLAS_OP_N, k, m, n, &alpha, other->DataPtr (), CUDA_R_32F, lda ,
129- stride_a, grad_output->DataPtr (), CUDA_R_32F, ldb, stride_b, &beta ,
130- grad_input->DataPtr (), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F ,
131- CUBLAS_GEMM_DEFAULT);
133+ CUBLAS_CHECK ( cublasGemmStridedBatchedEx (handle, CUBLAS_OP_T, CUBLAS_OP_N, k, m, n, &alpha, other->DataPtr (),
134+ CUDA_R_32F, lda, stride_a, grad_output->DataPtr (), CUDA_R_32F, ldb,
135+ stride_b, &beta, grad_input->DataPtr (), CUDA_R_32F, ldc, stride_c, bs,
136+ CUDA_R_32F, CUBLAS_GEMM_DEFAULT) );
132137 }
133138
134139 {
@@ -142,13 +147,13 @@ MatmulBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tenso
142147 const int64_t stride_a = n * m;
143148 const int64_t stride_b = m * k;
144149 const int64_t stride_c = n * k;
145- cublasGemmStridedBatchedEx (handle, CUBLAS_OP_N, CUBLAS_OP_T, n, k, m, &alpha, grad_output-> DataPtr () ,
146- CUDA_R_32F, lda, stride_a, input ->DataPtr (), CUDA_R_32F, ldb, stride_b, &beta ,
147- grad_other-> DataPtr (), CUDA_R_32F, ldc, stride_c, bs , CUDA_R_32F,
148- CUBLAS_GEMM_DEFAULT);
150+ CUBLAS_CHECK ( cublasGemmStridedBatchedEx (handle, CUBLAS_OP_N, CUBLAS_OP_T, n, k, m, &alpha,
151+ grad_output ->DataPtr (), CUDA_R_32F, lda, stride_a, input-> DataPtr () ,
152+ CUDA_R_32F, ldb, stride_b, &beta, grad_other-> DataPtr () , CUDA_R_32F,
153+ ldc, stride_c, bs, CUDA_R_32F, CUBLAS_GEMM_DEFAULT) );
149154 }
150155
151- cublasDestroy (handle);
156+ CUBLAS_CHECK ( cublasDestroy (handle) );
152157 return {grad_input, grad_other};
153158}
154159
@@ -163,34 +168,27 @@ std::shared_ptr<Tensor> LinearForward(const std::shared_ptr<Tensor> &input, cons
163168 output[*, out_features] = input[*, in_features] * weight[out_features, in_features]^T + bias[out_features]
164169 */
165170
166- CHECK_EQ (input->Dims ().size (), 2 );
167- const int64_t bs = input->Dims ()[0 ];
168- const int64_t in_features = input->Dims ()[1 ];
169- CHECK_EQ (weight->Dims ().size (), 2 );
171+ const auto &input_dims = input->Dims ();
172+ CHECK_GE (input_dims.size (), 2 );
173+ const int64_t bs = std::accumulate (input_dims.rbegin () + 1 , input_dims.rend (), 1 , std::multiplies<int64_t >{});
174+ const int64_t in_features = *input_dims.rbegin ();
175+
176+ const auto &weight_dims = weight->Dims ();
177+ CHECK_EQ (weight_dims.size (), 2 );
178+ CHECK_EQ (in_features, weight_dims[transpose ? 1 : 0 ]);
170179
171180 // As for cublas:
172181 // C = alpha * op(B) * op(A) + beta * C
173182 // Dimensions:
174183 // input: (bs, in_features)
175184 // weight: (in_features, out_features) or (out_features, in_features) if transposed
176185 // output: (bs, out_features)
177- int64_t out_features = 0 ;
178- cublasOperation_t op_weight = CUBLAS_OP_N;
179-
180- if (transpose) {
181- // weight: (out_features, in_features)
182- CHECK_EQ (in_features, weight->Dims ()[1 ]);
183- out_features = weight->Dims ()[0 ];
184- op_weight = CUBLAS_OP_T;
185- } else {
186- // weight: (in_features, out_features)
187- CHECK_EQ (in_features, weight->Dims ()[0 ]);
188- out_features = weight->Dims ()[1 ];
189- op_weight = CUBLAS_OP_N;
190- }
186+ const int64_t out_features = weight_dims[transpose ? 0 : 1 ];
187+ cublasOperation_t op_weight = transpose ? CUBLAS_OP_T : CUBLAS_OP_N;
191188
192- auto output = std::make_shared<Tensor>(std::vector<int64_t >{bs, out_features}, DataType::kFLOAT32 ,
193- Device (DeviceType::kCUDA , 0 ));
189+ auto output_dims = input_dims;
190+ *output_dims.rbegin () = out_features;
191+ auto output = std::make_shared<Tensor>(output_dims, DataType::kFLOAT32 , Device (DeviceType::kCUDA , 0 ));
194192
195193 if (bias) {
196194 CHECK_EQ (bias->Dims ().size (), 1 );
@@ -206,17 +204,18 @@ std::shared_ptr<Tensor> LinearForward(const std::shared_ptr<Tensor> &input, cons
206204 const float alpha = 1 .0f ;
207205 const float beta = 1 .0f ;
208206 cublasHandle_t handle;
209- cublasCreate (&handle);
207+ CUBLAS_CHECK ( cublasCreate (&handle) );
210208
211209 // C = alpha * op(B) * op(A) + beta * C
212210 // output = alpha * (input * weight) + beta * output
213211 // TODO(zbl): use cublasSgemv if possible
214- cublasSgemm (handle, op_weight, CUBLAS_OP_N, out_features, bs, in_features, &alpha,
215- static_cast <const float *>(weight->DataPtr ()), (op_weight == CUBLAS_OP_N) ? out_features : in_features,
216- static_cast <const float *>(input->DataPtr ()), in_features, &beta,
217- static_cast <float *>(output->DataPtr ()), out_features);
212+ CUBLAS_CHECK (cublasSgemm (handle, op_weight, CUBLAS_OP_N, out_features, bs, in_features, &alpha,
213+ static_cast <const float *>(weight->DataPtr ()),
214+ (op_weight == CUBLAS_OP_N) ? out_features : in_features,
215+ static_cast <const float *>(input->DataPtr ()), in_features, &beta,
216+ static_cast <float *>(output->DataPtr ()), out_features));
218217
219- cublasDestroy (handle);
218+ CUBLAS_CHECK ( cublasDestroy (handle) );
220219
221220 return {output};
222221}
@@ -231,13 +230,18 @@ __global__ void set_ones(float *data, int num_elements) {
231230std::tuple<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>
232231LinearBackward (const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &weight, bool transpose,
233232 int64_t out_features, const std::shared_ptr<Tensor> &grad_output, const bool bias) {
234- CHECK_EQ (input->Dims ().size (), 2 );
235- const int bs = input->Dims ()[0 ];
236- const int in_features = input->Dims ()[1 ];
237- CHECK_EQ (weight->Dims ().size (), 2 );
233+ const auto &input_dims = input->Dims ();
234+ CHECK_GE (input_dims.size (), 2 );
235+ const int64_t bs = std::accumulate (input_dims.rbegin () + 1 , input_dims.rend (), 1 , std::multiplies<int64_t >{});
236+ const int64_t in_features = *input_dims.rbegin ();
237+
238+ const auto &weight_dims = weight->Dims ();
239+ CHECK_EQ (weight_dims.size (), 2 );
240+ CHECK_EQ (in_features, weight_dims[transpose ? 1 : 0 ]);
241+ CHECK_EQ (out_features, weight_dims[transpose ? 0 : 1 ]);
238242
239- auto grad_input = std::make_shared<Tensor>(input-> Dims () , DataType::kFLOAT32 , Device (DeviceType::kCUDA , 0 ));
240- auto grad_weight = std::make_shared<Tensor>(weight-> Dims () , DataType::kFLOAT32 , Device (DeviceType::kCUDA , 0 ));
243+ auto grad_input = std::make_shared<Tensor>(input_dims , DataType::kFLOAT32 , Device (DeviceType::kCUDA , 0 ));
244+ auto grad_weight = std::make_shared<Tensor>(weight_dims , DataType::kFLOAT32 , Device (DeviceType::kCUDA , 0 ));
241245 grad_weight->Fill <float >(0 .0f );
242246 std::shared_ptr<Tensor> grad_bias = nullptr ;
243247 if (bias) {
@@ -249,7 +253,7 @@ LinearBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tenso
249253 float alpha = 1 .0f ;
250254 float beta = 0 .0f ;
251255 cublasHandle_t handle;
252- cublasCreate (&handle);
256+ CUBLAS_CHECK ( cublasCreate (&handle) );
253257
254258 // TODO(zbl): use cublasSgemv if possible
255259 if (transpose) {
@@ -299,7 +303,7 @@ LinearBackward(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tenso
299303 out_features, static_cast <float *>(ones_ptr), 1 , &beta, static_cast <float *>(grad_bias->DataPtr ()), 1 ));
300304 }
301305
302- cublasDestroy (handle);
306+ CUBLAS_CHECK ( cublasDestroy (handle) );
303307
304308 return {grad_input, grad_weight, grad_bias};
305309}
0 commit comments