|
1 | 1 | // Copyright © 2026 Apple Inc. (contributed by Thump604) |
2 | 2 | // |
3 | | -// Reduction kernel for 2-pass chunked full attention. |
4 | | -// Merges partial (O, max, sum) from multiple key-range chunks |
5 | | -// using the online softmax algorithm. |
| 3 | +// Merge and finalize kernels for 2-pass chunked full attention. |
| 4 | +// |
| 5 | +// Streaming online softmax: after each chunk's steel_attention dispatch, |
| 6 | +// sdpa_full_merge folds the chunk's partial results into a float32 |
| 7 | +// accumulator. After all chunks, sdpa_full_finalize normalizes and |
| 8 | +// writes the output in the caller's stride layout. |
| 9 | +// |
| 10 | +// Float32 accumulation eliminates the precision loss that would occur |
| 11 | +// from storing intermediate results in half/bfloat16. |
6 | 12 |
|
7 | 13 | // clang-format off |
8 | 14 | #include "mlx/backend/metal/kernels/utils.h" |
9 | 15 |
|
10 | 16 | using namespace metal; |
11 | 17 |
|
12 | | -// Merge partial attention results using online softmax. |
13 | | -// For each query position: |
14 | | -// new_max = max(max_a, max_b) |
15 | | -// scale_a = exp2(max_a - new_max) |
16 | | -// scale_b = exp2(max_b - new_max) |
17 | | -// O = O_a * scale_a + O_b * scale_b |
18 | | -// sum = sum_a * scale_a + sum_b * scale_b |
19 | | -// After all chunks merged: O_final = O / sum |
| 18 | +// --------------------------------------------------------------------------- |
| 19 | +// sdpa_full_merge — fold one chunk into the running float32 accumulator |
| 20 | +// --------------------------------------------------------------------------- |
| 21 | +// |
| 22 | +// Grid: (D, qL, B*H) — one thread per output element. |
| 23 | +// Group: (D, 1, 1) — all D threads for a query row in one threadgroup. |
| 24 | +// Requires D <= 256 (true for all current models). |
| 25 | +// |
| 26 | +// For the first chunk (is_first=1): copies chunk → accum. |
| 27 | +// For subsequent chunks: online softmax merge. |
| 28 | +// |
| 29 | +// A threadgroup_barrier separates reads of accum_maxs from the write by |
| 30 | +// thread 0, preventing a race between SIMD groups in the same threadgroup. |
20 | 31 |
|
21 | 32 | template <typename T> |
22 | | -[[kernel]] void sdpa_full_reduce( |
23 | | - const device T* partials [[buffer(0)]], // [n_chunks, B*H*qL*D] |
24 | | - const device float* maxs [[buffer(1)]], // [n_chunks, B*H*qL] |
25 | | - const device float* sums [[buffer(2)]], // [n_chunks, B*H*qL] |
26 | | - device T* output [[buffer(3)]], // [B*H*qL*D] |
27 | | - constant int& n_chunks [[buffer(4)]], |
28 | | - constant int& D [[buffer(5)]], |
29 | | - constant int& qL [[buffer(6)]], |
30 | | - uint3 tid [[thread_position_in_grid]], // (d, q, bh) |
31 | | - uint3 grid [[threads_per_grid]]) { |
32 | | - |
33 | | - int d_idx = tid.x; |
34 | | - int q_idx = tid.y; |
35 | | - int bh_idx = tid.z; |
| 33 | +[[kernel]] void sdpa_full_merge( |
| 34 | + const device T* chunk_os [[buffer(0)]], |
| 35 | + const device float* chunk_maxs [[buffer(1)]], |
| 36 | + const device float* chunk_sums [[buffer(2)]], |
| 37 | + device float* accum_os [[buffer(3)]], |
| 38 | + device float* accum_maxs [[buffer(4)]], |
| 39 | + device float* accum_sums [[buffer(5)]], |
| 40 | + constant int& is_first [[buffer(6)]], |
| 41 | + constant int& D [[buffer(7)]], |
| 42 | + constant int& qL [[buffer(8)]], |
| 43 | + uint3 tid [[thread_position_in_grid]], |
| 44 | + uint tid_in_tg [[thread_index_in_threadgroup]]) { |
| 45 | + |
| 46 | + const int d_idx = tid.x; |
| 47 | + const int q_idx = tid.y; |
| 48 | + const int bh_idx = tid.z; |
36 | 49 |
|
37 | 50 | if (d_idx >= D || q_idx >= qL) |
38 | 51 | return; |
39 | 52 |
|
40 | | - int bhq = bh_idx * qL + q_idx; |
41 | | - int bhqd = bhq * D + d_idx; |
42 | | - |
43 | | - float running_max = -INFINITY; |
44 | | - float running_sum = 0.0f; |
45 | | - float running_o = 0.0f; |
46 | | - |
47 | | - int chunk_stride_bhq = int(grid.z) * qL; |
48 | | - int chunk_stride_bhqd = int(grid.z) * qL * D; |
49 | | - |
50 | | - for (int c = 0; c < n_chunks; c++) { |
51 | | - float chunk_max = maxs[c * chunk_stride_bhq + bhq]; |
52 | | - float chunk_sum = sums[c * chunk_stride_bhq + bhq]; |
53 | | - float chunk_o = float(partials[c * chunk_stride_bhqd + bhqd]); |
54 | | - |
55 | | - if (c == 0) { |
56 | | - running_max = chunk_max; |
57 | | - running_sum = chunk_sum; |
58 | | - running_o = chunk_o; |
59 | | - } else { |
60 | | - float new_max = max(running_max, chunk_max); |
61 | | - float scale_old = fast::exp2(running_max - new_max); |
62 | | - float scale_new = fast::exp2(chunk_max - new_max); |
63 | | - |
64 | | - running_o = running_o * scale_old + chunk_o * scale_new; |
65 | | - running_sum = running_sum * scale_old + chunk_sum * scale_new; |
66 | | - running_max = new_max; |
| 53 | + const int bhq = bh_idx * qL + q_idx; |
| 54 | + const long bhqd = long(bhq) * D + d_idx; |
| 55 | + |
| 56 | + const float chunk_o = float(chunk_os[bhqd]); |
| 57 | + |
| 58 | + if (is_first) { |
| 59 | + accum_os[bhqd] = chunk_o; |
| 60 | + if (d_idx == 0) { |
| 61 | + accum_maxs[bhq] = chunk_maxs[bhq]; |
| 62 | + accum_sums[bhq] = chunk_sums[bhq]; |
67 | 63 | } |
| 64 | + return; |
| 65 | + } |
| 66 | + |
| 67 | + // --- Online softmax merge (chunks 1+) --- |
| 68 | + |
| 69 | + // Read shared per-row state before barrier-protected write below. |
| 70 | + const float acc_max = accum_maxs[bhq]; |
| 71 | + const float acc_sum = accum_sums[bhq]; |
| 72 | + const float c_max = chunk_maxs[bhq]; |
| 73 | + const float c_sum = chunk_sums[bhq]; |
| 74 | + |
| 75 | + const float new_max = max(acc_max, c_max); |
| 76 | + const float scale_old = fast::exp2(acc_max - new_max); |
| 77 | + const float scale_new = fast::exp2(c_max - new_max); |
| 78 | + |
| 79 | + // Per-element update — each thread writes a unique bhqd, no conflicts. |
| 80 | + accum_os[bhqd] = accum_os[bhqd] * scale_old + chunk_o * scale_new; |
| 81 | + |
| 82 | + // Barrier: all threads must have read the old accum_maxs/sums above |
| 83 | + // before thread 0 overwrites them below. |
| 84 | + threadgroup_barrier(mem_flags::mem_device); |
| 85 | + |
| 86 | + if (d_idx == 0) { |
| 87 | + accum_maxs[bhq] = new_max; |
| 88 | + accum_sums[bhq] = acc_sum * scale_old + c_sum * scale_new; |
68 | 89 | } |
| 90 | +} |
| 91 | + |
| 92 | +// --------------------------------------------------------------------------- |
| 93 | +// sdpa_full_finalize — normalize float32 accumulator → output in type T |
| 94 | +// --------------------------------------------------------------------------- |
| 95 | +// |
| 96 | +// Grid: (D, qL, B*H) — one thread per output element. |
| 97 | +// Group: (min(D, 256), 1, 1). |
| 98 | +// |
| 99 | +// Handles the layout transposition from contiguous BHLD accumulator |
| 100 | +// to the caller's output stride layout (typically BLHD). |
| 101 | + |
| 102 | +template <typename T> |
| 103 | +[[kernel]] void sdpa_full_finalize( |
| 104 | + const device float* accum_os [[buffer(0)]], |
| 105 | + const device float* accum_sums [[buffer(1)]], |
| 106 | + device T* output [[buffer(2)]], |
| 107 | + constant int& D [[buffer(3)]], |
| 108 | + constant int& H [[buffer(4)]], |
| 109 | + constant int& qL [[buffer(5)]], |
| 110 | + constant int64_t* O_strides [[buffer(6)]], |
| 111 | + uint3 tid [[thread_position_in_grid]]) { |
| 112 | + |
| 113 | + const int d_idx = tid.x; |
| 114 | + const int q_idx = tid.y; |
| 115 | + const int bh_idx = tid.z; |
| 116 | + |
| 117 | + if (d_idx >= D || q_idx >= qL) |
| 118 | + return; |
| 119 | + |
| 120 | + const int b = bh_idx / H; |
| 121 | + const int h = bh_idx % H; |
| 122 | + |
| 123 | + // Contiguous BHLD index into accumulator |
| 124 | + const int bhq = bh_idx * qL + q_idx; |
| 125 | + const long bhqd = long(bhq) * D + d_idx; |
69 | 126 |
|
70 | | - output[bhqd] = T(running_o / running_sum); |
| 127 | + // Strided index into output (may be BLHD or other layout) |
| 128 | + const long out_idx = long(b) * O_strides[0] + |
| 129 | + long(h) * O_strides[1] + |
| 130 | + long(q_idx) * O_strides[2] + |
| 131 | + d_idx; |
| 132 | + |
| 133 | + const float sum = accum_sums[bhq]; |
| 134 | + output[out_idx] = T(accum_os[bhqd] / sum); |
71 | 135 | } |
72 | 136 |
|
73 | | -#define instantiate_reduce(tname, dtype) \ |
74 | | - template [[host_name("sdpa_full_reduce_" #tname)]] \ |
75 | | - [[kernel]] void sdpa_full_reduce<dtype>( \ |
76 | | - const device dtype*, const device float*, const device float*, \ |
77 | | - device dtype*, constant int&, constant int&, constant int&, \ |
78 | | - uint3, uint3); |
| 137 | +// --------------------------------------------------------------------------- |
| 138 | +// Template instantiations |
| 139 | +// --------------------------------------------------------------------------- |
| 140 | + |
| 141 | +#define instantiate_merge(tname, dtype) \ |
| 142 | + template [[host_name("sdpa_full_merge_" #tname)]] \ |
| 143 | + [[kernel]] void sdpa_full_merge<dtype>( \ |
| 144 | + const device dtype*, const device float*, const device float*, \ |
| 145 | + device float*, device float*, device float*, \ |
| 146 | + constant int&, constant int&, constant int&, \ |
| 147 | + uint3, uint); |
| 148 | + |
| 149 | +#define instantiate_finalize(tname, dtype) \ |
| 150 | + template [[host_name("sdpa_full_finalize_" #tname)]] \ |
| 151 | + [[kernel]] void sdpa_full_finalize<dtype>( \ |
| 152 | + const device float*, const device float*, \ |
| 153 | + device dtype*, \ |
| 154 | + constant int&, constant int&, constant int&, constant int64_t*, \ |
| 155 | + uint3); |
| 156 | + |
| 157 | +instantiate_merge(float16, half); |
| 158 | +instantiate_merge(bfloat16, bfloat16_t); |
| 159 | +instantiate_merge(float32, float); |
79 | 160 |
|
80 | | -instantiate_reduce(float16, half); |
81 | | -instantiate_reduce(bfloat16, bfloat16_t); |
82 | | -instantiate_reduce(float32, float); |
| 161 | +instantiate_finalize(float16, half); |
| 162 | +instantiate_finalize(bfloat16, bfloat16_t); |
| 163 | +instantiate_finalize(float32, float); |
83 | 164 | // clang-format on |
0 commit comments