@@ -9,20 +9,20 @@ common::OpDispatcher<PagedAttention::schema> &PagedAttention::dispatcher() {
99 return dispatcher_;
1010};
1111
12- void PagedAttention::execute (Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens , std::optional<Tensor> alibi_slopes, float scale) {
13- INFINICORE_ASSERT_TENSORS_SAME_DEVICE (out, q, k_cache, v_cache, block_tables, seq_lens );
12+ void PagedAttention::execute (Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens , std::optional<Tensor> alibi_slopes, float scale) {
13+ INFINICORE_ASSERT_TENSORS_SAME_DEVICE (out, q, k_cache, v_cache, block_tables, cache_lens );
1414 infinicore::context::setDevice (out->device ());
15- dispatcher ().lookup (out->device ().getType ())(out, q, k_cache, v_cache, block_tables, seq_lens , alibi_slopes, scale);
15+ dispatcher ().lookup (out->device ().getType ())(out, q, k_cache, v_cache, block_tables, cache_lens , alibi_slopes, scale);
1616}
1717
18- Tensor paged_attention (Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens , std::optional<Tensor> alibi_slopes, float scale) {
18+ Tensor paged_attention (Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens , std::optional<Tensor> alibi_slopes, float scale) {
1919 auto out = Tensor::empty (q->shape (), q->dtype (), q->device ());
20- paged_attention_ (out, q, k_cache, v_cache, block_tables, seq_lens , alibi_slopes, scale);
20+ paged_attention_ (out, q, k_cache, v_cache, block_tables, cache_lens , alibi_slopes, scale);
2121 return out;
2222}
2323
24- void paged_attention_ (Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens , std::optional<Tensor> alibi_slopes, float scale) {
25- PagedAttention::execute (out, q, k_cache, v_cache, block_tables, seq_lens , alibi_slopes, scale);
24+ void paged_attention_ (Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens , std::optional<Tensor> alibi_slopes, float scale) {
25+ PagedAttention::execute (out, q, k_cache, v_cache, block_tables, cache_lens , alibi_slopes, scale);
2626}
2727
2828} // namespace infinicore::op
0 commit comments