@@ -61,59 +61,29 @@ INFINICORE_GRAPH_OP_CLASS(
6161//
6262// Returns:
6363// [total_q, nheads, head_dim]
64- void infllmv2_varlen_ (Tensor out,
65- const Tensor &q,
66- const Tensor &k,
67- const Tensor &v,
68- const Tensor &cu_seqlens_q,
69- const Tensor &cu_seqlens_k,
70- int max_seqlen_q,
71- int max_seqlen_k,
72- float scale,
73- bool causal,
74- int window_size_left = -1 ,
75- int window_size_right = -1 );
76- Tensor infllmv2_varlen (const Tensor &q,
77- const Tensor &k,
78- const Tensor &v,
79- const Tensor &cu_seqlens_q,
80- const Tensor &cu_seqlens_k,
81- int max_seqlen_q,
82- int max_seqlen_k,
83- float scale,
84- bool causal,
85- int window_size_left = -1 ,
86- int window_size_right = -1 );
87-
88- // Preferred names (attention-disambiguated). These are header-only aliases to the
89- // backward-compatible `infllmv2_*` symbols to avoid adding extra exported ABI.
90- inline void infllmv2_attention_varlen_ (Tensor out,
91- const Tensor &q,
92- const Tensor &k,
93- const Tensor &v,
94- const Tensor &cu_seqlens_q,
95- const Tensor &cu_seqlens_k,
96- int max_seqlen_q,
97- int max_seqlen_k,
98- float scale,
99- bool causal,
100- int window_size_left = -1 ,
101- int window_size_right = -1 ) {
102- infllmv2_varlen_ (out, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, scale, causal, window_size_left, window_size_right);
103- }
104- inline Tensor infllmv2_attention_varlen (const Tensor &q,
105- const Tensor &k,
106- const Tensor &v,
107- const Tensor &cu_seqlens_q,
108- const Tensor &cu_seqlens_k,
109- int max_seqlen_q,
110- int max_seqlen_k,
111- float scale,
112- bool causal,
113- int window_size_left = -1 ,
114- int window_size_right = -1 ) {
115- return infllmv2_varlen (q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, scale, causal, window_size_left, window_size_right);
116- }
64+ void infllmv2_attention_varlen_ (Tensor out,
65+ const Tensor &q,
66+ const Tensor &k,
67+ const Tensor &v,
68+ const Tensor &cu_seqlens_q,
69+ const Tensor &cu_seqlens_k,
70+ int max_seqlen_q,
71+ int max_seqlen_k,
72+ float scale,
73+ bool causal,
74+ int window_size_left = -1 ,
75+ int window_size_right = -1 );
76+ Tensor infllmv2_attention_varlen (const Tensor &q,
77+ const Tensor &k,
78+ const Tensor &v,
79+ const Tensor &cu_seqlens_q,
80+ const Tensor &cu_seqlens_k,
81+ int max_seqlen_q,
82+ int max_seqlen_k,
83+ float scale,
84+ bool causal,
85+ int window_size_left = -1 ,
86+ int window_size_right = -1 );
11787
11888// Decode-time InfLLM-V2 attention with KV cache.
11989//
@@ -125,104 +95,55 @@ inline Tensor infllmv2_attention_varlen(const Tensor &q,
12595//
12696// Returns:
12797// [batch, seqlen_q, nheads, head_dim]
128- void infllmv2_kvcache_ (Tensor out,
129- const Tensor &q,
130- const Tensor &k_cache,
131- const Tensor &v_cache,
132- const Tensor &cache_lens,
133- float scale,
134- bool causal,
135- int window_size_left = -1 ,
136- int window_size_right = -1 );
137- Tensor infllmv2_kvcache (const Tensor &q,
138- const Tensor &k_cache,
139- const Tensor &v_cache,
140- const Tensor &cache_lens,
141- float scale,
142- bool causal,
143- int window_size_left = -1 ,
144- int window_size_right = -1 );
98+ void infllmv2_attention_kvcache_ (Tensor out,
99+ const Tensor &q,
100+ const Tensor &k_cache,
101+ const Tensor &v_cache,
102+ const Tensor &cache_lens,
103+ float scale,
104+ bool causal,
105+ int window_size_left = -1 ,
106+ int window_size_right = -1 );
107+ Tensor infllmv2_attention_kvcache (const Tensor &q,
108+ const Tensor &k_cache,
109+ const Tensor &v_cache,
110+ const Tensor &cache_lens,
111+ float scale,
112+ bool causal,
113+ int window_size_left = -1 ,
114+ int window_size_right = -1 );
145115
146- inline void infllmv2_attention_kvcache_ (Tensor out,
116+ // Decode-time InfLLM-V2 attention with KV cache, updating cache in-place.
117+ //
118+ // Shapes:
119+ // q : [batch, seqlen_q, nheads, head_dim]
120+ // k_cache : [batch, seqlen_cache, nheads_k, head_dim] (dense cache)
121+ // v_cache : same as k_cache
122+ // k_new/v_new: [batch, seqlen_new, nheads_k, head_dim] (new KV to append at cache_lens offsets)
123+ // cache_lens : [batch] (int32) current KV length per sequence BEFORE appending
124+ //
125+ // Returns:
126+ // [batch, seqlen_q, nheads, head_dim]
127+ void infllmv2_attention_kvcache_update_ (Tensor out,
147128 const Tensor &q,
148129 const Tensor &k_cache,
149130 const Tensor &v_cache,
131+ const Tensor &k_new,
132+ const Tensor &v_new,
150133 const Tensor &cache_lens,
151134 float scale,
152135 bool causal,
153136 int window_size_left = -1 ,
154- int window_size_right = -1 ) {
155- infllmv2_kvcache_ (out, q, k_cache, v_cache, cache_lens, scale, causal, window_size_left, window_size_right);
156- }
157- inline Tensor infllmv2_attention_kvcache (const Tensor &q,
137+ int window_size_right = -1 );
138+ Tensor infllmv2_attention_kvcache_update (const Tensor &q,
158139 const Tensor &k_cache,
159140 const Tensor &v_cache,
141+ const Tensor &k_new,
142+ const Tensor &v_new,
160143 const Tensor &cache_lens,
161144 float scale,
162145 bool causal,
163146 int window_size_left = -1 ,
164- int window_size_right = -1 ) {
165- return infllmv2_kvcache (q, k_cache, v_cache, cache_lens, scale, causal, window_size_left, window_size_right);
166- }
167-
168- // Decode-time InfLLM-V2 attention with KV cache, updating cache in-place.
169- //
170- // Shapes:
171- // q : [batch, seqlen_q, nheads, head_dim]
172- // k_cache : [batch, seqlen_cache, nheads_k, head_dim] (dense cache)
173- // v_cache : same as k_cache
174- // k_new/v_new: [batch, seqlen_new, nheads_k, head_dim] (new KV to append at cache_lens offsets)
175- // cache_lens : [batch] (int32) current KV length per sequence BEFORE appending
176- //
177- // Returns:
178- // [batch, seqlen_q, nheads, head_dim]
179- void infllmv2_kvcache_update_ (Tensor out,
180- const Tensor &q,
181- const Tensor &k_cache,
182- const Tensor &v_cache,
183- const Tensor &k_new,
184- const Tensor &v_new,
185- const Tensor &cache_lens,
186- float scale,
187- bool causal,
188- int window_size_left = -1 ,
189- int window_size_right = -1 );
190- Tensor infllmv2_kvcache_update (const Tensor &q,
191- const Tensor &k_cache,
192- const Tensor &v_cache,
193- const Tensor &k_new,
194- const Tensor &v_new,
195- const Tensor &cache_lens,
196- float scale,
197- bool causal,
198- int window_size_left = -1 ,
199- int window_size_right = -1 );
200-
201- inline void infllmv2_attention_kvcache_update_ (Tensor out,
202- const Tensor &q,
203- const Tensor &k_cache,
204- const Tensor &v_cache,
205- const Tensor &k_new,
206- const Tensor &v_new,
207- const Tensor &cache_lens,
208- float scale,
209- bool causal,
210- int window_size_left = -1 ,
211- int window_size_right = -1 ) {
212- infllmv2_kvcache_update_ (out, q, k_cache, v_cache, k_new, v_new, cache_lens, scale, causal, window_size_left, window_size_right);
213- }
214- inline Tensor infllmv2_attention_kvcache_update (const Tensor &q,
215- const Tensor &k_cache,
216- const Tensor &v_cache,
217- const Tensor &k_new,
218- const Tensor &v_new,
219- const Tensor &cache_lens,
220- float scale,
221- bool causal,
222- int window_size_left = -1 ,
223- int window_size_right = -1 ) {
224- return infllmv2_kvcache_update (q, k_cache, v_cache, k_new, v_new, cache_lens, scale, causal, window_size_left, window_size_right);
225- }
147+ int window_size_right = -1 );
226148
227149} // namespace infinicore::op
228-
0 commit comments