Apple Metal Performance Shaders Graph (MPSGraph) backend for the Burn deep learning framework.
Runs tensor computations on Apple GPUs (M1, M2, M3, M4 and later) by dispatching directly to MPSGraph — Apple's graph-based GPU compute engine that powers Core ML and Create ML.
# Cargo.toml
[dependencies]
burn-mpsgraph = "0.0.1"
burn = "0.21.0-pre.2"use burn::tensor::{Distribution, Tensor};
use burn_mpsgraph::prelude::*;
type B = MpsGraph;
let device = MpsGraphDevice::default();
let a: Tensor<B, 2> = Tensor::random([128, 64], Distribution::Default, &device);
let b: Tensor<B, 2> = Tensor::random([64, 256], Distribution::Default, &device);
let c = a.matmul(b); // runs on Apple GPU
let data = c.into_data(); // copy result back to CPUBurn Tensor API
│
▼
FloatTensorOps / IntTensorOps / BoolTensorOps / ModuleOps
│
▼
bridge.rs ── builds an MPSGraph per op, feeds MTLBuffer tensors,
runs synchronously inside an ObjC autorelease pool
│
▼
ffi.rs ── raw objc_msgSend calls to Metal, Foundation, MPS, MPSGraph
│
▼
Apple GPU (Metal)
Each MpsGraphTensor wraps a retained MTLBuffer pointer.
On Apple Silicon the buffer uses shared memory — CPU and GPU share the
same physical pages with no copy overhead.
Data is only serialised to Vec<u8> when you call into_data().
All Apple framework calls go through direct objc_msgSend FFI in ffi.rs.
This removes the entire objc2 / objc2-foundation / objc2-metal family
from the dependency tree.
| burn-mpsgraph | alternative (objc2-based) | |
|---|---|---|
| Runtime deps | 4 | 30+ |
cargo build time |
~1 s | ~20 s |
| Category | Operations |
|---|---|
| Arithmetic | add, sub, mul, div, remainder, powf, matmul |
| Unary math | exp, log, log1p, sqrt, abs, sin, cos, tan, asin, acos, atan, sinh, cosh, tanh, asinh, acosh, atanh, atan2, erf, recip, floor, ceil, round, trunc |
| Comparisons | equal, not_equal, greater, greater_equal, lower, lower_equal (tensor and scalar) |
| Reductions | sum, prod, mean, max, min, argmax, argmin — full tensor and per-axis |
| Cumulative | cumsum, cumprod, cummin, cummax |
| Sort | sort, argsort (ascending / descending) |
| Shape | reshape, transpose, permute, flip, slice, slice_assign, cat, expand, unfold |
| Masking | mask_where, mask_fill |
| Gather / scatter | gather, scatter_add, select, select_add |
| Convolution | conv1d, conv2d, conv3d, conv_transpose1d/2d/3d, deform_conv2d |
| Pooling | avg_pool2d (+ backward), max_pool2d (+ with_indices + backward), adaptive_avg_pool2d |
| Interpolation | nearest, bilinear (+ backward) |
| Attention | Fused scaled dot-product (single MPSGraph execution) |
| Embedding | forward + backward |
| Int | full arithmetic, cumulative, sort, bitwise (and, or, xor, not, left/right shift), cast |
| Bool | and, or, not, equal, scatter_or, cast-to-int/float |
| Quantization | Basic per-tensor symmetric quantize / dequantize |
| DType | Storage | Arithmetic | GPU-accelerated |
|---|---|---|---|
| F32 | ✓ | ✓ | ✓ |
| F16 | ✓ | ✓ | ✓ |
| BF16 | ✓ | ✓ | |
| I32 | ✓ | ✓ | |
| I16 | ✓ | ✓ | |
| I8 | ✓ | ✓ | |
| I64 | ✓ | ✓ | |
| U8–U64 | ✓ | ✓ | |
| Bool | ✓ | ✓ | |
| F64 | — | — | — (Metal unsupported; panics with a clear message) |
Benchmarks run on Apple M3 Pro against:
burn-wgpuwith themetalfeature (wgpu → MSL → Metal)burn-ndarray(CPU reference)
| Size | MPSGraph | metal-wgpu | ndarray (CPU) |
|---|---|---|---|
| 64×64 | 1.9 ms | ~2 ms | 45 µs |
| 256×256 | 2.4 ms | ~3 ms | 250 µs |
| 512×512 | 4.6 ms | ~5 ms | 1.1 ms |
| 1024×1024 | 8.0 ms | ~9 ms | 6.6 ms |
| 2048×2048 | 14 ms | ~15 ms | 50 ms |
| 4096×4096 | 30 ms | ~29 ms | 410 ms |
At 4096×4096 both GPU backends are ~14× faster than CPU.
| Elements | MPSGraph | metal-wgpu | ndarray |
|---|---|---|---|
| 10K | 2.0 ms | ~1 ms | 6 µs |
| 100K | 2.4 ms | ~1 ms | 30 µs |
| 1M | 8.5 ms | ~2 ms | 300 µs |
Note: For small ops the per-graph-execution overhead (~2 ms) dominates. This is the primary area for future improvement — see Roadmap.
cargo run --example basic # arithmetic, reductions, math, shapes
cargo run --example inference # 2-layer MLP with softmax
cargo run --example conv # conv2d + maxpool2d feature extractor
MpsGraph::seed(&device, 42);
let t: Tensor<B, 2> = Tensor::random([4, 4], Distribution::Default, &device);The seed is stored globally per process. Two calls with the same seed produce identical tensors.
| Requirement | Version |
|---|---|
| macOS | 12.3+ (Monterey) |
| Xcode Command Line Tools | 14+ |
| Apple Silicon or AMD GPU | M1 / M2 / M3 / M4 or supported AMD |
The build.rs links these frameworks automatically:
Foundation, Metal, MetalPerformanceShaders, MetalPerformanceShadersGraph.
- Lazy / deferred execution — accumulate ops into one graph before running. This would eliminate the ~2 ms per-op overhead and make element-wise chains as fast as native MPSGraph programs.
- MTLBuffer pool — reuse freed buffers instead of allocating new ones.
- Async execution — use
encodeToCommandBufferto pipeline work. - BF16 arithmetic — currently storage-only.
- F64 via emulation — two F32 ops for
reduce_sum/mean. - Multi-device — map
MpsGraphDevice { index }to discrete GPUs on Mac Pro.
Licensed under either of
at your option.