|
| 1 | +#ifndef __PAGED_ATTENTION_PREFILL_KERNEL_CUH__ |
| 2 | +#define __PAGED_ATTENTION_PREFILL_KERNEL_CUH__ |
| 3 | + |
| 4 | +namespace op::paged_attention_prefill::cuda { |
| 5 | + |
| 6 | +// 辅助函数:二分查找确定当前 global_token_idx 属于哪个 sequence |
| 7 | +__device__ __forceinline__ int find_seq_id(int token_idx, const int64_t *offset, int num_seqs) { |
| 8 | + int low = 0, high = num_seqs - 1; |
| 9 | + while (low <= high) { |
| 10 | + int mid = (low + high) >> 1; |
| 11 | + if (token_idx >= offset[mid] && token_idx < offset[mid + 1]) { |
| 12 | + return mid; |
| 13 | + } else if (token_idx < offset[mid]) { |
| 14 | + high = mid - 1; |
| 15 | + } else { |
| 16 | + low = mid + 1; |
| 17 | + } |
| 18 | + } |
| 19 | + return 0; |
| 20 | +} |
| 21 | + |
| 22 | +template <typename Tdata, typename Tcompute> |
| 23 | +__global__ void pagedAttentionPrefillKernel( |
| 24 | + Tdata *out_, const Tdata *q_, const Tdata *k_cache_, const Tdata *v_cache_, |
| 25 | + const int64_t *block_tables_, const int64_t *cache_lens_, const int64_t *seq_lens_, |
| 26 | + const float *alibi_slopes_, |
| 27 | + const size_t num_heads, const size_t num_kv_heads, const float scale, |
| 28 | + const size_t max_num_blocks_per_seq, const size_t block_size, |
| 29 | + const ptrdiff_t kv_block_stride, const ptrdiff_t kv_head_stride, |
| 30 | + const size_t head_size, |
| 31 | + const int64_t *offset_, |
| 32 | + const size_t num_seqs) { |
| 33 | + |
| 34 | + // --- 使用 2D Grid 坐标 --- |
| 35 | + const int global_token_idx = blockIdx.x; // 展平后的全局 token 索引 |
| 36 | + const int head_idx = blockIdx.y; // Head 索引 |
| 37 | + const int dim_idx = threadIdx.x; // Head 内部维度 |
| 38 | + |
| 39 | + if (dim_idx >= head_size) { |
| 40 | + return; |
| 41 | + } |
| 42 | + |
| 43 | + // --- 通过二分查找 offset 找到所属的 seq_idx --- |
| 44 | + int seq_idx = find_seq_id(global_token_idx, offset_, num_seqs); |
| 45 | + |
| 46 | + // --- 获取该 Sequence 本次 Prefill 的长度 |
| 47 | + const int64_t cur_new_len = seq_lens_[seq_idx]; |
| 48 | + |
| 49 | + // --- 该 token 在当前序列中的相对位置 |
| 50 | + int q_token_idx = global_token_idx - offset_[seq_idx]; |
| 51 | + |
| 52 | + const Tdata *q_ptr_base = q_ + global_token_idx * num_heads * head_size + head_idx * head_size; |
| 53 | + Tdata *out_ptr = out_ + global_token_idx * num_heads * head_size + head_idx * head_size; |
| 54 | + |
| 55 | + // --- KV Cache 相关信息 |
| 56 | + const int64_t total_seq_len = cache_lens_[seq_idx]; |
| 57 | + const int64_t history_len = total_seq_len - cur_new_len; |
| 58 | + const int64_t causal_limit = history_len + q_token_idx; |
| 59 | + |
| 60 | + const size_t num_queries_per_kv = num_heads / num_kv_heads; |
| 61 | + const size_t kv_head_idx = head_idx / num_queries_per_kv; |
| 62 | + const int64_t *block_table = block_tables_ + seq_idx * max_num_blocks_per_seq; |
| 63 | + |
| 64 | + const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx]; |
| 65 | + |
| 66 | + // Pass 1: 计算 Score 并找最大值 |
| 67 | + Tcompute max_score = -FLT_MAX; |
| 68 | + for (int t = 0; t <= causal_limit; ++t) { |
| 69 | + const int64_t b_idx = t / block_size; |
| 70 | + const int64_t t_off = t % block_size; |
| 71 | + const int64_t physical_block_id = block_table[b_idx]; |
| 72 | + const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size; |
| 73 | + |
| 74 | + Tcompute score = 0.0f; |
| 75 | + for (int d = 0; d < head_size; ++d) { |
| 76 | + score += static_cast<Tcompute>(q_ptr_base[d]) * static_cast<Tcompute>(k_vec[d]); |
| 77 | + } |
| 78 | + score *= static_cast<Tcompute>(scale); |
| 79 | + if (alibi_slope != 0.0f) { |
| 80 | + score += alibi_slope * static_cast<float>(t - causal_limit); |
| 81 | + } |
| 82 | + if (score > max_score) { |
| 83 | + max_score = score; |
| 84 | + } |
| 85 | + } |
| 86 | + |
| 87 | + // Pass 2: 计算 Sum of Exp |
| 88 | + Tcompute sum_exp = 0.0f; |
| 89 | + for (int t = 0; t <= causal_limit; ++t) { |
| 90 | + const int64_t b_idx = t / block_size; |
| 91 | + const int64_t t_off = t % block_size; |
| 92 | + const int64_t physical_block_id = block_table[b_idx]; |
| 93 | + const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size; |
| 94 | + |
| 95 | + Tcompute score = 0.0f; |
| 96 | + for (int d = 0; d < head_size; ++d) { |
| 97 | + score += static_cast<Tcompute>(q_ptr_base[d]) * static_cast<Tcompute>(k_vec[d]); |
| 98 | + } |
| 99 | + score *= static_cast<Tcompute>(scale); |
| 100 | + if (alibi_slope != 0.0f) { |
| 101 | + score += alibi_slope * static_cast<float>(t - causal_limit); |
| 102 | + } |
| 103 | + sum_exp += expf(static_cast<float>(score - max_score)); |
| 104 | + } |
| 105 | + |
| 106 | + // Pass 3: 加权求和得到输出 |
| 107 | + Tcompute acc = 0.0f; |
| 108 | + Tcompute inv_sum = 1.0f / (sum_exp + 1e-6f); |
| 109 | + for (int t = 0; t <= causal_limit; ++t) { |
| 110 | + const int64_t b_idx = t / block_size; |
| 111 | + const int64_t t_off = t % block_size; |
| 112 | + const int64_t physical_block_id = block_table[b_idx]; |
| 113 | + |
| 114 | + const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size; |
| 115 | + Tcompute score = 0.0f; |
| 116 | + for (int d = 0; d < head_size; ++d) { |
| 117 | + score += static_cast<Tcompute>(q_ptr_base[d]) * static_cast<Tcompute>(k_vec[d]); |
| 118 | + } |
| 119 | + score *= static_cast<Tcompute>(scale); |
| 120 | + if (alibi_slope != 0.0f) { |
| 121 | + score += alibi_slope * static_cast<float>(t - causal_limit); |
| 122 | + } |
| 123 | + Tcompute prob = expf(static_cast<float>(score - max_score)) * inv_sum; |
| 124 | + |
| 125 | + const Tdata *v_vec = v_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size; |
| 126 | + acc += prob * static_cast<Tcompute>(v_vec[dim_idx]); |
| 127 | + } |
| 128 | + |
| 129 | + out_ptr[dim_idx] = static_cast<Tdata>(acc); |
| 130 | +} |
| 131 | + |
| 132 | +} // namespace op::paged_attention_prefill::cuda |
| 133 | + |
| 134 | +#endif |
0 commit comments