@@ -11,18 +11,13 @@ namespace infini::ops {
1111
1212class FlashAttention : public Operator <FlashAttention> {
1313 public:
14- FlashAttention (
15- const Tensor query, const Tensor key, const Tensor value,
16- std::optional<Tensor> cu_seqlens_q,
17- std::optional<Tensor> cu_seqlens_kv,
18- std::optional<Tensor> block_table,
19- int64_t num_heads, int64_t num_kv_heads, int64_t head_size,
20- double scale,
21- bool causal,
22- int64_t window_left,
23- int64_t window_right,
24- int64_t block_size,
25- Tensor output)
14+ FlashAttention (const Tensor query, const Tensor key, const Tensor value,
15+ std::optional<Tensor> cu_seqlens_q,
16+ std::optional<Tensor> cu_seqlens_kv,
17+ std::optional<Tensor> block_table, int64_t num_heads,
18+ int64_t num_kv_heads, int64_t head_size, double scale,
19+ bool causal, int64_t window_left, int64_t window_right,
20+ int64_t block_size, Tensor output)
2621 : num_tokens_{query.size (0 )},
2722 num_heads_{num_heads},
2823 num_kv_heads_{num_kv_heads},
@@ -50,18 +45,15 @@ class FlashAttention : public Operator<FlashAttention> {
5045 " `FlashAttention` requires query to be 3D [T, N, D]" );
5146 }
5247
53- virtual void operator ()(
54- const Tensor query, const Tensor key, const Tensor value,
55- std::optional<Tensor> cu_seqlens_q,
56- std::optional<Tensor> cu_seqlens_kv,
57- std::optional<Tensor> block_table,
58- int64_t num_heads, int64_t num_kv_heads, int64_t head_size,
59- double scale,
60- bool causal,
61- int64_t window_left,
62- int64_t window_right,
63- int64_t block_size,
64- Tensor output) const = 0;
48+ virtual void operator ()(const Tensor query, const Tensor key,
49+ const Tensor value,
50+ std::optional<Tensor> cu_seqlens_q,
51+ std::optional<Tensor> cu_seqlens_kv,
52+ std::optional<Tensor> block_table, int64_t num_heads,
53+ int64_t num_kv_heads, int64_t head_size, double scale,
54+ bool causal, int64_t window_left,
55+ int64_t window_right, int64_t block_size,
56+ Tensor output) const = 0;
6557
6658 protected:
6759 Tensor::Size num_tokens_{0 };
0 commit comments