Skip to content

Commit 8d4b379

Browse files
committed
fix: chunked SDPA correctness — scale, causal offset, precision, overflow
Fix four bugs in sdpa_full_self_attention_2pass and redesign the reduction pipeline for numerical precision and memory scalability. Bugs fixed: 1. Double M_LOG2E: host pre-multiplied scale by M_LOG2E, but the kernel already multiplies by M_LOG2E_F. Result was scale^2 * M_LOG2E^2 instead of scale * M_LOG2E. Fix: pass raw scale (matches single-pass). 2. Wrong causal qL_off: formula was (k_start + chunk_kL) - qL = k_end - qL. Correct formula is (kL - qL) - k_start. The old formula broke causal masking completely — early chunks masked everything, late chunks nothing. 3. int32 overflow: B*H*qL*D overflows signed int32 at H=64, qL=131072, D=256 (exactly 2^31). Fix: use int64_t for bhq/bhqd. 4. simdgroup_barrier: guard was BD == 128, but BD=256 (head_dim=256 models) needs the same V@P read-after-write barrier. Fix: BD >= 128. Redesign: streaming merge with float32 accumulator - Old: allocate [n_chunks, B, H, qL, D] partials in input dtype, then reduce all chunks at once. Memory scales linearly with chunk count, precision limited by half/bfloat16 round-trip. - New: one chunk buffer (type T, reused) + one float32 accumulator. After each chunk's steel_attention dispatch, sdpa_full_merge folds results into the accumulator via online softmax. sdpa_full_finalize normalizes and writes output with correct stride layout (BLHD). - Memory: O(B*H*qL*D) constant regardless of chunk count. - Precision: float32 throughout accumulation, only final output cast to T. Verified against manual float32 reference (matmul + softmax): - Non-causal kL=65537: max_diff=0.000199, mean_diff=0.0000098 - Causal kL=65537: max_diff=0.008, mean_diff=0.0000193 - GQA D=256 causal: finite, correct magnitude - 3-chunk kL=131073: finite, correct magnitude
1 parent 4c39174 commit 8d4b379

3 files changed

Lines changed: 229 additions & 129 deletions

File tree

mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ template <
431431
for (short id = 0; id < TD; id++) {
432432
STEEL_PRAGMA_UNROLL
433433
for (short ik = 0; ik < TK; ik++) {
434-
if constexpr (BD == 128) {
434+
if constexpr (BD >= 128) {
435435
simdgroup_barrier(mem_flags::mem_none);
436436
}
437437

@@ -441,7 +441,7 @@ template <
441441
Vtile.template load<T, 1, 1, LDV_tgp, 1>(
442442
&Vs[Vs_offset + kk * LDV_tgp + dd]);
443443

444-
if constexpr (BD == 128) {
444+
if constexpr (BD >= 128) {
445445
simdgroup_barrier(mem_flags::mem_none);
446446
}
447447

Lines changed: 143 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,83 +1,164 @@
11
// Copyright © 2026 Apple Inc. (contributed by Thump604)
22
//
3-
// Reduction kernel for 2-pass chunked full attention.
4-
// Merges partial (O, max, sum) from multiple key-range chunks
5-
// using the online softmax algorithm.
3+
// Merge and finalize kernels for 2-pass chunked full attention.
4+
//
5+
// Streaming online softmax: after each chunk's steel_attention dispatch,
6+
// sdpa_full_merge folds the chunk's partial results into a float32
7+
// accumulator. After all chunks, sdpa_full_finalize normalizes and
8+
// writes the output in the caller's stride layout.
9+
//
10+
// Float32 accumulation eliminates the precision loss that would occur
11+
// from storing intermediate results in half/bfloat16.
612

713
// clang-format off
814
#include "mlx/backend/metal/kernels/utils.h"
915

1016
using namespace metal;
1117

12-
// Merge partial attention results using online softmax.
13-
// For each query position:
14-
// new_max = max(max_a, max_b)
15-
// scale_a = exp2(max_a - new_max)
16-
// scale_b = exp2(max_b - new_max)
17-
// O = O_a * scale_a + O_b * scale_b
18-
// sum = sum_a * scale_a + sum_b * scale_b
19-
// After all chunks merged: O_final = O / sum
18+
// ---------------------------------------------------------------------------
19+
// sdpa_full_merge — fold one chunk into the running float32 accumulator
20+
// ---------------------------------------------------------------------------
21+
//
22+
// Grid: (D, qL, B*H) — one thread per output element.
23+
// Group: (D, 1, 1) — all D threads for a query row in one threadgroup.
24+
// Requires D <= 256 (true for all current models).
25+
//
26+
// For the first chunk (is_first=1): copies chunk → accum.
27+
// For subsequent chunks: online softmax merge.
28+
//
29+
// A threadgroup_barrier separates reads of accum_maxs from the write by
30+
// thread 0, preventing a race between SIMD groups in the same threadgroup.
2031

2132
template <typename T>
22-
[[kernel]] void sdpa_full_reduce(
23-
const device T* partials [[buffer(0)]], // [n_chunks, B*H*qL*D]
24-
const device float* maxs [[buffer(1)]], // [n_chunks, B*H*qL]
25-
const device float* sums [[buffer(2)]], // [n_chunks, B*H*qL]
26-
device T* output [[buffer(3)]], // [B*H*qL*D]
27-
constant int& n_chunks [[buffer(4)]],
28-
constant int& D [[buffer(5)]],
29-
constant int& qL [[buffer(6)]],
30-
uint3 tid [[thread_position_in_grid]], // (d, q, bh)
31-
uint3 grid [[threads_per_grid]]) {
32-
33-
int d_idx = tid.x;
34-
int q_idx = tid.y;
35-
int bh_idx = tid.z;
33+
[[kernel]] void sdpa_full_merge(
34+
const device T* chunk_os [[buffer(0)]],
35+
const device float* chunk_maxs [[buffer(1)]],
36+
const device float* chunk_sums [[buffer(2)]],
37+
device float* accum_os [[buffer(3)]],
38+
device float* accum_maxs [[buffer(4)]],
39+
device float* accum_sums [[buffer(5)]],
40+
constant int& is_first [[buffer(6)]],
41+
constant int& D [[buffer(7)]],
42+
constant int& qL [[buffer(8)]],
43+
uint3 tid [[thread_position_in_grid]],
44+
uint tid_in_tg [[thread_index_in_threadgroup]]) {
45+
46+
const int d_idx = tid.x;
47+
const int q_idx = tid.y;
48+
const int bh_idx = tid.z;
3649

3750
if (d_idx >= D || q_idx >= qL)
3851
return;
3952

40-
int bhq = bh_idx * qL + q_idx;
41-
int bhqd = bhq * D + d_idx;
42-
43-
float running_max = -INFINITY;
44-
float running_sum = 0.0f;
45-
float running_o = 0.0f;
46-
47-
int chunk_stride_bhq = int(grid.z) * qL;
48-
int chunk_stride_bhqd = int(grid.z) * qL * D;
49-
50-
for (int c = 0; c < n_chunks; c++) {
51-
float chunk_max = maxs[c * chunk_stride_bhq + bhq];
52-
float chunk_sum = sums[c * chunk_stride_bhq + bhq];
53-
float chunk_o = float(partials[c * chunk_stride_bhqd + bhqd]);
54-
55-
if (c == 0) {
56-
running_max = chunk_max;
57-
running_sum = chunk_sum;
58-
running_o = chunk_o;
59-
} else {
60-
float new_max = max(running_max, chunk_max);
61-
float scale_old = fast::exp2(running_max - new_max);
62-
float scale_new = fast::exp2(chunk_max - new_max);
63-
64-
running_o = running_o * scale_old + chunk_o * scale_new;
65-
running_sum = running_sum * scale_old + chunk_sum * scale_new;
66-
running_max = new_max;
53+
const int bhq = bh_idx * qL + q_idx;
54+
const long bhqd = long(bhq) * D + d_idx;
55+
56+
const float chunk_o = float(chunk_os[bhqd]);
57+
58+
if (is_first) {
59+
accum_os[bhqd] = chunk_o;
60+
if (d_idx == 0) {
61+
accum_maxs[bhq] = chunk_maxs[bhq];
62+
accum_sums[bhq] = chunk_sums[bhq];
6763
}
64+
return;
65+
}
66+
67+
// --- Online softmax merge (chunks 1+) ---
68+
69+
// Read shared per-row state before barrier-protected write below.
70+
const float acc_max = accum_maxs[bhq];
71+
const float acc_sum = accum_sums[bhq];
72+
const float c_max = chunk_maxs[bhq];
73+
const float c_sum = chunk_sums[bhq];
74+
75+
const float new_max = max(acc_max, c_max);
76+
const float scale_old = fast::exp2(acc_max - new_max);
77+
const float scale_new = fast::exp2(c_max - new_max);
78+
79+
// Per-element update — each thread writes a unique bhqd, no conflicts.
80+
accum_os[bhqd] = accum_os[bhqd] * scale_old + chunk_o * scale_new;
81+
82+
// Barrier: all threads must have read the old accum_maxs/sums above
83+
// before thread 0 overwrites them below.
84+
threadgroup_barrier(mem_flags::mem_device);
85+
86+
if (d_idx == 0) {
87+
accum_maxs[bhq] = new_max;
88+
accum_sums[bhq] = acc_sum * scale_old + c_sum * scale_new;
6889
}
90+
}
91+
92+
// ---------------------------------------------------------------------------
93+
// sdpa_full_finalize — normalize float32 accumulator → output in type T
94+
// ---------------------------------------------------------------------------
95+
//
96+
// Grid: (D, qL, B*H) — one thread per output element.
97+
// Group: (min(D, 256), 1, 1).
98+
//
99+
// Handles the layout transposition from contiguous BHLD accumulator
100+
// to the caller's output stride layout (typically BLHD).
101+
102+
template <typename T>
103+
[[kernel]] void sdpa_full_finalize(
104+
const device float* accum_os [[buffer(0)]],
105+
const device float* accum_sums [[buffer(1)]],
106+
device T* output [[buffer(2)]],
107+
constant int& D [[buffer(3)]],
108+
constant int& H [[buffer(4)]],
109+
constant int& qL [[buffer(5)]],
110+
constant int64_t* O_strides [[buffer(6)]],
111+
uint3 tid [[thread_position_in_grid]]) {
112+
113+
const int d_idx = tid.x;
114+
const int q_idx = tid.y;
115+
const int bh_idx = tid.z;
116+
117+
if (d_idx >= D || q_idx >= qL)
118+
return;
119+
120+
const int b = bh_idx / H;
121+
const int h = bh_idx % H;
122+
123+
// Contiguous BHLD index into accumulator
124+
const int bhq = bh_idx * qL + q_idx;
125+
const long bhqd = long(bhq) * D + d_idx;
69126

70-
output[bhqd] = T(running_o / running_sum);
127+
// Strided index into output (may be BLHD or other layout)
128+
const long out_idx = long(b) * O_strides[0] +
129+
long(h) * O_strides[1] +
130+
long(q_idx) * O_strides[2] +
131+
d_idx;
132+
133+
const float sum = accum_sums[bhq];
134+
output[out_idx] = T(accum_os[bhqd] / sum);
71135
}
72136

73-
#define instantiate_reduce(tname, dtype) \
74-
template [[host_name("sdpa_full_reduce_" #tname)]] \
75-
[[kernel]] void sdpa_full_reduce<dtype>( \
76-
const device dtype*, const device float*, const device float*, \
77-
device dtype*, constant int&, constant int&, constant int&, \
78-
uint3, uint3);
137+
// ---------------------------------------------------------------------------
138+
// Template instantiations
139+
// ---------------------------------------------------------------------------
140+
141+
#define instantiate_merge(tname, dtype) \
142+
template [[host_name("sdpa_full_merge_" #tname)]] \
143+
[[kernel]] void sdpa_full_merge<dtype>( \
144+
const device dtype*, const device float*, const device float*, \
145+
device float*, device float*, device float*, \
146+
constant int&, constant int&, constant int&, \
147+
uint3, uint);
148+
149+
#define instantiate_finalize(tname, dtype) \
150+
template [[host_name("sdpa_full_finalize_" #tname)]] \
151+
[[kernel]] void sdpa_full_finalize<dtype>( \
152+
const device float*, const device float*, \
153+
device dtype*, \
154+
constant int&, constant int&, constant int&, constant int64_t*, \
155+
uint3);
156+
157+
instantiate_merge(float16, half);
158+
instantiate_merge(bfloat16, bfloat16_t);
159+
instantiate_merge(float32, float);
79160

80-
instantiate_reduce(float16, half);
81-
instantiate_reduce(bfloat16, bfloat16_t);
82-
instantiate_reduce(float32, float);
161+
instantiate_finalize(float16, half);
162+
instantiate_finalize(bfloat16, bfloat16_t);
163+
instantiate_finalize(float32, float);
83164
// clang-format on

0 commit comments

Comments
 (0)