Skip to content

Commit 4c39174

Browse files
committed
feat: chunked full-attention SDPA to avoid GPU watchdog timeout
The steel_attention kernel processes all keys in a single Metal dispatch. At 65K+ key sequence length, the dispatch can exceed the macOS GPU watchdog threshold (~5s), causing kIOGPUCommandBufferCallbackErrorImpactingInteractivity and process termination. Add 2-pass chunked full attention (sdpa_full_self_attention_2pass): - Splits key sequence into 65K-token chunks - Each chunk dispatches the existing steel_attention kernel with write_partial=true, outputting unnormalized O + partial max/sum - Reduction kernel (sdpa_full_reduce) merges chunk results via online softmax: O_final = sum(O_chunk * exp2(max_chunk - max_all)) - Routes automatically when kL >= 65536 Verified: Qwen3.5-2B at 512K context completes successfully (was crashing with GPU watchdog before this fix). TTFT=887s, 36.7 tok/s. The existing 1-pass path is unchanged for kL < 65536. The vector SDPA (qL <= 8, decode) already had 2-pass chunking. Files: - params.h: add chunked attention fields to AttnParams - steel_attention.h: add write_partial function constant for partial output mode (unnormalized O + max/sum to global memory) - steel_attention_reduce.metal: new reduction kernel for merging chunk results via online softmax - scaled_dot_product_attention.cpp: add sdpa_full_self_attention_2pass host function, route kL >= 65536 to it - CMakeLists.txt: build reduction kernel
1 parent 1432557 commit 4c39174

5 files changed

Lines changed: 345 additions & 13 deletions

File tree

mlx/backend/metal/kernels/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ if(NOT MLX_METAL_JIT)
153153
build_kernel(steel/gemm/kernels/steel_gemm_segmented ${STEEL_HEADERS})
154154
build_kernel(gemv_masked steel/utils.h)
155155
build_kernel(steel/attn/kernels/steel_attention ${STEEL_ATTN_HEADERS})
156+
build_kernel(steel/attn/kernels/steel_attention_reduce)
156157

157158
if((MLX_METAL_VERSION GREATER_EQUAL 400) AND (MACOS_SDK_VERSION GREATER_EQUAL
158159
26.2))

mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ constant bool align_K [[function_constant(201)]];
1414
constant bool has_mask [[function_constant(300)]];
1515
constant bool do_causal [[function_constant(301)]];
1616
constant bool has_sinks [[function_constant(302)]];
17+
constant bool write_partial [[function_constant(303)]];
1718

1819
struct MaxOp {
1920
template <typename T>
@@ -76,6 +77,8 @@ template <
7677
const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]],
7778
const device MaskType* mask [[buffer(6), function_constant(has_mask)]],
7879
const device T* sinks [[buffer(7), function_constant(has_sinks)]],
80+
device float* partial_maxs [[buffer(8), function_constant(write_partial)]],
81+
device float* partial_sums [[buffer(9), function_constant(write_partial)]],
7982
uint simd_lane_id [[thread_index_in_simdgroup]],
8083
uint simd_group_id [[simdgroup_index_in_threadgroup]],
8184
uint3 tid [[threadgroup_position_in_grid]],
@@ -456,21 +459,54 @@ template <
456459
loader_v.next();
457460
}
458461

459-
// Normalize output
460-
Otile.template row_bin_op<DivOp>(sum_score);
461-
threadgroup_barrier(mem_flags::mem_none);
462+
if (write_partial) {
463+
// Write unnormalized O, max_score, sum_score for 2-pass reduction.
464+
// O is NOT divided by sum — the reduction kernel handles normalization.
465+
threadgroup_barrier(mem_flags::mem_none);
462466

463-
// Store results
464-
O += (tm + sm) * params->O_strides[2] + sn;
467+
O += (tm + sm) * params->O_strides[2] + sn;
465468

466-
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
467-
auto dst_tile_dims = short2(BD - sn, params->qL_rem - (tm + sm));
469+
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
470+
auto dst_tile_dims = short2(BD - sn, params->qL_rem - (tm + sm));
471+
if (dst_tile_dims.x > 0 && dst_tile_dims.y > 0) {
472+
Otile.template store_safe<T, 1, 1>(O, params->O_strides[2], dst_tile_dims);
473+
}
474+
} else {
475+
Otile.template store<T, 1, 1>(O, params->O_strides[2]);
476+
}
468477

469-
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
470-
return;
478+
// Write per-row max and sum to global memory.
479+
// Layout: [B, H, qL] — one value per query position.
480+
// Each thread writes its owned rows (determined by simd position).
481+
int base_row = int(tid.x) * BQ + tm + sm;
482+
int global_idx = int(tidl.z) * params->H * params->qL +
483+
int(tidl.y) * params->qL + base_row;
471484

