Skip to content

Commit 947d42a

Browse files
committed
resolve comments
Signed-off-by: Ceng23333 <441651826@qq.com>
1 parent f8a3368 commit 947d42a

10 files changed

Lines changed: 264 additions & 458 deletions

File tree

include/infinicore/ops/infllmv2_attention.hpp

Lines changed: 59 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -61,59 +61,29 @@ INFINICORE_GRAPH_OP_CLASS(
6161
//
6262
// Returns:
6363
// [total_q, nheads, head_dim]
64-
void infllmv2_varlen_(Tensor out,
65-
const Tensor &q,
66-
const Tensor &k,
67-
const Tensor &v,
68-
const Tensor &cu_seqlens_q,
69-
const Tensor &cu_seqlens_k,
70-
int max_seqlen_q,
71-
int max_seqlen_k,
72-
float scale,
73-
bool causal,
74-
int window_size_left = -1,
75-
int window_size_right = -1);
76-
Tensor infllmv2_varlen(const Tensor &q,
77-
const Tensor &k,
78-
const Tensor &v,
79-
const Tensor &cu_seqlens_q,
80-
const Tensor &cu_seqlens_k,
81-
int max_seqlen_q,
82-
int max_seqlen_k,
83-
float scale,
84-
bool causal,
85-
int window_size_left = -1,
86-
int window_size_right = -1);
87-
88-
// Preferred names (attention-disambiguated). These are header-only aliases to the
89-
// backward-compatible `infllmv2_*` symbols to avoid adding extra exported ABI.
90-
inline void infllmv2_attention_varlen_(Tensor out,
91-
const Tensor &q,
92-
const Tensor &k,
93-
const Tensor &v,
94-
const Tensor &cu_seqlens_q,
95-
const Tensor &cu_seqlens_k,
96-
int max_seqlen_q,
97-
int max_seqlen_k,
98-
float scale,
99-
bool causal,
100-
int window_size_left = -1,
101-
int window_size_right = -1) {
102-
infllmv2_varlen_(out, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, scale, causal, window_size_left, window_size_right);
103-
}
104-
inline Tensor infllmv2_attention_varlen(const Tensor &q,
105-
const Tensor &k,
106-
const Tensor &v,
107-
const Tensor &cu_seqlens_q,
108-
const Tensor &cu_seqlens_k,
109-
int max_seqlen_q,
110-
int max_seqlen_k,
111-
float scale,
112-
bool causal,
113-
int window_size_left = -1,
114-
int window_size_right = -1) {
115-
return infllmv2_varlen(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, scale, causal, window_size_left, window_size_right);
116-
}
64+
void infllmv2_attention_varlen_(Tensor out,
65+
const Tensor &q,
66+
const Tensor &k,
67+
const Tensor &v,
68+
const Tensor &cu_seqlens_q,
69+
const Tensor &cu_seqlens_k,
70+
int max_seqlen_q,
71+
int max_seqlen_k,
72+
float scale,
73+
bool causal,
74+
int window_size_left = -1,
75+
int window_size_right = -1);
76+
Tensor infllmv2_attention_varlen(const Tensor &q,
77+
const Tensor &k,
78+
const Tensor &v,
79+
const Tensor &cu_seqlens_q,
80+
const Tensor &cu_seqlens_k,
81+
int max_seqlen_q,
82+
int max_seqlen_k,
83+
float scale,
84+
bool causal,
85+
int window_size_left = -1,
86+
int window_size_right = -1);
11787

11888
// Decode-time InfLLM-V2 attention with KV cache.
11989
//
@@ -125,104 +95,55 @@ inline Tensor infllmv2_attention_varlen(const Tensor &q,
12595
//
12696
// Returns:
12797
// [batch, seqlen_q, nheads, head_dim]
128-
void infllmv2_kvcache_(Tensor out,
129-
const Tensor &q,
130-
const Tensor &k_cache,
131-
const Tensor &v_cache,
132-
const Tensor &cache_lens,
133-
float scale,
134-
bool causal,
135-
int window_size_left = -1,
136-
int window_size_right = -1);
137-
Tensor infllmv2_kvcache(const Tensor &q,
138-
const Tensor &k_cache,
139-
const Tensor &v_cache,
140-
const Tensor &cache_lens,
141-
float scale,
142-
bool causal,
143-
int window_size_left = -1,
144-
int window_size_right = -1);
98+
void infllmv2_attention_kvcache_(Tensor out,
99+
const Tensor &q,
100+
const Tensor &k_cache,
101+
const Tensor &v_cache,
102+
const Tensor &cache_lens,
103+
float scale,
104+
bool causal,
105+
int window_size_left = -1,
106+
int window_size_right = -1);
107+
Tensor infllmv2_attention_kvcache(const Tensor &q,
108+
const Tensor &k_cache,
109+
const Tensor &v_cache,
110+
const Tensor &cache_lens,
111+
float scale,
112+
bool causal,
113+
int window_size_left = -1,
114+
int window_size_right = -1);
145115

146-
inline void infllmv2_attention_kvcache_(Tensor out,
116+
// Decode-time InfLLM-V2 attention with KV cache, updating cache in-place.
117+
//
118+
// Shapes:
119+
// q : [batch, seqlen_q, nheads, head_dim]
120+
// k_cache : [batch, seqlen_cache, nheads_k, head_dim] (dense cache)
121+
// v_cache : same as k_cache
122+
// k_new/v_new: [batch, seqlen_new, nheads_k, head_dim] (new KV to append at cache_lens offsets)
123+
// cache_lens : [batch] (int32) current KV length per sequence BEFORE appending
124+
//
125+
// Returns:
126+
// [batch, seqlen_q, nheads, head_dim]
127+
void infllmv2_attention_kvcache_update_(Tensor out,
147128
const Tensor &q,
148129
const Tensor &k_cache,
149130
const Tensor &v_cache,
131+
const Tensor &k_new,
132+
const Tensor &v_new,
150133
const Tensor &cache_lens,
151134
float scale,
152135
bool causal,
153136
int window_size_left = -1,
154-
int window_size_right = -1) {
155-
infllmv2_kvcache_(out, q, k_cache, v_cache, cache_lens, scale, causal, window_size_left, window_size_right);
156-
}
157-
inline Tensor infllmv2_attention_kvcache(const Tensor &q,
137+
int window_size_right = -1);
138+
Tensor infllmv2_attention_kvcache_update(const Tensor &q,
158139
const Tensor &k_cache,
159140
const Tensor &v_cache,
141+
const Tensor &k_new,
142+
const Tensor &v_new,
160143
const Tensor &cache_lens,
161144
float scale,
162145
bool causal,
163146
int window_size_left = -1,
164-
int window_size_right = -1) {
165-
return infllmv2_kvcache(q, k_cache, v_cache, cache_lens, scale, causal, window_size_left, window_size_right);
166-
}
167-
168-
// Decode-time InfLLM-V2 attention with KV cache, updating cache in-place.
169-
//
170-
// Shapes:
171-
// q : [batch, seqlen_q, nheads, head_dim]
172-
// k_cache : [batch, seqlen_cache, nheads_k, head_dim] (dense cache)
173-
// v_cache : same as k_cache
174-
// k_new/v_new: [batch, seqlen_new, nheads_k, head_dim] (new KV to append at cache_lens offsets)
175-
// cache_lens : [batch] (int32) current KV length per sequence BEFORE appending
176-
//
177-
// Returns:
178-
// [batch, seqlen_q, nheads, head_dim]
179-
void infllmv2_kvcache_update_(Tensor out,
180-
const Tensor &q,
181-
const Tensor &k_cache,
182-
const Tensor &v_cache,
183-
const Tensor &k_new,
184-
const Tensor &v_new,
185-
const Tensor &cache_lens,
186-
float scale,
187-
bool causal,
188-
int window_size_left = -1,
189-
int window_size_right = -1);
190-
Tensor infllmv2_kvcache_update(const Tensor &q,
191-
const Tensor &k_cache,
192-
const Tensor &v_cache,
193-
const Tensor &k_new,
194-
const Tensor &v_new,
195-
const Tensor &cache_lens,
196-
float scale,
197-
bool causal,
198-
int window_size_left = -1,
199-
int window_size_right = -1);
200-
201-
inline void infllmv2_attention_kvcache_update_(Tensor out,
202-
const Tensor &q,
203-
const Tensor &k_cache,
204-
const Tensor &v_cache,
205-
const Tensor &k_new,
206-
const Tensor &v_new,
207-
const Tensor &cache_lens,
208-
float scale,
209-
bool causal,
210-
int window_size_left = -1,
211-
int window_size_right = -1) {
212-
infllmv2_kvcache_update_(out, q, k_cache, v_cache, k_new, v_new, cache_lens, scale, causal, window_size_left, window_size_right);
213-
}
214-
inline Tensor infllmv2_attention_kvcache_update(const Tensor &q,
215-
const Tensor &k_cache,
216-
const Tensor &v_cache,
217-
const Tensor &k_new,
218-
const Tensor &v_new,
219-
const Tensor &cache_lens,
220-
float scale,
221-
bool causal,
222-
int window_size_left = -1,
223-
int window_size_right = -1) {
224-
return infllmv2_kvcache_update(q, k_cache, v_cache, k_new, v_new, cache_lens, scale, causal, window_size_left, window_size_right);
225-
}
147+
int window_size_right = -1);
226148

227149
} // namespace infinicore::op
228-

python/infinicore/__init__.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,10 @@
8383
from infinicore.ops.hypot import hypot
8484
from infinicore.ops.index_add import index_add
8585
from infinicore.ops.index_copy import index_copy
86-
from infinicore.ops.infllmv2_attention import infllmv2_kvcache, infllmv2_varlen
86+
from infinicore.ops.infllmv2_attention import (
87+
infllmv2_attention_kvcache,
88+
infllmv2_attention_varlen,
89+
)
8790
from infinicore.ops.inner import inner
8891
from infinicore.ops.kron import kron
8992
from infinicore.ops.kthvalue import kthvalue
@@ -195,8 +198,8 @@
195198
"block_diag",
196199
"kron",
197200
"bitwise_right_shift",
198-
"infllmv2_varlen",
199-
"infllmv2_kvcache",
201+
"infllmv2_attention_varlen",
202+
"infllmv2_attention_kvcache",
200203
"simple_gla_attention",
201204
"simple_gla_decode_step",
202205
"simple_gla_prefill",

python/infinicore/ops/infllmv2_attention.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,21 @@
66
from infinicore.lib import _infinicore
77
from infinicore.tensor import Tensor
88

9-
_native_infllmv2_varlen = getattr(_infinicore, "infllmv2_varlen", None)
10-
_native_infllmv2_kvcache = getattr(_infinicore, "infllmv2_kvcache", None)
9+
_native_infllmv2_attention_varlen = getattr(
10+
_infinicore, "infllmv2_attention_varlen", None
11+
)
12+
_native_infllmv2_attention_kvcache = getattr(
13+
_infinicore, "infllmv2_attention_kvcache", None
14+
)
1115

1216
_MISSING_MSG = (
13-
"infllmv2_varlen / infllmv2_kvcache not found in _infinicore. "
17+
"infllmv2_attention_varlen / infllmv2_attention_kvcache not found in _infinicore. "
1418
"Build InfiniCore with: xmake f --aten=y --infllmv2=y (auto-detect under third_party/infllmv2_cuda_impl) "
1519
"or --infllmv2=/abs/path/to/libinfllm_v2.so (recommended), then xmake build/install."
1620
)
1721

1822

19-
def infllmv2_varlen(
23+
def infllmv2_attention_varlen(
2024
q: Tensor,
2125
k: Tensor,
2226
v: Tensor,
@@ -30,10 +34,10 @@ def infllmv2_varlen(
3034
window_size_right: int = -1,
3135
):
3236
"""InfLLM-V2 varlen attention. q,k,v unpadded; cu_seqlens_q/k [batch+1]. Returns [total_q, nheads, head_dim]."""
33-
if _native_infllmv2_varlen is None:
37+
if _native_infllmv2_attention_varlen is None:
3438
raise NotImplementedError(_MISSING_MSG)
3539
return Tensor(
36-
_native_infllmv2_varlen(
40+
_native_infllmv2_attention_varlen(
3741
q._underlying,
3842
k._underlying,
3943
v._underlying,
@@ -49,7 +53,7 @@ def infllmv2_varlen(
4953
)
5054

5155

52-
def infllmv2_kvcache(
56+
def infllmv2_attention_kvcache(
5357
q: Tensor,
5458
k_cache: Tensor,
5559
v_cache: Tensor,
@@ -60,10 +64,10 @@ def infllmv2_kvcache(
6064
window_size_right: int = -1,
6165
):
6266
"""InfLLM-V2 KV-cache (decode) attention. Returns [batch, seqlen_q, nheads, head_dim]."""
63-
if _native_infllmv2_kvcache is None:
67+
if _native_infllmv2_attention_kvcache is None:
6468
raise NotImplementedError(_MISSING_MSG)
6569
return Tensor(
66-
_native_infllmv2_kvcache(
70+
_native_infllmv2_attention_kvcache(
6771
q._underlying,
6872
k_cache._underlying,
6973
v_cache._underlying,

python/infinicore/tensor.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,6 @@ def is_pinned(self):
8080
def copy_(self, src):
8181
self._underlying.copy_(src._underlying)
8282

83-
def write_i32(self, linear_index, value):
84-
"""Write one int32 element at a contiguous linear index (metadata fast path)."""
85-
self._underlying.write_i32(linear_index, int(value))
86-
87-
def write_i64(self, linear_index, value):
88-
"""Write one int64 element at a contiguous linear index (metadata fast path)."""
89-
self._underlying.write_i64(linear_index, int(value))
90-
9183
def to(self, *args, **kwargs):
9284
return Tensor(
9385
self._underlying.to(*tuple(arg._underlying for arg in args), **kwargs)

src/infinicore/context/allocators/pinnable_block_allocator.cc

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#include "../../utils.hpp"
66

77
#include <algorithm>
8-
#include <cstdlib>
98
#include <infinirt.h>
109
#include <stdexcept>
1110

@@ -73,13 +72,6 @@ std::byte *PinnableBlockAllocator::allocate(size_t size) {
7372
block->frozen = pinned_mode_;
7473
block->in_use = true;
7574

76-
if (std::getenv("INFINICORE_DEBUG_ALLOC") != nullptr) {
77-
infiniDevice_t dev;
78-
int dev_id;
79-
infinirtGetDevice(&dev, &dev_id);
80-
spdlog::warn("PinnableBlockAllocator cudaMalloc request: requested={} aligned={} class={} device={} id={}",
81-
size, size, cls.block_size, static_cast<int>(dev), dev_id);
82-
}
8375
INFINICORE_CHECK_ERROR(infinirtMalloc(&block->ptr, block->size));
8476

8577
all_blocks_[block->ptr] = block;
@@ -105,13 +97,6 @@ std::byte *PinnableBlockAllocator::allocate(size_t size) {
10597
block->frozen = pinned_mode_;
10698
block->in_use = true;
10799

108-
if (std::getenv("INFINICORE_DEBUG_ALLOC") != nullptr) {
109-
infiniDevice_t dev;
110-
int dev_id;
111-
infinirtGetDevice(&dev, &dev_id);
112-
spdlog::warn("PinnableBlockAllocator cudaMalloc request (large): requested={} aligned={} device={} id={}",
113-
size, size, static_cast<int>(dev), dev_id);
114-
}
115100
INFINICORE_CHECK_ERROR(infinirtMalloc(&block->ptr, block->size));
116101

117102
large_blocks_.push_back(block);

0 commit comments

Comments
 (0)