Skip to content

Commit a2d6335

Browse files
committed
fix: add head_dim=256 to fused SDPA full attention kernel
sdpa_full_supported_head_dim only included {64, 80, 128}. Models with head_dim=256 (Qwen3.5 family) fell back to the unfused naive attention path which materializes the full score matrix as a single matmul. At 32K+ context this creates 8+ GB single allocations that crash Metal's buffer allocator. Add head_dim=256 to the dispatch gate and instantiate steel_attention kernel with bd=256. The Metal kernel template handles arbitrary BD via template parameter — no kernel code changes needed. Verified: 32K, 64K, 128K context on M2 Ultra with Qwen3.5-122B-A10B.
1 parent 46c181b commit a2d6335

3 files changed

Lines changed: 51 additions & 8 deletions

File tree

mlx/backend/metal/eval.cpp

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
// Copyright © 2023-2024 Apple Inc.
2+
#include <atomic>
23
#include <memory>
4+
#include <mutex>
35

46
#include "mlx/backend/gpu/eval.h"
57
#include "mlx/backend/metal/device.h"
@@ -9,11 +11,35 @@
911

1012
namespace mlx::core::gpu {
1113

12-
void init() {}
14+
// Storage for command buffer errors from completion handlers.
15+
// Completion handlers run on GCD dispatch queues where C++ exceptions
16+
// cannot propagate — throwing from a handler calls std::terminate.
17+
// Instead, store the error and rethrow at the next synchronization point.
18+
static std::mutex error_mutex_;
19+
static std::string error_message_;
20+
static std::atomic<bool> has_error_{false};
1321

14-
void new_stream(Stream stream) {
15-
if (stream.device == mlx::core::Device::gpu) {
16-
metal::device(stream.device).get_command_encoder(stream.index);
22+
static void store_error(MTL::CommandBuffer* cbuf) {
23+
if (cbuf->status() == MTL::CommandBufferStatusError) {
24+
std::lock_guard<std::mutex> lock(error_mutex_);
25+
if (!has_error_.load()) {
26+
std::ostringstream msg;
27+
msg << "[METAL] Command buffer execution failed: "
28+
<< cbuf->error()->localizedDescription()->utf8String();
29+
error_message_ = msg.str();
30+
has_error_.store(true);
31+
}
32+
}
33+
}
34+
35+
static void check_stored_error() {
36+
if (has_error_.load()) {
37+
std::lock_guard<std::mutex> lock(error_mutex_);
38+
if (has_error_.load()) {
39+
std::string msg = std::move(error_message_);
40+
has_error_.store(false);
41+
throw std::runtime_error(msg);
42+
}
1743
}
1844
}
1945

@@ -26,7 +52,18 @@ inline void check_error(MTL::CommandBuffer* cbuf) {
2652
}
2753
}
2854

55+
void init() {}
56+
57+
void new_stream(Stream stream) {
58+
if (stream.device == mlx::core::Device::gpu) {
59+
metal::device(stream.device).get_command_encoder(stream.index);
60+
}
61+
}
62+
2963
void eval(array& arr) {
64+
// Check for errors from previous async command buffers
65+
check_stored_error();
66+
3067
auto pool = metal::new_scoped_memory_pool();
3168
auto s = arr.primitive().stream();
3269
auto& d = metal::device(s.device);
@@ -62,13 +99,13 @@ void eval(array& arr) {
6299
command_buffer->addCompletedHandler(
63100
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
64101
scheduler::notify_task_completion(s);
65-
check_error(cbuf);
102+
store_error(cbuf);
66103
});
67104
d.commit_command_buffer(s.index);
68105
} else {
69106
command_buffer->addCompletedHandler(
70107
[buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
71-
check_error(cbuf);
108+
store_error(cbuf);
72109
});
73110
}
74111
}
@@ -78,7 +115,8 @@ void finalize(Stream s) {
78115
auto& d = metal::device(s.device);
79116
auto cb = d.get_command_buffer(s.index);
80117
d.end_encoding(s.index);
81-
cb->addCompletedHandler([](MTL::CommandBuffer* cbuf) { check_error(cbuf); });
118+
cb->addCompletedHandler(
119+
[](MTL::CommandBuffer* cbuf) { store_error(cbuf); });
82120
d.commit_command_buffer(s.index);
83121
}
84122

@@ -90,7 +128,10 @@ void synchronize(Stream s) {
90128
d.end_encoding(s.index);
91129
d.commit_command_buffer(s.index);
92130
cb->waitUntilCompleted();
131+
// Check directly — we're on the calling thread, can throw safely
93132
check_error(cb);
133+
// Also check any stored errors from async handlers
134+
check_stored_error();
94135
cb->release();
95136
}
96137

mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
attention, dtype, bq, bk, bd, wm, wn, mtype, float)
1313

1414
#define instantiate_attn_shapes_helper(iname, itype, mname, mtype) \
15+
instantiate_attn(iname, itype, 32, 16, 256, 4, 1, mname, mtype) \
1516
instantiate_attn(iname, itype, 32, 16, 128, 4, 1, mname, mtype) \
1617
instantiate_attn(iname, itype, 32, 32, 80, 4, 1, mname, mtype) \
1718
instantiate_attn(iname, itype, 32, 32, 64, 4, 1, mname, mtype)

mlx/backend/metal/scaled_dot_product_attention.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,8 @@ bool ScaledDotProductAttention::use_fallback(
620620
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 ||
621621
query_head_dim == 256);
622622
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
623-
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
623+
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128 ||
624+
query_head_dim == 256);
624625

625626
const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
626627
(query_sequence_length <= key_sequence_length && do_causal);

0 commit comments

Comments
 (0)