@@ -194,13 +194,8 @@ __device__ void PagedAttentionPrefillWarpKernel(
194194 l = l * alpha + beta;
195195 m = m_new;
196196 }
197- #ifdef ENABLE_ILUVATAR_API
198197 alpha = op::paged_attention::cuda::warpBroadcast (alpha, 0 );
199198 beta = op::paged_attention::cuda::warpBroadcast (beta, 0 );
200- #else
201- alpha = __shfl_sync (0xffffffff , alpha, 0 );
202- beta = __shfl_sync (0xffffffff , beta, 0 );
203- #endif
204199
205200#if defined(__CUDA_ARCH__)
206201 if constexpr (std::is_same_v<Tdata, half>) {
@@ -238,11 +233,7 @@ __device__ void PagedAttentionPrefillWarpKernel(
238233 if (lane == 0 ) {
239234 inv_l = 1 .0f / (l + 1e-6f );
240235 }
241- #ifdef ENABLE_ILUVATAR_API
242236 inv_l = op::paged_attention::cuda::warpBroadcast (inv_l, 0 );
243- #else
244- inv_l = __shfl_sync (0xffffffff , inv_l, 0 );
245- #endif
246237
247238#pragma unroll
248239 for (int i = 0 ; i < DIMS_PER_THREAD; ++i) {
@@ -420,13 +411,8 @@ __global__ void PagedAttentionPrefillWarpGlobalKernel(
420411 l = l * alpha + beta;
421412 m = m_new;
422413 }
423- #ifdef ENABLE_ILUVATAR_API
424414 alpha = op::paged_attention::cuda::warpBroadcast (alpha, 0 );
425415 beta = op::paged_attention::cuda::warpBroadcast (beta, 0 );
426- #else
427- alpha = __shfl_sync (0xffffffff , alpha, 0 );
428- beta = __shfl_sync (0xffffffff , beta, 0 );
429- #endif
430416
431417#if defined(__CUDA_ARCH__)
432418 if constexpr (std::is_same_v<Tdata, half>) {
@@ -803,13 +789,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernel(
803789 l = l * alpha + beta;
804790 m = m_new;
805791 }
806- #ifdef ENABLE_ILUVATAR_API
807792 alpha = op::paged_attention::cuda::warpBroadcast (alpha, 0 );
808793 beta = op::paged_attention::cuda::warpBroadcast (beta, 0 );
809- #else
810- alpha = __shfl_sync (0xffffffff , alpha, 0 );
811- beta = __shfl_sync (0xffffffff , beta, 0 );
812- #endif
813794
814795#if defined(__CUDA_ARCH__)
815796 if constexpr (std::is_same_v<Tdata, half>) {
@@ -849,11 +830,7 @@ __device__ void PagedAttentionPrefillWarpCtaKernel(
849830 if (lane == 0 ) {
850831 inv_l = 1 .0f / (l + 1e-6f );
851832 }
852- #ifdef ENABLE_ILUVATAR_API
853833 inv_l = op::paged_attention::cuda::warpBroadcast (inv_l, 0 );
854- #else
855- inv_l = __shfl_sync (0xffffffff , inv_l, 0 );
856- #endif
857834
858835#pragma unroll
859836 for (int i = 0 ; i < DIMS_PER_THREAD; ++i) {
@@ -1297,11 +1274,7 @@ __device__ void PagedAttentionPrefillWarpCtaKernelPipelined(
12971274 if (lane == 0 ) {
12981275 inv_l = 1 .0f / (l + 1e-6f );
12991276 }
1300- #ifdef ENABLE_ILUVATAR_API
13011277 inv_l = op::paged_attention::cuda::warpBroadcast (inv_l, 0 );
1302- #else
1303- inv_l = __shfl_sync (0xffffffff , inv_l, 0 );
1304- #endif
13051278
13061279#pragma unroll
13071280 for (int i = 0 ; i < DIMS_PER_THREAD; ++i) {
@@ -1992,13 +1965,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernelKOnly(
19921965 l = l * alpha + beta;
19931966 m = m_new;
19941967 }
1995- #ifdef ENABLE_ILUVATAR_API
19961968 alpha = op::paged_attention::cuda::warpBroadcast (alpha, 0 );
19971969 beta = op::paged_attention::cuda::warpBroadcast (beta, 0 );
1998- #else
1999- alpha = __shfl_sync (0xffffffff , alpha, 0 );
2000- beta = __shfl_sync (0xffffffff , beta, 0 );
2001- #endif
20021970
20031971#if defined(__CUDA_ARCH__)
20041972 if constexpr (std::is_same_v<Tdata, half>) {
@@ -2038,11 +2006,7 @@ __device__ void PagedAttentionPrefillWarpCtaKernelKOnly(
20382006 if (lane == 0 ) {
20392007 inv_l = 1 .0f / (l + 1e-6f );
20402008 }
2041- #ifdef ENABLE_ILUVATAR_API
20422009 inv_l = op::paged_attention::cuda::warpBroadcast (inv_l, 0 );
2043- #else
2044- inv_l = __shfl_sync (0xffffffff , inv_l, 0 );
2045- #endif
20462010
20472011#pragma unroll
20482012 for (int i = 0 ; i < DIMS_PER_THREAD; ++i) {
@@ -2171,11 +2135,7 @@ __device__ __forceinline__ void PagedAttentionPrefillMmaScoreWriteRow(
21712135 if (lane == 0 ) {
21722136 inv_l = 1 .0f / (l + 1e-6f );
21732137 }
2174- #ifdef ENABLE_ILUVATAR_API
21752138 inv_l = op::paged_attention::cuda::warpBroadcast (inv_l, 0 );
2176- #else
2177- inv_l = __shfl_sync (0xffffffff , inv_l, 0 );
2178- #endif
21792139
21802140 const int64_t q_token = q_start + static_cast <int64_t >(q_token_local);
21812141 half *out_ptr = out_ + q_token * o_stride + static_cast <int64_t >(head_idx) * o_head_stride;
0 commit comments