472-
Otile.template store_safe<T, 1, 1>(O, params->O_strides[2], dst_tile_dims);
485+
STEEL_PRAGMA_UNROLL
486+
for (short i = 0; i < kRowsPT; ++i) {
487+
int row = base_row + i * kFragSize;
488+
if (row < params->qL) {
489+
int idx = global_idx + i * kFragSize;
490+
partial_maxs[idx] = max_score[i];
491+
partial_sums[idx] = sum_score[i];
492+
}
493+
}
473494
} else {
474-
Otile.template store<T, 1, 1>(O, params->O_strides[2]);
495+
// Normal path: normalize and write final output.
496+
Otile.template row_bin_op<DivOp>(sum_score);
497+
threadgroup_barrier(mem_flags::mem_none);
498+
499+
O += (tm + sm) * params->O_strides[2] + sn;
500+
501+
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
502+
auto dst_tile_dims = short2(BD - sn, params->qL_rem - (tm + sm));
503+
504+
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
505+
return;
506+
507+
Otile.template store_safe<T, 1, 1>(O, params->O_strides[2], dst_tile_dims);
508+
} else {
509+
Otile.template store<T, 1, 1>(O, params->O_strides[2]);
510+
}
475511
}
476512
}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
// Copyright © 2026 Apple Inc. (contributed by Thump604)
2+
//
3+
// Reduction kernel for 2-pass chunked full attention.
4+
// Merges partial (O, max, sum) from multiple key-range chunks
5+
// using the online softmax algorithm.
6+
7+
// clang-format off
8+
#include "mlx/backend/metal/kernels/utils.h"
9+
10+
using namespace metal;
11+
12+
// Merge partial attention results using online softmax.
13+
// For each query position:
14+
// new_max = max(max_a, max_b)
15+
// scale_a = exp2(max_a - new_max)
16+
// scale_b = exp2(max_b - new_max)
17+
// O = O_a * scale_a + O_b * scale_b
18+
// sum = sum_a * scale_a + sum_b * scale_b
19+
// After all chunks merged: O_final = O / sum
20+
21+
template <typename T>
22+
[[kernel]] void sdpa_full_reduce(
23+
const device T* partials [[buffer(0)]], // [n_chunks, B*H*qL*D]
24+
const device float* maxs [[buffer(1)]], // [n_chunks, B*H*qL]
25+
const device float* sums [[buffer(2)]], // [n_chunks, B*H*qL]
26+
device T* output [[buffer(3)]], // [B*H*qL*D]
27+
constant int& n_chunks [[buffer(4)]],
28+
constant int& D [[buffer(5)]],
29+
constant int& qL [[buffer(6)]],
30+
uint3 tid [[thread_position_in_grid]], // (d, q, bh)
31+
uint3 grid [[threads_per_grid]]) {
32+
33+
int d_idx = tid.x;
34+
int q_idx = tid.y;
35+
int bh_idx = tid.z;
36+
37+
if (d_idx >= D || q_idx >= qL)
38+
return;
39+
40+
int bhq = bh_idx * qL + q_idx;
41+
int bhqd = bhq * D + d_idx;
42+
43+
float running_max = -INFINITY;
44+
float running_sum = 0.0f;
45+
float running_o = 0.0f;
46+
47+
int chunk_stride_bhq = int(grid.z) * qL;
48+
int chunk_stride_bhqd = int(grid.z) * qL * D;
49+
50+
for (int c = 0; c < n_chunks; c++) {
51+
float chunk_max = maxs[c * chunk_stride_bhq + bhq];
52+
float chunk_sum = sums[c * chunk_stride_bhq + bhq];
53+
float chunk_o = float(partials[c * chunk_stride_bhqd + bhqd]);
54+
55+
if (c == 0) {
56+
running_max = chunk_max;
57+
running_sum = chunk_sum;
58+
running_o = chunk_o;
59+
} else {
60+
float new_max = max(running_max, chunk_max);
61+
float scale_old = fast::exp2(running_max - new_max);
62+
float scale_new = fast::exp2(chunk_max - new_max);
63+
64+
running_o = running_o * scale_old + chunk_o * scale_new;
65+
running_sum = running_sum * scale_old + chunk_sum * scale_new;
66+
running_max = new_max;
67+
}
68+
}
69+
70+
output[bhqd] = T(running_o / running_sum);
71+
}
72+
73+
#define instantiate_reduce(tname, dtype) \
74+
template [[host_name("sdpa_full_reduce_" #tname)]] \
75+
[[kernel]] void sdpa_full_reduce<dtype>( \
76+
const device dtype*, const device float*, const device float*, \
77+
device dtype*, constant int&, constant int&, constant int&, \
78+
uint3, uint3);
79+
80+
instantiate_reduce(float16, half);
81+
instantiate_reduce(bfloat16, bfloat16_t);
82+
instantiate_reduce(float32, float);
83+
// clang-format on

mlx/backend/metal/kernels/steel/attn/params.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@ struct AttnParams {
3434
int64_t K_strides[3]; ///< Key strides (B, H, L, D = 1)
3535
int64_t V_strides[3]; ///< Value strides (B, H, L, D = 1)
3636
int64_t O_strides[3]; ///< Output strides (B, H, L, D = 1)
37+
38+
// Chunked attention parameters (for 2-pass to avoid GPU watchdog).
39+
// When nk_chunk_end > 0, the kernel processes only keys in
40+
// [nk_chunk_start, nk_chunk_end) and writes partial softmax state
41+
// to intermediate buffers. The reduction pass merges chunk results.
42+
int nk_chunk_start; ///< First key block to process (0 = from beginning)
43+
int nk_chunk_end; ///< Last key block (exclusive, 0 = use NK)
44+
int chunk_idx; ///< Index of this chunk (for indexing intermediates)
3745
};
3846

3947
struct AttnMaskParams {

0 commit comments

Comments
 (0)