@@ -194,8 +194,13 @@ __device__ void PagedAttentionPrefillWarpKernel(
194194 l = l * alpha + beta;
195195 m = m_new;
196196 }
197+ #ifdef ENABLE_ILUVATAR_API
197198 alpha = op::paged_attention::cuda::warpBroadcast (alpha, 0 );
198199 beta = op::paged_attention::cuda::warpBroadcast (beta, 0 );
200+ #else
201+ alpha = __shfl_sync (0xffffffff , alpha, 0 );
202+ beta = __shfl_sync (0xffffffff , beta, 0 );
203+ #endif
199204
200205#if defined(__CUDA_ARCH__)
201206 if constexpr (std::is_same_v<Tdata, half>) {
@@ -233,7 +238,11 @@ __device__ void PagedAttentionPrefillWarpKernel(
233238 if (lane == 0 ) {
234239 inv_l = 1 .0f / (l + 1e-6f );
235240 }
241+ #ifdef ENABLE_ILUVATAR_API
236242 inv_l = op::paged_attention::cuda::warpBroadcast (inv_l, 0 );
243+ #else
244+ inv_l = __shfl_sync (0xffffffff , inv_l, 0 );
245+ #endif
237246
238247#pragma unroll
239248 for (int i = 0 ; i < DIMS_PER_THREAD; ++i) {
@@ -411,8 +420,13 @@ __global__ void PagedAttentionPrefillWarpGlobalKernel(
411420 l = l * alpha + beta;
412421 m = m_new;
413422 }
423+ #ifdef ENABLE_ILUVATAR_API
414424 alpha = op::paged_attention::cuda::warpBroadcast (alpha, 0 );
415425 beta = op::paged_attention::cuda::warpBroadcast (beta, 0 );
426+ #else
427+ alpha = __shfl_sync (0xffffffff , alpha, 0 );
428+ beta = __shfl_sync (0xffffffff , beta, 0 );
429+ #endif
416430
417431#if defined(__CUDA_ARCH__)
418432 if constexpr (std::is_same_v<Tdata, half>) {
@@ -450,7 +464,11 @@ __global__ void PagedAttentionPrefillWarpGlobalKernel(
450464 if (lane == 0 ) {
451465 inv_l = 1 .0f / (l + 1e-6f );
452466 }
467+ #ifdef ENABLE_ILUVATAR_API
453468 inv_l = op::paged_attention::cuda::warpBroadcast (inv_l, 0 );
469+ #else
470+ inv_l = __shfl_sync (0xffffffff , inv_l, 0 );
471+ #endif
454472
455473#pragma unroll
456474 for (int i = 0 ; i < DIMS_PER_THREAD; ++i) {
@@ -785,8 +803,13 @@ __device__ void PagedAttentionPrefillWarpCtaKernel(
785803 l = l * alpha + beta;
786804 m = m_new;
787805 }
806+ #ifdef ENABLE_ILUVATAR_API
788807 alpha = op::paged_attention::cuda::warpBroadcast (alpha, 0 );
789808 beta = op::paged_attention::cuda::warpBroadcast (beta, 0 );
809+ #else
810+ alpha = __shfl_sync (0xffffffff , alpha, 0 );
811+ beta = __shfl_sync (0xffffffff , beta, 0 );
812+ #endif
790813
791814#if defined(__CUDA_ARCH__)
792815 if constexpr (std::is_same_v<Tdata, half>) {
@@ -826,7 +849,11 @@ __device__ void PagedAttentionPrefillWarpCtaKernel(
826849 if (lane == 0 ) {
827850 inv_l = 1 .0f / (l + 1e-6f );
828851 }
852+ #ifdef ENABLE_ILUVATAR_API
829853 inv_l = op::paged_attention::cuda::warpBroadcast (inv_l, 0 );
854+ #else
855+ inv_l = __shfl_sync (0xffffffff , inv_l, 0 );
856+ #endif
830857
831858#pragma unroll
832859 for (int i = 0 ; i < DIMS_PER_THREAD; ++i) {
@@ -1270,7 +1297,11 @@ __device__ void PagedAttentionPrefillWarpCtaKernelPipelined(
12701297 if (lane == 0 ) {
12711298 inv_l = 1 .0f / (l + 1e-6f );
12721299 }
1300+ #ifdef ENABLE_ILUVATAR_API
12731301 inv_l = op::paged_attention::cuda::warpBroadcast (inv_l, 0 );
1302+ #else
1303+ inv_l = __shfl_sync (0xffffffff , inv_l, 0 );
1304+ #endif
12741305
12751306#pragma unroll
12761307 for (int i = 0 ; i < DIMS_PER_THREAD; ++i) {
@@ -1961,8 +1992,13 @@ __device__ void PagedAttentionPrefillWarpCtaKernelKOnly(
19611992 l = l * alpha + beta;
19621993 m = m_new;
19631994 }
1995+ #ifdef ENABLE_ILUVATAR_API
19641996 alpha = op::paged_attention::cuda::warpBroadcast (alpha, 0 );
19651997 beta = op::paged_attention::cuda::warpBroadcast (beta, 0 );
1998+ #else
1999+ alpha = __shfl_sync (0xffffffff , alpha, 0 );
2000+ beta = __shfl_sync (0xffffffff , beta, 0 );
2001+ #endif
19662002
19672003#if defined(__CUDA_ARCH__)
19682004 if constexpr (std::is_same_v<Tdata, half>) {
@@ -2002,7 +2038,11 @@ __device__ void PagedAttentionPrefillWarpCtaKernelKOnly(
20022038 if (lane == 0 ) {
20032039 inv_l = 1 .0f / (l + 1e-6f );
20042040 }
2041+ #ifdef ENABLE_ILUVATAR_API
20052042 inv_l = op::paged_attention::cuda::warpBroadcast (inv_l, 0 );
2043+ #else
2044+ inv_l = __shfl_sync (0xffffffff , inv_l, 0 );
2045+ #endif
20062046
20072047#pragma unroll
20082048 for (int i = 0 ; i < DIMS_PER_THREAD; ++i) {
@@ -2131,7 +2171,11 @@ __device__ __forceinline__ void PagedAttentionPrefillMmaScoreWriteRow(
21312171 if (lane == 0 ) {
21322172 inv_l = 1 .0f / (l + 1e-6f );
21332173 }
2174+ #ifdef ENABLE_ILUVATAR_API
21342175 inv_l = op::paged_attention::cuda::warpBroadcast (inv_l, 0 );
2176+ #else
2177+ inv_l = __shfl_sync (0xffffffff , inv_l, 0 );
2178+ #endif
21352179
21362180 const int64_t q_token = q_start + static_cast <int64_t >(q_token_local);
21372181 half *out_ptr = out_ + q_token * o_stride + static_cast <int64_t >(head_idx) * o_head_stride;
0 commit comments