|
1 | 1 | #ifndef __PAGED_ATTENTION_PREFILL_KERNEL_V2_CUH__ |
2 | 2 | #define __PAGED_ATTENTION_PREFILL_KERNEL_V2_CUH__ |
3 | 3 |
|
4 | | -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) |
| 4 | +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API) |
5 | 5 | #include <cuda_bf16.h> |
6 | 6 | #include <cuda_fp16.h> |
7 | 7 | #include <cuda_runtime.h> |
@@ -194,8 +194,8 @@ __device__ void PagedAttentionPrefillWarpKernel( |
194 | 194 | l = l * alpha + beta; |
195 | 195 | m = m_new; |
196 | 196 | } |
197 | | - alpha = __shfl_sync(0xffffffff, alpha, 0); |
198 | | - beta = __shfl_sync(0xffffffff, beta, 0); |
| 197 | + alpha = op::paged_attention::cuda::warpBroadcast(alpha, 0); |
| 198 | + beta = op::paged_attention::cuda::warpBroadcast(beta, 0); |
199 | 199 |
|
200 | 200 | #if defined(__CUDA_ARCH__) |
201 | 201 | if constexpr (std::is_same_v<Tdata, half>) { |
@@ -233,7 +233,7 @@ __device__ void PagedAttentionPrefillWarpKernel( |
233 | 233 | if (lane == 0) { |
234 | 234 | inv_l = 1.0f / (l + 1e-6f); |
235 | 235 | } |
236 | | - inv_l = __shfl_sync(0xffffffff, inv_l, 0); |
| 236 | + inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0); |
237 | 237 |
|
238 | 238 | #pragma unroll |
239 | 239 | for (int i = 0; i < DIMS_PER_THREAD; ++i) { |
@@ -411,8 +411,8 @@ __global__ void PagedAttentionPrefillWarpGlobalKernel( |
411 | 411 | l = l * alpha + beta; |
412 | 412 | m = m_new; |
413 | 413 | } |
414 | | - alpha = __shfl_sync(0xffffffff, alpha, 0); |
415 | | - beta = __shfl_sync(0xffffffff, beta, 0); |
| 414 | + alpha = op::paged_attention::cuda::warpBroadcast(alpha, 0); |
| 415 | + beta = op::paged_attention::cuda::warpBroadcast(beta, 0); |
416 | 416 |
|
417 | 417 | #if defined(__CUDA_ARCH__) |
418 | 418 | if constexpr (std::is_same_v<Tdata, half>) { |
@@ -450,7 +450,11 @@ __global__ void PagedAttentionPrefillWarpGlobalKernel( |
450 | 450 | if (lane == 0) { |
451 | 451 | inv_l = 1.0f / (l + 1e-6f); |
452 | 452 | } |
| 453 | +#ifdef ENABLE_ILUVATAR_API |
| 454 | + inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0); |
| 455 | +#else |
453 | 456 | inv_l = __shfl_sync(0xffffffff, inv_l, 0); |
| 457 | +#endif |
454 | 458 |
|
455 | 459 | #pragma unroll |
456 | 460 | for (int i = 0; i < DIMS_PER_THREAD; ++i) { |
@@ -785,8 +789,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernel( |
785 | 789 | l = l * alpha + beta; |
786 | 790 | m = m_new; |
787 | 791 | } |
788 | | - alpha = __shfl_sync(0xffffffff, alpha, 0); |
789 | | - beta = __shfl_sync(0xffffffff, beta, 0); |
| 792 | + alpha = op::paged_attention::cuda::warpBroadcast(alpha, 0); |
| 793 | + beta = op::paged_attention::cuda::warpBroadcast(beta, 0); |
790 | 794 |
|
791 | 795 | #if defined(__CUDA_ARCH__) |
792 | 796 | if constexpr (std::is_same_v<Tdata, half>) { |
@@ -826,7 +830,7 @@ __device__ void PagedAttentionPrefillWarpCtaKernel( |
826 | 830 | if (lane == 0) { |
827 | 831 | inv_l = 1.0f / (l + 1e-6f); |
828 | 832 | } |
829 | | - inv_l = __shfl_sync(0xffffffff, inv_l, 0); |
| 833 | + inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0); |
830 | 834 |
|
831 | 835 | #pragma unroll |
832 | 836 | for (int i = 0; i < DIMS_PER_THREAD; ++i) { |
@@ -1270,7 +1274,7 @@ __device__ void PagedAttentionPrefillWarpCtaKernelPipelined( |
1270 | 1274 | if (lane == 0) { |
1271 | 1275 | inv_l = 1.0f / (l + 1e-6f); |
1272 | 1276 | } |
1273 | | - inv_l = __shfl_sync(0xffffffff, inv_l, 0); |
| 1277 | + inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0); |
1274 | 1278 |
|
1275 | 1279 | #pragma unroll |
1276 | 1280 | for (int i = 0; i < DIMS_PER_THREAD; ++i) { |
@@ -1961,8 +1965,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernelKOnly( |
1961 | 1965 | l = l * alpha + beta; |
1962 | 1966 | m = m_new; |
1963 | 1967 | } |
1964 | | - alpha = __shfl_sync(0xffffffff, alpha, 0); |
1965 | | - beta = __shfl_sync(0xffffffff, beta, 0); |
| 1968 | + alpha = op::paged_attention::cuda::warpBroadcast(alpha, 0); |
| 1969 | + beta = op::paged_attention::cuda::warpBroadcast(beta, 0); |
1966 | 1970 |
|
1967 | 1971 | #if defined(__CUDA_ARCH__) |
1968 | 1972 | if constexpr (std::is_same_v<Tdata, half>) { |
@@ -2002,7 +2006,7 @@ __device__ void PagedAttentionPrefillWarpCtaKernelKOnly( |
2002 | 2006 | if (lane == 0) { |
2003 | 2007 | inv_l = 1.0f / (l + 1e-6f); |
2004 | 2008 | } |
2005 | | - inv_l = __shfl_sync(0xffffffff, inv_l, 0); |
| 2009 | + inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0); |
2006 | 2010 |
|
2007 | 2011 | #pragma unroll |
2008 | 2012 | for (int i = 0; i < DIMS_PER_THREAD; ++i) { |
@@ -2131,7 +2135,7 @@ __device__ __forceinline__ void PagedAttentionPrefillMmaScoreWriteRow( |
2131 | 2135 | if (lane == 0) { |
2132 | 2136 | inv_l = 1.0f / (l + 1e-6f); |
2133 | 2137 | } |
2134 | | - inv_l = __shfl_sync(0xffffffff, inv_l, 0); |
| 2138 | + inv_l = op::paged_attention::cuda::warpBroadcast(inv_l, 0); |
2135 | 2139 |
|
2136 | 2140 | const int64_t q_token = q_start + static_cast<int64_t>(q_token_local); |
2137 | 2141 | half *out_ptr = out_ + q_token * o_stride + static_cast<int64_t>(head_idx) * o_head_stride; |
|
0 commit comments