Skip to content

Commit caa61e9

Browse files
Merge pull request #868 from InfiniTensor/issue/847
Issue/847 paged attention prefill一段式接口
2 parents 31c0af3 + 99b940b commit caa61e9

15 files changed

Lines changed: 206 additions & 46 deletions

File tree

include/infinicore/common/hash.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "../tensor.hpp"
44

5+
#include <optional>
56
#include <type_traits>
67

78
namespace infinicore {
@@ -24,6 +25,15 @@ inline void hash_combine(size_t &seed, Tensor tensor) {
2425
}
2526
}
2627

28+
// Specialization for optional
29+
template <typename T>
30+
inline void hash_combine(size_t &seed, const std::optional<T> &opt) {
31+
hash_combine(seed, opt.has_value());
32+
if (opt) {
33+
hash_combine(seed, *opt);
34+
}
35+
}
36+
2737
// Specialization for std::string
2838
inline void hash_combine(size_t &seed, const std::string &str) {
2939
hash_combine(seed, std::hash<std::string>{}(str));

include/infinicore/ops.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "ops/matmul.hpp"
77
#include "ops/ones.hpp"
88
#include "ops/paged_attention.hpp"
9+
#include "ops/paged_attention_prefill.hpp"
910
#include "ops/paged_caching.hpp"
1011
#include "ops/random_sample.hpp"
1112
#include "ops/rearrange.hpp"

include/infinicore/ops/paged_attention.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ namespace infinicore::op {
99
class PagedAttention {
1010
public:
1111
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, std::optional<Tensor>, float);
12-
static void execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float);
12+
static void execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float);
1313
static common::OpDispatcher<schema> &dispatcher();
1414
};
1515

16-
Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale);
17-
void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale);
16+
Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale);
17+
void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale);
1818
} // namespace infinicore::op
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#pragma once
2+
3+
#include "../device.hpp"
4+
#include "common/op.hpp"
5+
#include <optional>
6+
7+
namespace infinicore::op {
8+
9+
class PagedAttentionPrefill {
10+
public:
11+
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, std::optional<Tensor>, float);
12+
static void execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional<Tensor> alibi_slopes, float);
13+
static common::OpDispatcher<schema> &dispatcher();
14+
};
15+
16+
Tensor paged_attention_prefill(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional<Tensor> alibi_slopes, float scale);
17+
void paged_attention_prefill_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional<Tensor> alibi_slopes, float scale);
18+
} // namespace infinicore::op

python/infinicore/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from infinicore.ops.mul import mul
4646
from infinicore.ops.narrow import narrow
4747
from infinicore.ops.paged_attention import paged_attention
48+
from infinicore.ops.paged_attention_prefill import paged_attention_prefill
4849
from infinicore.ops.paged_caching import paged_caching
4950
from infinicore.ops.rearrange import rearrange
5051
from infinicore.ops.squeeze import squeeze
@@ -119,6 +120,7 @@
119120
"from_torch",
120121
"paged_caching",
121122
"paged_attention",
123+
"paged_attention_prefill",
122124
"ones",
123125
"strided_empty",
124126
"strided_from_blob",

python/infinicore/ops/paged_attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ def paged_attention(
77
k_cache: Tensor,
88
v_cache: Tensor,
99
block_tables: Tensor,
10-
seq_lens: Tensor,
10+
cache_lens: Tensor,
1111
alibi_slopes: Tensor | None = None,
1212
scale: float = 1.0,
1313
*,
@@ -20,7 +20,7 @@ def paged_attention(
2020
k_cache._underlying,
2121
v_cache._underlying,
2222
block_tables._underlying,
23-
seq_lens._underlying,
23+
cache_lens._underlying,
2424
alibi_slopes._underlying if alibi_slopes is not None else None,
2525
scale,
2626
)
@@ -32,7 +32,7 @@ def paged_attention(
3232
k_cache._underlying,
3333
v_cache._underlying,
3434
block_tables._underlying,
35-
seq_lens._underlying,
35+
cache_lens._underlying,
3636
alibi_slopes._underlying if alibi_slopes is not None else None,
3737
scale,
3838
)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from infinicore.lib import _infinicore
2+
from infinicore.tensor import Tensor
3+
4+
5+
def paged_attention_prefill(
6+
q: Tensor,
7+
k_cache: Tensor,
8+
v_cache: Tensor,
9+
block_tables: Tensor,
10+
cache_lens: Tensor,
11+
seq_lens: Tensor,
12+
seq_offsets: Tensor,
13+
alibi_slopes: Tensor | None = None,
14+
scale: float = 1.0,
15+
*,
16+
out: Tensor | None = None,
17+
):
18+
if out is None:
19+
return Tensor(
20+
_infinicore.paged_attention_prefill(
21+
q._underlying,
22+
k_cache._underlying,
23+
v_cache._underlying,
24+
block_tables._underlying,
25+
cache_lens._underlying,
26+
seq_lens._underlying,
27+
seq_offsets._underlying,
28+
alibi_slopes._underlying if alibi_slopes is not None else None,
29+
scale,
30+
)
31+
)
32+
33+
_infinicore.paged_attention_prefill_(
34+
out._underlying,
35+
q._underlying,
36+
k_cache._underlying,
37+
v_cache._underlying,
38+
block_tables._underlying,
39+
cache_lens._underlying,
40+
seq_lens._underlying,
41+
seq_offsets._underlying,
42+
alibi_slopes._underlying if alibi_slopes is not None else None,
43+
scale,
44+
)
45+
46+
return out

src/infinicore/ops/paged_attention/paged_attention.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,20 @@ common::OpDispatcher<PagedAttention::schema> &PagedAttention::dispatcher() {
99
return dispatcher_;
1010
};
1111

12-
void PagedAttention::execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale) {
13-
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, block_tables, seq_lens);
12+
void PagedAttention::execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale) {
13+
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, block_tables, cache_lens);
1414
infinicore::context::setDevice(out->device());
15-
dispatcher().lookup(out->device().getType())(out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, scale);
15+
dispatcher().lookup(out->device().getType())(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, scale);
1616
}
1717

