@@ -24,6 +24,7 @@ void sdpa_full_self_attention_nax(
2424 const float scale,
2525 array& o,
2626 bool do_causal_,
27+ bool causal_upper_left,
2728 const std::optional<array>& mask,
2829 const std::optional<array>& sinks) {
2930 using namespace mlx ::steel;
@@ -131,7 +132,7 @@ void sdpa_full_self_attention_nax(
131132
132133 /* int qL_rem = */ (qL - NQ_aligned * bq),
133134 /* int kL_rem = */ (kL - NK_aligned * bk),
134- /* int qL_off = */ (kL - qL),
135+ /* int qL_off = */ (causal_upper_left ? 0 : kL - qL),
135136
136137 /* int64_t Q_strides[3] = */ {q.strides (0 ), q.strides (1 ), q.strides (2 )},
137138 /* int64_t K_strides[3] = */ {k.strides (0 ), k.strides (1 ), k.strides (2 )},
@@ -172,6 +173,7 @@ void sdpa_full_self_attention_metal(
172173 const float scale,
173174 array& o,
174175 bool do_causal_,
176+ bool causal_upper_left,
175177 const std::optional<array>& mask,
176178 const std::optional<array>& sinks) {
177179 if (metal::is_nax_available () && q.shape (3 ) != 80 &&
@@ -185,6 +187,7 @@ void sdpa_full_self_attention_metal(
185187 /* const float scale = */ scale,
186188 /* array& o = */ o,
187189 /* bool do_causal_ = */ do_causal_,
190+ /* bool causal_upper_left = */ causal_upper_left,
188191 /* const std::optional<array>& mask = */ mask,
189192 /* const std::optional<array>& sinks = */ sinks);
190193 }
@@ -294,7 +297,7 @@ void sdpa_full_self_attention_metal(
294297
295298 /* int qL_rem = */ (qL - NQ_aligned * bq),
296299 /* int kL_rem = */ (kL - NK_aligned * bk),
297- /* int qL_off = */ (kL - qL),
300+ /* int qL_off = */ (causal_upper_left ? 0 : kL - qL),
298301
299302 /* int64_t Q_strides[3] = */ {q.strides (0 ), q.strides (1 ), q.strides (2 )},
300303 /* int64_t K_strides[3] = */ {k.strides (0 ), k.strides (1 ), k.strides (2 )},
@@ -335,6 +338,7 @@ void sdpa_vector(
335338 array& out,
336339 float scale,
337340 bool do_causal,
341+ bool causal_upper_left,
338342 const std::optional<array>& mask,
339343 const std::optional<array>& sinks) {
340344 // Set the kernel name
@@ -410,6 +414,8 @@ void sdpa_vector(
410414 compute_encoder.set_input_array (*sinks, 16 );
411415 compute_encoder.set_bytes (q.shape (1 ), 17 );
412416 }
417+ int32_t causal_offset = causal_upper_left ? 0 : N - q.shape (2 );
418+ compute_encoder.set_bytes (causal_offset, 18 );
413419
414420 // Launch
415421 compute_encoder.dispatch_threadgroups (grid_dims, group_dims);
@@ -424,6 +430,7 @@ void sdpa_vector_2pass(
424430 array& out,
425431 float scale,
426432 bool do_causal,
433+ bool causal_upper_left,
427434 const std::optional<array>& mask,
428435 const std::optional<array>& sinks) {
429436 // Set the kernel name
@@ -554,6 +561,8 @@ void sdpa_vector_2pass(
554561 if (has_sinks) {
555562 compute_encoder.set_input_array (*sinks, 18 );
556563 }
564+ int32_t causal_offset = causal_upper_left ? 0 : N - q.shape (2 );
565+ compute_encoder.set_bytes (causal_offset, 19 );
557566
558567 // Launch
559568 compute_encoder.dispatch_threadgroups (grid_dims, group_dims);
@@ -744,9 +753,11 @@ void ScaledDotProductAttention::eval_gpu(
744753 char devc = d.get_architecture ().back ();
745754 if (((devc == ' d' || devc == ' s' ) && k.shape (2 ) >= 1024 ) ||
746755 (k.shape (1 ) < q.shape (1 ) && k.shape (2 ) >= 4096 )) {
747- sdpa_vector_2pass (s, d, q, k, v, o, scale_, do_causal, mask, sinks);
756+ sdpa_vector_2pass (
757+ s, d, q, k, v, o, scale_, do_causal, causal_upper_left_, mask, sinks);
748758 } else {
749- sdpa_vector (s, d, q, k, v, o, scale_, do_causal, mask, sinks);
759+ sdpa_vector (
760+ s, d, q, k, v, o, scale_, do_causal, causal_upper_left_, mask, sinks);
750761 }
751762 }
752763
@@ -779,7 +790,7 @@ void ScaledDotProductAttention::eval_gpu(
779790 : std::nullopt ;
780791
781792 sdpa_full_self_attention_metal (
782- s, d, q, k, v, scale_, o, do_causal_, mask, sinks);
793+ s, d, q, k, v, scale_, o, do_causal_, causal_upper_left_, mask, sinks);
783794 }
784795
785796 d.add_temporaries (std::move (copies), s.index );
0 commit comments