Skip to content

Commit b5d6a6b

Browse files
Chamberlain0w0kilinchange
authored andcommitted
fix: remove unnecessary code and comments
1 parent 0eda77d commit b5d6a6b

2 files changed

Lines changed: 1 addition & 14 deletions

File tree

infini_train/src/kernels/cuda/linear.cu

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,7 @@ std::shared_ptr<Tensor> MatmulForward(const std::shared_ptr<Tensor> &input, cons
6969
int64_t stride_a = n * k;
7070
int64_t stride_b = k * m;
7171
int64_t stride_c = m * n;
72-
// TODO(zbl): check GEMM algo
73-
// CUBLAS_GEMM_DEFAULT might requires TensorCore
74-
// Use CUBLAS_GEMM_ALGO0 to disable TensorCore algos
75-
72+
// NOTE(zbl): the last cublasGemmAlgo_t param has no effect on GPU arch >= sm_80(Ampere)
7673
CUBLAS_CHECK(cublasGemmStridedBatchedEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, other->DataPtr(),
7774
CUDA_R_32F, lda, stride_a, input->DataPtr(), CUDA_R_32F, ldb, stride_b,
7875
&beta, output->DataPtr(), CUDA_R_32F, ldc, stride_c, bs, CUDA_R_32F,

infini_train/src/kernels/cuda/softmax.cu

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,6 @@
1010
#include "infini_train/include/tensor.h"
1111

1212
namespace infini_train::kernels::cuda {
13-
14-
#define CUDA_CHECK(call) \
15-
do { \
16-
cudaError_t status = call; \
17-
if (status != cudaSuccess) { \
18-
LOG(FATAL) << "CUDA Error: " << cudaGetErrorString(status) << " at " << __FILE__ << ":" << __LINE__; \
19-
} \
20-
} while (0)
21-
2213
template <size_t BLOCK_SIZE, typename T>
2314
__global__ void SoftmaxForwardKernel(T *output, const T *input, int64_t outer_size, int64_t axis_size,
2415
int64_t inner_size) {
@@ -116,7 +107,6 @@ std::shared_ptr<Tensor> SoftmaxForward(const std::shared_ptr<Tensor> &input, int
116107
default:
117108
LOG(FATAL) << "CUDA softmax forward: 'Unsupported data type' at " << __FILE__ << ":" << __LINE__;
118109
}
119-
CUDA_CHECK(cudaDeviceSynchronize());
120110
return output;
121111
}
122112

0 commit comments

Comments
 (0)