Skip to content

Commit 31c0af3

Browse files
Merge pull request #866 from InfiniTensor/issue/848_new
issue/848: add paged attention prefill for nvidia gpu with test pass
2 parents 10aa1c3 + 1ba0bcf commit 31c0af3

11 files changed

Lines changed: 1235 additions & 2 deletions

File tree

include/infiniop.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
#include "infiniop/ops/lp_norm.h"
1616
#include "infiniop/ops/mul.h"
1717
#include "infiniop/ops/ones.h"
18+
#include "infiniop/ops/paged_attention.h"
19+
#include "infiniop/ops/paged_attention_prefill.h"
20+
#include "infiniop/ops/paged_caching.h"
1821
#include "infiniop/ops/random_sample.h"
1922
#include "infiniop/ops/rearrange.h"
2023
#include "infiniop/ops/relu.h"
@@ -31,7 +34,5 @@
3134
#include "infiniop/ops/topksoftmax.h"
3235
#include "infiniop/ops/zeros.h"
3336
#include "infiniop/tensor_descriptor.h"
34-
#include "infiniop/ops/paged_attention.h"
35-
#include "infiniop/ops/paged_caching.h"
3637

3738
#endif // __INFINIOP_API_H__
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
#ifndef __INFINIOP_PAGED_ATTENTION_PREFILL_API_H__
2+
#define __INFINIOP_PAGED_ATTENTION_PREFILL_API_H__
3+
4+
#include "../operator_descriptor.h"
5+
6+
// Define an opaque handle for the Paged Attention Prefill descriptor.
7+
typedef struct InfiniopDescriptor *infiniopPagedAttentionPrefillDescriptor_t;
8+
9+
/**
10+
* @brief Creates a descriptor for the Paged Attention Prefill operation.
11+
* @param handle The handle to the InfiniOP library context.
12+
* @param desc_ptr A pointer to store the created descriptor.
13+
* @param out_desc Descriptor for the output tensor.
14+
* @param q_desc Descriptor for the query tensor (packed/flattened).
15+
* @param k_cache_desc Descriptor for the global physical key cache.
16+
* @param v_cache_desc Descriptor for the global physical value cache.
17+
* @param block_tables_desc Descriptor for the block tables mapping logic to physical blocks.
18+
* @param cache_lens_desc Descriptor for the total sequence lengths (history + current).
19+
* @param seq_lens_desc Descriptor for the current prefill sequence lengths.
20+
* @param offset_desc Descriptor for the start position of each sequence in the packed Q tensor.
21+
* @param alibi_slopes_desc Optional descriptor for the ALiBi slopes tensor. Can be NULL.
22+
* @param scale The attention scaling factor.
23+
* @return infiniStatus_t Status code of the operation.
24+
*/
25+
__C __export infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
26+
infiniopHandle_t handle,
27+
infiniopPagedAttentionPrefillDescriptor_t *desc_ptr,
28+
infiniopTensorDescriptor_t out_desc,
29+
infiniopTensorDescriptor_t q_desc,
30+
infiniopTensorDescriptor_t k_cache_desc,
31+
infiniopTensorDescriptor_t v_cache_desc,
32+
infiniopTensorDescriptor_t block_tables_desc,
33+
infiniopTensorDescriptor_t cache_lens_desc,
34+
infiniopTensorDescriptor_t seq_lens_desc,
35+
infiniopTensorDescriptor_t offset_desc,
36+
infiniopTensorDescriptor_t alibi_slopes_desc,
37+
float scale);
38+
39+
/**
40+
* @brief Retrieves the workspace size required for the Paged Attention Prefill operation.
41+
*/
42+
__C __export infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize(
43+
infiniopPagedAttentionPrefillDescriptor_t desc, size_t *size);
44+
45+
/**
46+
* @brief Executes the Paged Attention Prefill operation.
47+
* @param desc The Paged Attention Prefill descriptor.
48+
* @param workspace Pointer to the workspace memory.
49+
* @param workspace_size The size of the workspace.
50+
* @param out Pointer to the output tensor data.
51+
* @param q Pointer to the query tensor data (packed).
52+
* @param k_cache Pointer to the global key cache data.
53+
* @param v_cache Pointer to the global value cache data.
54+
* @param block_tables Pointer to the block tables data.
55+
* @param cache_lens Pointer to the total sequence lengths data.
56+
* @param seq_lens Pointer to the current prefill sequence lengths data.
57+
* @param offset Pointer to the sequence start offsets data.
58+
* @param alibi_slopes Pointer to the ALiBi slopes data. Can be NULL.
59+
* @param stream The CUDA/device stream for the operation.
60+
* @return infiniStatus_t Status code of the operation.
61+
*/
62+
__C __export infiniStatus_t infiniopPagedAttentionPrefill(
63+
infiniopPagedAttentionPrefillDescriptor_t desc,
64+
void *workspace,
65+
size_t workspace_size,
66+
void *out,
67+
const void *q,
68+
const void *k_cache,
69+
const void *v_cache,
70+
const void *block_tables,
71+
const void *cache_lens,
72+
const void *seq_lens,
73+
const void *offset,
74+
const void *alibi_slopes,
75+
void *stream);
76+
77+
/**
78+
* @brief Destroys a Paged Attention Prefill descriptor.
79+
*/
80+
__C __export infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor(
81+
infiniopPagedAttentionPrefillDescriptor_t desc);
82+
83+
#endif // __INFINIOP_PAGED_ATTENTION_PREFILL_API_H__
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
#ifndef __PAGED_ATTENTION_PREFILL_KERNEL_CUH__
2+
#define __PAGED_ATTENTION_PREFILL_KERNEL_CUH__
3+
4+
namespace op::paged_attention_prefill::cuda {
5+
6+
// 辅助函数:二分查找确定当前 global_token_idx 属于哪个 sequence
7+
__device__ __forceinline__ int find_seq_id(int token_idx, const int64_t *offset, int num_seqs) {
8+
int low = 0, high = num_seqs - 1;
9+
while (low <= high) {
10+
int mid = (low + high) >> 1;
11+
if (token_idx >= offset[mid] && token_idx < offset[mid + 1]) {
12+
return mid;
13+
} else if (token_idx < offset[mid]) {
14+
high = mid - 1;
15+
} else {
16+
low = mid + 1;
17+
}
18+
}
19+
return 0;
20+
}
21+
22+
template <typename Tdata, typename Tcompute>
23+
__global__ void pagedAttentionPrefillKernel(
24+
Tdata *out_, const Tdata *q_, const Tdata *k_cache_, const Tdata *v_cache_,
25+
const int64_t *block_tables_, const int64_t *cache_lens_, const int64_t *seq_lens_,
26+
const float *alibi_slopes_,
27+
const size_t num_heads, const size_t num_kv_heads, const float scale,
28+
const size_t max_num_blocks_per_seq, const size_t block_size,
29+
const ptrdiff_t kv_block_stride, const ptrdiff_t kv_head_stride,
30+
const size_t head_size,
31+
const int64_t *offset_,
32+
const size_t num_seqs) {
33+
34+
// --- 使用 2D Grid 坐标 ---
35+
const int global_token_idx = blockIdx.x; // 展平后的全局 token 索引
36+
const int head_idx = blockIdx.y; // Head 索引
37+
const int dim_idx = threadIdx.x; // Head 内部维度
38+
39+
if (dim_idx >= head_size) {
40+
return;
41+
}
42+
43+
// --- 通过二分查找 offset 找到所属的 seq_idx ---
44+
int seq_idx = find_seq_id(global_token_idx, offset_, num_seqs);
45+
46+
// --- 获取该 Sequence 本次 Prefill 的长度
47+
const int64_t cur_new_len = seq_lens_[seq_idx];
48+
49+
// --- 该 token 在当前序列中的相对位置
50+
int q_token_idx = global_token_idx - offset_[seq_idx];
51+
52+
const Tdata *q_ptr_base = q_ + global_token_idx * num_heads * head_size + head_idx * head_size;
53+
Tdata *out_ptr = out_ + global_token_idx * num_heads * head_size + head_idx * head_size;
54+
55+
// --- KV Cache 相关信息
56+
const int64_t total_seq_len = cache_lens_[seq_idx];
57+
const int64_t history_len = total_seq_len - cur_new_len;
58+
const int64_t causal_limit = history_len + q_token_idx;
59+
60+
const size_t num_queries_per_kv = num_heads / num_kv_heads;
61+
const size_t kv_head_idx = head_idx / num_queries_per_kv;
62+
const int64_t *block_table = block_tables_ + seq_idx * max_num_blocks_per_seq;
63+
64+
const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx];
65+
66+
// Pass 1: 计算 Score 并找最大值
67+
Tcompute max_score = -FLT_MAX;
68+
for (int t = 0; t <= causal_limit; ++t) {
69+
const int64_t b_idx = t / block_size;
70+
const int64_t t_off = t % block_size;
71+
const int64_t physical_block_id = block_table[b_idx];
72+
const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size;
73+
74+
Tcompute score = 0.0f;
75+
for (int d = 0; d < head_size; ++d) {
76+
score += static_cast<Tcompute>(q_ptr_base[d]) * static_cast<Tcompute>(k_vec[d]);
77+
}
78+
score *= static_cast<Tcompute>(scale);
79+
if (alibi_slope != 0.0f) {
80+
score += alibi_slope * static_cast<float>(t - causal_limit);
81+
}
82+
if (score > max_score) {
83+
max_score = score;
84+
}
85+
}
86+
87+
// Pass 2: 计算 Sum of Exp
88+
Tcompute sum_exp = 0.0f;
89+
for (int t = 0; t <= causal_limit; ++t) {
90+
const int64_t b_idx = t / block_size;
91+
const int64_t t_off = t % block_size;
92+
const int64_t physical_block_id = block_table[b_idx];
93+
const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size;
94+
95+
Tcompute score = 0.0f;
96+
for (int d = 0; d < head_size; ++d) {
97+
score += static_cast<Tcompute>(q_ptr_base[d]) * static_cast<Tcompute>(k_vec[d]);
98+
}
99+
score *= static_cast<Tcompute>(scale);
100+
if (alibi_slope != 0.0f) {
101+
score += alibi_slope * static_cast<float>(t - causal_limit);
102+
}
103+
sum_exp += expf(static_cast<float>(score - max_score));
104+
}
105+
106+
// Pass 3: 加权求和得到输出
107+
Tcompute acc = 0.0f;
108+
Tcompute inv_sum = 1.0f / (sum_exp + 1e-6f);
109+
for (int t = 0; t <= causal_limit; ++t) {
110+
const int64_t b_idx = t / block_size;
111+
const int64_t t_off = t % block_size;
112+
const int64_t physical_block_id = block_table[b_idx];
113+
114+
const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size;
115+
Tcompute score = 0.0f;
116+
for (int d = 0; d < head_size; ++d) {
117+
score += static_cast<Tcompute>(q_ptr_base[d]) * static_cast<Tcompute>(k_vec[d]);
118+
}
119+
score *= static_cast<Tcompute>(scale);
120+
if (alibi_slope != 0.0f) {
121+
score += alibi_slope * static_cast<float>(t - causal_limit);
122+
}
123+
Tcompute prob = expf(static_cast<float>(score - max_score)) * inv_sum;
124+
125+
const Tdata *v_vec = v_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size;
126+
acc += prob * static_cast<Tcompute>(v_vec[dim_idx]);
127+
}
128+
129+
out_ptr[dim_idx] = static_cast<Tdata>(acc);
130+
}
131+
132+
} // namespace op::paged_attention_prefill::cuda
133+
134+
#endif
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
#ifndef __PAGED_ATTENTION_PREFILL_INFO_H__
2+
#define __PAGED_ATTENTION_PREFILL_INFO_H__
3+
4+
#include "../../../utils.h"
5+
#include "../../tensor.h"
6+
#include <iostream>
7+
#include <optional>
8+
#include <vector>
9+
10+
namespace op::paged_attention_prefill {
11+
12+
class PagedAttentionPrefillInfo {
13+
PagedAttentionPrefillInfo() = default;
14+
15+
public:
16+
infiniDtype_t dtype;
17+
float scale;
18+
19+
size_t num_seqs;
20+
size_t num_heads;
21+
size_t num_kv_heads;
22+
size_t head_size;
23+
size_t block_size;
24+
size_t max_num_blocks_per_seq;
25+
size_t total_q_tokens;
26+
27+
ptrdiff_t q_stride;
28+
ptrdiff_t kv_block_stride;
29+
ptrdiff_t kv_head_stride;
30+
ptrdiff_t o_stride;
31+
32+
static utils::Result<PagedAttentionPrefillInfo> create(
33+
infiniopTensorDescriptor_t out_desc,
34+
infiniopTensorDescriptor_t q_desc,
35+
infiniopTensorDescriptor_t k_cache_desc,
36+
infiniopTensorDescriptor_t v_cache_desc,
37+
infiniopTensorDescriptor_t block_tables_desc,
38+
infiniopTensorDescriptor_t cache_lens_desc,
39+
infiniopTensorDescriptor_t seq_lens_desc,
40+
infiniopTensorDescriptor_t offset_desc,
41+
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
42+
float scale) {
43+
44+
auto dtype = q_desc->dtype();
45+
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32);
46+
47+
if (out_desc->dtype() != dtype || k_cache_desc->dtype() != dtype || v_cache_desc->dtype() != dtype) {
48+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
49+
}
50+
if (offset_desc->dtype() != INFINI_DTYPE_I64 || seq_lens_desc->dtype() != INFINI_DTYPE_I64) {
51+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
52+
}
53+
54+
if (alibi_slopes_desc.has_value() && alibi_slopes_desc.value() != nullptr) {
55+
std::cerr << "[Error] PagedAttentionPrefill: ALiBi slopes are not supported yet." << std::endl;
56+
return INFINI_STATUS_BAD_PARAM;
57+
}
58+
59+
// Q shape: [total_tokens, heads, dim] (3D)
60+
auto q_shape = q_desc->shape();
61+
if (q_shape.size() < 3) {
62+
return INFINI_STATUS_BAD_TENSOR_SHAPE;
63+
}
64+
size_t total_q_tokens = q_shape[0];
65+
66+
size_t num_heads = q_shape[q_shape.size() - 2];
67+
size_t head_size = q_shape[q_shape.size() - 1];
68+
69+
if (head_size != 128) {
70+
std::cerr << "[Error] PagedAttentionPrefill head_size = 128 supported, got " << head_size << std::endl;
71+
return INFINI_STATUS_BAD_TENSOR_SHAPE;
72+
}
73+
74+
// 从 seq_lens 获取 num_seqs
75+
size_t num_seqs = seq_lens_desc->shape()[0];
76+
77+
auto k_cache_shape = k_cache_desc->shape();
78+
size_t num_kv_heads = k_cache_shape[1];
79+
size_t block_size = v_cache_desc->shape()[2];
80+
size_t max_num_blocks_per_seq = block_tables_desc->shape()[1];
81+
82+
// 提取步长,需要保持多个请求的 Q 连续
83+
ptrdiff_t q_stride = q_desc->stride(0);
84+
ptrdiff_t kv_block_stride = k_cache_desc->stride(0);
85+
ptrdiff_t kv_head_stride = k_cache_desc->stride(1);
86+
ptrdiff_t o_stride = out_desc->stride(0);
87+
88+
return utils::Result<PagedAttentionPrefillInfo>(PagedAttentionPrefillInfo{
89+
dtype,
90+
scale,
91+
num_seqs,
92+
num_heads,
93+
num_kv_heads,
94+
head_size,
95+
block_size,
96+
max_num_blocks_per_seq,
97+
total_q_tokens,
98+
q_stride,
99+
kv_block_stride,
100+
kv_head_stride,
101+
o_stride});
102+
}
103+
};
104+
105+
} // namespace op::paged_attention_prefill
106+
107+
#endif

0 commit comments

Comments
 (0)