Skip to content

Commit 68026bd

Browse files
committed
issue/1008: use warpBroadcast api
1 parent 3d54ce8 commit 68026bd

1 file changed

Lines changed: 0 additions & 40 deletions

File tree

src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -194,13 +194,8 @@ __device__ void PagedAttentionPrefillWarpKernel(
194194
l = l * alpha + beta;
195195
m = m_new;
196196
}
197-
#ifdef ENABLE_ILUVATAR_API
198197
alpha = op::paged_attention::cuda::warpBroadcast(alpha, 0);
199198
beta = op::paged_attention::cuda::warpBroadcast(beta, 0);
200-
#else
201-
alpha = __shfl_sync(0xffffffff, alpha, 0);
202-
beta = __shfl_sync(0xffffffff, beta, 0);
203-
#endif
204199

205200
#if defined(__CUDA_ARCH__)
206201
if constexpr (std::is_same_v<Tdata, half>) {
@@ -238,11 +233,7 @@ __device__ void PagedAttentionPrefillWarpKernel(
238233
if (lane == 0) {
239234
inv_l = 1.0f / (l + 1e-6f);
240235
}
241-
#ifdef ENABLE_ILUVATAR_API
242236
inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0);
243-
#else
244-
inv_l = __shfl_sync(0xffffffff, inv_l, 0);
245-
#endif
246237

247238
#pragma unroll
248239
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
@@ -420,13 +411,8 @@ __global__ void PagedAttentionPrefillWarpGlobalKernel(
420411
l = l * alpha + beta;
421412
m = m_new;
422413
}
423-
#ifdef ENABLE_ILUVATAR_API
424414
alpha = op::paged_attention::cuda::warpBroadcast(alpha, 0);
425415
beta = op::paged_attention::cuda::warpBroadcast(beta, 0);
426-
#else
427-
alpha = __shfl_sync(0xffffffff, alpha, 0);
428-
beta = __shfl_sync(0xffffffff, beta, 0);
429-
#endif
430416

431417
#if defined(__CUDA_ARCH__)
432418
if constexpr (std::is_same_v<Tdata, half>) {
@@ -803,13 +789,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernel(
803789
l = l * alpha + beta;
804790
m = m_new;
805791
}
806-
#ifdef ENABLE_ILUVATAR_API
807792
alpha = op::paged_attention::cuda::warpBroadcast(alpha, 0);
808793
beta = op::paged_attention::cuda::warpBroadcast(beta, 0);
809-
#else
810-
alpha = __shfl_sync(0xffffffff, alpha, 0);
811-
beta = __shfl_sync(0xffffffff, beta, 0);
812-
#endif
813794

814795
#if defined(__CUDA_ARCH__)
815796
if constexpr (std::is_same_v<Tdata, half>) {
@@ -849,11 +830,7 @@ __device__ void PagedAttentionPrefillWarpCtaKernel(
849830
if (lane == 0) {
850831
inv_l = 1.0f / (l + 1e-6f);
851832
}
852-
#ifdef ENABLE_ILUVATAR_API
853833
inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0);
854-
#else
855-
inv_l = __shfl_sync(0xffffffff, inv_l, 0);
856-
#endif
857834

858835
#pragma unroll
859836
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
@@ -1297,11 +1274,7 @@ __device__ void PagedAttentionPrefillWarpCtaKernelPipelined(
12971274
if (lane == 0) {
12981275
inv_l = 1.0f / (l + 1e-6f);
12991276
}
1300-
#ifdef ENABLE_ILUVATAR_API
13011277
inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0);
1302-
#else
1303-
inv_l = __shfl_sync(0xffffffff, inv_l, 0);
1304-
#endif
13051278

13061279
#pragma unroll
13071280
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
@@ -1992,13 +1965,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernelKOnly(
19921965
l = l * alpha + beta;
19931966
m = m_new;
19941967
}
1995-
#ifdef ENABLE_ILUVATAR_API
19961968
alpha = op::paged_attention::cuda::warpBroadcast(alpha, 0);
19971969
beta = op::paged_attention::cuda::warpBroadcast(beta, 0);
1998-
#else
1999-
alpha = __shfl_sync(0xffffffff, alpha, 0);
2000-
beta = __shfl_sync(0xffffffff, beta, 0);
2001-
#endif
20021970

20031971
#if defined(__CUDA_ARCH__)
20041972
if constexpr (std::is_same_v<Tdata, half>) {
@@ -2038,11 +2006,7 @@ __device__ void PagedAttentionPrefillWarpCtaKernelKOnly(
20382006
if (lane == 0) {
20392007
inv_l = 1.0f / (l + 1e-6f);
20402008
}
2041-
#ifdef ENABLE_ILUVATAR_API
20422009
inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0);
2043-
#else
2044-
inv_l = __shfl_sync(0xffffffff, inv_l, 0);
2045-
#endif
20462010

20472011
#pragma unroll
20482012
for (int i = 0; i < DIMS_PER_THREAD; ++i) {
@@ -2171,11 +2135,7 @@ __device__ __forceinline__ void PagedAttentionPrefillMmaScoreWriteRow(
21712135
if (lane == 0) {
21722136
inv_l = 1.0f / (l + 1e-6f);
21732137
}
2174-
#ifdef ENABLE_ILUVATAR_API
21752138
inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0);
2176-
#else
2177-
inv_l = __shfl_sync(0xffffffff, inv_l, 0);
2178-
#endif
21792139

21802140
const int64_t q_token = q_start + static_cast<int64_t>(q_token_local);
21812141
half *out_ptr = out_ + q_token * o_stride + static_cast<int64_t>(head_idx) * o_head_stride;

0 commit comments

Comments
 (0)