Skip to content

Commit 6b2d587

Browse files
committed
add causal_upper_left mask option to scaled_dot_product_attention
1 parent 5d17004 commit 6b2d587

6 files changed

Lines changed: 112 additions & 25 deletions

File tree

mlx/backend/metal/kernels/sdpa_vector.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ template <typename T, int D, int V = D>
3636
const device T* sinks [[buffer(16), function_constant(has_sinks)]],
3737
const constant int& num_q_heads
3838
[[buffer(17), function_constant(has_sinks)]],
39+
const constant int& causal_offset [[buffer(18)]],
3940
uint3 tid [[threadgroup_position_in_grid]],
4041
uint3 tpg [[threadgroups_per_grid]],
4142
uint simd_gid [[simdgroup_index_in_threadgroup]],
@@ -99,7 +100,7 @@ template <typename T, int D, int V = D>
99100
for (int i = simd_gid; i < N; i += BN) {
100101
bool use_key = true;
101102
if (do_causal) {
102-
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
103+
use_key = i <= (causal_offset + int(q_seq_idx));
103104
} else if (bool_mask) {
104105
use_key = bmask[0];
105106
} else if (float_mask) {
@@ -199,6 +200,7 @@ template <typename T, int D, int V = D>
199200
const constant int& mask_head_stride
200201
[[buffer(17), function_constant(has_mask)]],
201202
const device T* sinks [[buffer(18), function_constant(has_sinks)]],
203+
const constant int& causal_offset [[buffer(19)]],
202204
uint3 tptg [[threads_per_threadgroup]],
203205
uint3 tidtg [[thread_position_in_threadgroup]],
204206
uint3 tid [[threadgroup_position_in_grid]],
@@ -263,7 +265,7 @@ template <typename T, int D, int V = D>
263265
for (int i = block_idx; i < N; i += blocks) {
264266
bool use_key = true;
265267
if (do_causal) {
266-
use_key = i <= (N - q_seq_len + int(q_seq_idx));
268+
use_key = i <= (causal_offset + int(q_seq_idx));
267269
} else if (bool_mask) {
268270
use_key = bmask[0];
269271
} else if (float_mask) {

mlx/backend/metal/scaled_dot_product_attention.cpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ void sdpa_full_self_attention_nax(
2424
const float scale,
2525
array& o,
2626
bool do_causal_,
27+
bool causal_upper_left,
2728
const std::optional<array>& mask,
2829
const std::optional<array>& sinks) {
2930
using namespace mlx::steel;
@@ -131,7 +132,7 @@ void sdpa_full_self_attention_nax(
131132

132133
/* int qL_rem = */ (qL - NQ_aligned * bq),
133134
/* int kL_rem = */ (kL - NK_aligned * bk),
134-
/* int qL_off = */ (kL - qL),
135+
/* int qL_off = */ (causal_upper_left ? 0 : kL - qL),
135136

136137
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
137138
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
@@ -172,6 +173,7 @@ void sdpa_full_self_attention_metal(
172173
const float scale,
173174
array& o,
174175
bool do_causal_,
176+
bool causal_upper_left,
175177
const std::optional<array>& mask,
176178
const std::optional<array>& sinks) {
177179
if (metal::is_nax_available() && q.shape(3) != 80 &&
@@ -185,6 +187,7 @@ void sdpa_full_self_attention_metal(
185187
/* const float scale = */ scale,
186188
/* array& o = */ o,
187189
/* bool do_causal_ = */ do_causal_,
190+
/* bool causal_upper_left = */ causal_upper_left,
188191
/* const std::optional<array>& mask = */ mask,
189192
/* const std::optional<array>& sinks = */ sinks);
190193
}
@@ -294,7 +297,7 @@ void sdpa_full_self_attention_metal(
294297

295298
/* int qL_rem = */ (qL - NQ_aligned * bq),
296299
/* int kL_rem = */ (kL - NK_aligned * bk),
297-
/* int qL_off = */ (kL - qL),
300+
/* int qL_off = */ (causal_upper_left ? 0 : kL - qL),
298301

299302
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
300303
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
@@ -335,6 +338,7 @@ void sdpa_vector(
335338
array& out,
336339
float scale,
337340
bool do_causal,
341+
bool causal_upper_left,
338342
const std::optional<array>& mask,
339343
const std::optional<array>& sinks) {
340344
// Set the kernel name
@@ -410,6 +414,8 @@ void sdpa_vector(
410414
compute_encoder.set_input_array(*sinks, 16);
411415
compute_encoder.set_bytes(q.shape(1), 17);
412416
}
417+
int32_t causal_offset = causal_upper_left ? 0 : N - q.shape(2);
418+
compute_encoder.set_bytes(causal_offset, 18);
413419

414420
// Launch
415421
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
@@ -424,6 +430,7 @@ void sdpa_vector_2pass(
424430
array& out,
425431
float scale,
426432
bool do_causal,
433+
bool causal_upper_left,
427434
const std::optional<array>& mask,
428435
const std::optional<array>& sinks) {
429436
// Set the kernel name
@@ -554,6 +561,8 @@ void sdpa_vector_2pass(
554561
if (has_sinks) {
555562
compute_encoder.set_input_array(*sinks, 18);
556563
}
564+
int32_t causal_offset = causal_upper_left ? 0 : N - q.shape(2);
565+
compute_encoder.set_bytes(causal_offset, 19);
557566

558567
// Launch
559568
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
@@ -744,9 +753,11 @@ void ScaledDotProductAttention::eval_gpu(
744753
char devc = d.get_architecture().back();
745754
if (((devc == 'd' || devc == 's') && k.shape(2) >= 1024) ||
746755
(k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) {
747-
sdpa_vector_2pass(s, d, q, k, v, o, scale_, do_causal, mask, sinks);
756+
sdpa_vector_2pass(
757+
s, d, q, k, v, o, scale_, do_causal, causal_upper_left_, mask, sinks);
748758
} else {
749-
sdpa_vector(s, d, q, k, v, o, scale_, do_causal, mask, sinks);
759+
sdpa_vector(
760+
s, d, q, k, v, o, scale_, do_causal, causal_upper_left_, mask, sinks);
750761
}
751762
}
752763

@@ -779,7 +790,7 @@ void ScaledDotProductAttention::eval_gpu(
779790
: std::nullopt;
780791

781792
sdpa_full_self_attention_metal(
782-
s, d, q, k, v, scale_, o, do_causal_, mask, sinks);
793+
s, d, q, k, v, scale_, o, do_causal_, causal_upper_left_, mask, sinks);
783794
}
784795

785796
d.add_temporaries(std::move(copies), s.index);

mlx/fast.cpp

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -628,26 +628,31 @@ array scaled_dot_product_attention(
628628
}
629629
}
630630
// Check valid mask
631-
if (mask_mode != "" && mask_mode != "causal" && mask_mode != "array") {
631+
bool is_causal_mode = mask_mode == "causal" ||
632+
mask_mode == "causal_lower_right" || mask_mode == "causal_upper_left";
633+
if (mask_mode != "" && !is_causal_mode && mask_mode != "array") {
632634
std::ostringstream msg;
633-
msg << "[scaled_dot_product_attention] Invalid mask_mode " << mask_mode
634-
<< ". mask_mode must be 'causal', 'array' or ''.";
635+
msg << "[scaled_dot_product_attention] Invalid mask_mode '" << mask_mode
636+
<< "'. Must be 'causal', 'causal_lower_right', "
637+
<< "'causal_upper_left', 'array' or ''.";
635638
throw std::invalid_argument(msg.str());
636639
}
637640

638641
bool do_causal = false;
642+
bool causal_upper_left = false;
639643
bool has_mask = false;
640644
bool has_arr_mask = false;
641645
bool has_bool_mask = false;
642646

643-
if (mask_mode == "causal") {
647+
if (is_causal_mode) {
644648
has_mask = true;
645649
do_causal = true;
650+
causal_upper_left = (mask_mode == "causal_upper_left");
646651

647652
if (mask_arr) {
648653
std::ostringstream msg;
649654
msg << "[scaled_dot_product_attention] Invalid mask_arr for mask_mode "
650-
<< "'casusal'. No array mask should be passed.";
655+
<< "'" << mask_mode << "'. No array mask should be passed.";
651656
throw std::invalid_argument(msg.str());
652657
}
653658
} else if (mask_arr) {
@@ -718,6 +723,7 @@ array scaled_dot_product_attention(
718723
n_q_heads,
719724
n_kv_heads,
720725
do_causal,
726+
causal_upper_left,
721727
has_sinks,
722728
has_arr_mask,
723729
s](const std::vector<array>& inputs) {
@@ -737,7 +743,7 @@ array scaled_dot_product_attention(
737743
if (do_causal) {
738744
int kL = k.shape(-2);
739745
int qL = q.shape(-2);
740-
int offset = kL - qL;
746+
int offset = causal_upper_left ? 0 : kL - qL;
741747
auto q_idx = arange(offset, qL + offset, s);
742748
auto k_idx = arange(0, kL, s);
743749
q_idx = expand_dims(q_idx, 1, s);
@@ -846,7 +852,13 @@ array scaled_dot_product_attention(
846852
}
847853
Shape out_shape{q.shape(0), q.shape(1), q.shape(2), v.shape(-1)};
848854
auto primitive = std::make_shared<ScaledDotProductAttention>(
849-
stream, fallback, scale, do_causal, has_sinks, output_logsumexp);
855+
stream,
856+
fallback,
857+
scale,
858+
do_causal,
859+
causal_upper_left,
860+
has_sinks,
861+
output_logsumexp);
850862
if (output_logsumexp) {
851863
return array::make_arrays(
852864
{std::move(out_shape), Shape{q.shape(0), q.shape(1), q.shape(2), 1}},
@@ -911,6 +923,7 @@ bool ScaledDotProductAttention::is_equivalent(const Primitive& other) const {
911923
const ScaledDotProductAttention& a_other =
912924
static_cast<const ScaledDotProductAttention&>(other);
913925
return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_ &&
926+
causal_upper_left_ == a_other.causal_upper_left_ &&
914927
has_sinks_ == a_other.has_sinks_ &&
915928
output_logsumexp_ == a_other.output_logsumexp_;
916929
}

mlx/fast_primitives.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,11 +210,13 @@ class ScaledDotProductAttention : public Custom {
210210
std::function<std::vector<array>(std::vector<array>)> fallback,
211211
float scale,
212212
bool do_causal,
213+
bool causal_upper_left,
213214
bool has_sinks,
214215
bool output_logsumexp)
215216
: Custom(stream, std::move(fallback)),
216217
scale_(scale),
217218
do_causal_(do_causal),
219+
causal_upper_left_(causal_upper_left),
218220
has_sinks_(has_sinks),
219221
output_logsumexp_(output_logsumexp) {}
220222

@@ -250,12 +252,18 @@ class ScaledDotProductAttention : public Custom {
250252
DEFINE_INPUT_OUTPUT_SHAPE()
251253
auto state() const {
252254
return std::make_tuple(
253-
nullptr, scale_, do_causal_, has_sinks_, output_logsumexp_);
255+
nullptr,
256+
scale_,
257+
do_causal_,
258+
causal_upper_left_,
259+
has_sinks_,
260+
output_logsumexp_);
254261
}
255262

256263
private:
257264
float scale_;
258265
bool do_causal_;
266+
bool causal_upper_left_;
259267
bool has_sinks_;
260268
bool output_logsumexp_;
261269
};

python/src/fast.cpp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,13 @@ void init_fast(nb::module_& parent_module) {
206206
if (has_mask) {
207207
if (has_str_mask) {
208208
auto mask_str = std::get<std::string>(mask);
209-
if (mask_str != "causal") {
209+
if (mask_str != "causal" && mask_str != "causal_lower_right" &&
210+
mask_str != "causal_upper_left") {
210211
std::ostringstream msg;
211212
msg << "[scaled_dot_product_attention] invalid mask option '"
212-
<< mask_str << "'. Must be 'causal', or an array.";
213+
<< mask_str
214+
<< "'. Must be 'causal', 'causal_lower_right', "
215+
<< "'causal_upper_left', or an array.";
213216
throw std::invalid_argument(msg.str());
214217
}
215218
return mx::fast::scaled_dot_product_attention(
@@ -267,13 +270,20 @@ void init_fast(nb::module_& parent_module) {
267270
scale (float): Scale for queries (typically ``1.0 / sqrt(q.shape(-1)``).
268271
mask (str or array, optional): The mask to apply to the
269272
query-key scores. The mask can be an array or a string indicating
270-
the mask type. The only supported string type is ``"causal"``. If
271-
the mask is an array it can be a boolean or additive mask. The mask
272-
can have at most 4 dimensions and must be broadcast-compatible with
273-
the shape ``[B, N, T_q, T_kv]``. If an additive mask is given its
274-
type must promote to the promoted type of ``q``, ``k``, and ``v``.
275-
The ``"causal"`` mask uses lower-right alignment where the
276-
last query aligns with the last key.
273+
the mask type. Supported string types are:
274+
275+
* ``"causal"`` or ``"causal_lower_right"``: Lower-right
276+
aligned causal mask. The last query attends to the last key.
277+
This is the standard mask for autoregressive decoding.
278+
* ``"causal_upper_left"``: Upper-left aligned causal mask.
279+
Query ``i`` attends to keys ``0..i``. This matches PyTorch's
280+
default ``is_causal=True`` behavior.
281+
282+
If the mask is an array it can be a boolean or additive mask.
283+
The mask can have at most 4 dimensions and must be
284+
broadcast-compatible with the shape ``[B, N, T_q, T_kv]``. If
285+
an additive mask is given its type must promote to the promoted
286+
type of ``q``, ``k``, and ``v``.
277287
sinks (array, optional): An optional array of attention sinks.
278288
Default: ``None``.
279289

python/tests/test_fast_sdpa.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ def mlx_ref_attn(q, k, v, scale=1.0, mask=None, sinks=None):
2626
scores = q @ mx.swapaxes(k, -1, -2)
2727
is_causal = mask == "causal"
2828
if mask is not None:
29-
3029
if is_causal:
3130
offset = kL - L
3231
q_indices = mx.arange(L) + offset
@@ -642,6 +641,50 @@ def test_sdpa_sliced(self):
642641
tolerance = {"rtol": 1e-2, "atol": 1e-2}
643642
self.assertTrue(mx.allclose(ref, out, **tolerance))
644643

644+
def test_causal_mask_alignment(self):
645+
B, H, D = 1, 2, 64
646+
qL, kL = 4, 8
647+
scale = 1.0 / math.sqrt(D)
648+
649+
mx.random.seed(0)
650+
q = mx.random.normal((B, H, qL, D))
651+
k = mx.random.normal((B, H, kL, D))
652+
v = mx.random.normal((B, H, kL, D))
653+
654+
# "causal" and "causal_lower_right" should be identical
655+
out_causal = mx.fast.scaled_dot_product_attention(
656+
q, k, v, scale=scale, mask="causal"
657+
)
658+
out_lr = mx.fast.scaled_dot_product_attention(
659+
q, k, v, scale=scale, mask="causal_lower_right"
660+
)
661+
self.assertTrue(mx.allclose(out_causal, out_lr, atol=1e-6, rtol=1e-5))
662+
663+
# "causal_upper_left" should match a manual upper-left mask
664+
q_idx = mx.arange(qL)
665+
k_idx = mx.arange(kL)
666+
ul_mask = q_idx[:, None] >= k_idx[None]
667+
out_ul = mx.fast.scaled_dot_product_attention(
668+
q, k, v, scale=scale, mask="causal_upper_left"
669+
)
670+
out_manual = mx.fast.scaled_dot_product_attention(
671+
q, k, v, scale=scale, mask=ul_mask
672+
)
673+
self.assertTrue(mx.allclose(out_ul, out_manual, atol=1e-5, rtol=1e-4))
674+
675+
# upper-left != lower-right when qL != kL
676+
self.assertFalse(mx.allclose(out_ul, out_lr, atol=1e-2, rtol=1e-2))
677+
678+
# when qL == kL, both should be identical
679+
q_eq = mx.random.normal((B, H, kL, D))
680+
out_lr_eq = mx.fast.scaled_dot_product_attention(
681+
q_eq, k, v, scale=scale, mask="causal_lower_right"
682+
)
683+
out_ul_eq = mx.fast.scaled_dot_product_attention(
684+
q_eq, k, v, scale=scale, mask="causal_upper_left"
685+
)
686+
self.assertTrue(mx.allclose(out_lr_eq, out_ul_eq, atol=1e-6, rtol=1e-5))
687+
645688

646689
if __name__ == "__main__":
647690
mlx_tests.MLXTestRunner(failfast=True)

0 commit comments

Comments
 (0)