18-
Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale) {
18+
Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale) {
1919
auto out = Tensor::empty(q->shape(), q->dtype(), q->device());
20-
paged_attention_(out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, scale);
20+
paged_attention_(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, scale);
2121
return out;
2222
}
2323

24-
void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale) {
25-
PagedAttention::execute(out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, scale);
24+
void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale) {
25+
PagedAttention::execute(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, scale);
2626
}
2727

2828
} // namespace infinicore::op

src/infinicore/ops/paged_attention/paged_attention_infiniop.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ thread_local common::OpCache<size_t, infiniopPagedAttentionDescriptor_t> caches(
1515
}
1616
});
1717

18-
void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale) {
19-
size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, seq_lens);
18+
void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale) {
19+
size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, scale);
2020

2121
auto device = context::getDevice();
2222
auto &cache = caches.getCache(device);
@@ -27,7 +27,7 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor bloc
2727
if (!desc_opt) {
2828
INFINICORE_CHECK_ERROR(infiniopCreatePagedAttentionDescriptor(
2929
context::getInfiniopHandle(device), &desc,
30-
out->desc(), q->desc(), k_cache->desc(), v_cache->desc(), block_tables->desc(), seq_lens->desc(),
30+
out->desc(), q->desc(), k_cache->desc(), v_cache->desc(), block_tables->desc(), cache_lens->desc(),
3131
alibi_slopes.has_value() ? alibi_slopes.value()->desc() : nullptr,
3232
scale));
3333
cache.put(seed, desc);
@@ -41,7 +41,7 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor bloc
4141

4242
INFINICORE_CHECK_ERROR(infiniopPagedAttention(
4343
desc, workspace->data(), workspace_size,
44-
out->data(), q->data(), k_cache->data(), v_cache->data(), block_tables->data(), seq_lens->data(),
44+
out->data(), q->data(), k_cache->data(), v_cache->data(), block_tables->data(), cache_lens->data(),
4545
alibi_slopes.has_value() ? alibi_slopes.value()->data() : nullptr,
4646
context::getStream()));
4747
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#include "infinicore/ops/paged_attention_prefill.hpp"
2+
3+
#include "../../utils.hpp"
4+
5+
namespace infinicore::op {
6+
7+
common::OpDispatcher<PagedAttentionPrefill::schema> &PagedAttentionPrefill::dispatcher() {
8+
static common::OpDispatcher<PagedAttentionPrefill::schema> dispatcher_;
9+
return dispatcher_;
10+
};
11+
12+
void PagedAttentionPrefill::execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional<Tensor> alibi_slopes, float scale) {
13+
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, block_tables, cache_lens);
14+
infinicore::context::setDevice(out->device());
15+
dispatcher().lookup(out->device().getType())(out, q, k_cache, v_cache, block_tables, cache_lens, seq_lens, seq_offsets, alibi_slopes, scale);
16+
}
17+
18+
Tensor paged_attention_prefill(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional<Tensor> alibi_slopes, float scale) {
19+
auto out = Tensor::empty(q->shape(), q->dtype(), q->device());
20+
paged_attention_prefill_(out, q, k_cache, v_cache, block_tables, cache_lens, seq_lens, seq_offsets, alibi_slopes, scale);
21+
return out;
22+
}
23+
24+
void paged_attention_prefill_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional<Tensor> alibi_slopes, float scale) {
25+
PagedAttentionPrefill::execute(out, q, k_cache, v_cache, block_tables, cache_lens, seq_lens, seq_offsets, alibi_slopes, scale);
26+
}
27+
28+
} // namespace infinicore::op

0 commit comments

Comments
 (0)