Skip to content

Commit 827495a

Browse files
committed
fix: fix rebase error
1 parent 4f7b2a4 commit 827495a

1 file changed

Lines changed: 0 additions & 24 deletions

File tree

src/cuda/rms_norm/kernel.h

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -43,37 +43,13 @@ class CudaRmsNorm : public RmsNorm {
4343
using T = TypeMapType<ListGet<0>(list_tag)>;
4444
constexpr int kBlockSize = ListGet<1>(list_tag);
4545

46-
<<<<<<< HEAD
47-
#define LAUNCH_RMS_NORM_KERNEL(BLOCK_SIZE) \
48-
RmsNormKernel<BLOCK_SIZE, float, T, T> \
49-
<<<num_blocks, BLOCK_SIZE, 0, cuda_stream>>>( \
50-
reinterpret_cast<T*>(out.data()), stride_out_batch, \
51-
stride_out_nhead, reinterpret_cast<const T*>(input.data()), \
52-
stride_input_batch, stride_input_nhead, \
53-
reinterpret_cast<const T*>(weight.data()), nhead_, dim_, eps);
54-
55-
if (block_size == CUDA_BLOCK_SIZE_2048) {
56-
LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_2048)
57-
} else if (block_size == CUDA_BLOCK_SIZE_1024) {
58-
LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_1024)
59-
} else if (block_size == CUDA_BLOCK_SIZE_512) {
60-
LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_512)
61-
} else if (block_size == CUDA_BLOCK_SIZE_256) {
62-
LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_256)
63-
} else {
64-
LAUNCH_RMS_NORM_KERNEL(CUDA_BLOCK_SIZE_128)
65-
}
66-
67-
#undef LAUNCH_RMS_NORM_KERNEL
68-
=======
6946
RmsNormKernel<kBlockSize, float, T, T>
7047
<<<num_blocks, kBlockSize, 0, cuda_stream>>>(
7148
reinterpret_cast<T*>(out.data()), stride_out_batch,
7249
stride_out_nhead, reinterpret_cast<const T*>(input.data()),
7350
stride_input_batch, stride_input_nhead,
7451
reinterpret_cast<const T*>(weight.data()), nhead_, dim_,
7552
eps_);
76-
>>>>>>> ae94669 (feat: add a convenient interface for any `int64_t`-convertible types and use `DispatchFunc()` to dispatch `DataType` and block sizes with a single call.)
7753
},
7854
"CudaRmsNorm::operator()");
7955
}

0 commit comments

Comments
 (0)