Skip to content

eugenehp/burn-mpsgraph

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

burn-mpsgraph

Apple Metal Performance Shaders Graph (MPSGraph) backend for the Burn deep learning framework.

Runs tensor computations on Apple GPUs (M1, M2, M3, M4 and later) by dispatching directly to MPSGraph — Apple's graph-based GPU compute engine that powers Core ML and Create ML.

Crates.io docs.rs License: MIT OR Apache-2.0


Quick start

# Cargo.toml
[dependencies]
burn-mpsgraph = "0.0.1"
burn           = "0.21.0-pre.2"
use burn::tensor::{Distribution, Tensor};
use burn_mpsgraph::prelude::*;

type B = MpsGraph;

let device = MpsGraphDevice::default();

let a: Tensor<B, 2> = Tensor::random([128, 64], Distribution::Default, &device);
let b: Tensor<B, 2> = Tensor::random([64, 256], Distribution::Default, &device);
let c = a.matmul(b);          // runs on Apple GPU
let data = c.into_data();     // copy result back to CPU

Design

Burn Tensor API
      │
      ▼
FloatTensorOps / IntTensorOps / BoolTensorOps / ModuleOps
      │
      ▼
bridge.rs ── builds an MPSGraph per op, feeds MTLBuffer tensors,
             runs synchronously inside an ObjC autorelease pool
      │
      ▼
ffi.rs ── raw objc_msgSend calls to Metal, Foundation, MPS, MPSGraph
      │
      ▼
Apple GPU (Metal)

GPU-resident tensors

Each MpsGraphTensor wraps a retained MTLBuffer pointer.
On Apple Silicon the buffer uses shared memory — CPU and GPU share the same physical pages with no copy overhead.
Data is only serialised to Vec<u8> when you call into_data().

No objc2 dependency

All Apple framework calls go through direct objc_msgSend FFI in ffi.rs. This removes the entire objc2 / objc2-foundation / objc2-metal family from the dependency tree.

burn-mpsgraph alternative (objc2-based)
Runtime deps 4 30+
cargo build time ~1 s ~20 s

Supported operations

Category Operations
Arithmetic add, sub, mul, div, remainder, powf, matmul
Unary math exp, log, log1p, sqrt, abs, sin, cos, tan, asin, acos, atan, sinh, cosh, tanh, asinh, acosh, atanh, atan2, erf, recip, floor, ceil, round, trunc
Comparisons equal, not_equal, greater, greater_equal, lower, lower_equal (tensor and scalar)
Reductions sum, prod, mean, max, min, argmax, argmin — full tensor and per-axis
Cumulative cumsum, cumprod, cummin, cummax
Sort sort, argsort (ascending / descending)
Shape reshape, transpose, permute, flip, slice, slice_assign, cat, expand, unfold
Masking mask_where, mask_fill
Gather / scatter gather, scatter_add, select, select_add
Convolution conv1d, conv2d, conv3d, conv_transpose1d/2d/3d, deform_conv2d
Pooling avg_pool2d (+ backward), max_pool2d (+ with_indices + backward), adaptive_avg_pool2d
Interpolation nearest, bilinear (+ backward)
Attention Fused scaled dot-product (single MPSGraph execution)
Embedding forward + backward
Int full arithmetic, cumulative, sort, bitwise (and, or, xor, not, left/right shift), cast
Bool and, or, not, equal, scatter_or, cast-to-int/float
Quantization Basic per-tensor symmetric quantize / dequantize

Supported dtypes

DType Storage Arithmetic GPU-accelerated
F32
F16
BF16
I32
I16
I8
I64
U8–U64
Bool
F64 — (Metal unsupported; panics with a clear message)

Benchmarks

Benchmarks run on Apple M3 Pro against:

  • burn-wgpu with the metal feature (wgpu → MSL → Metal)
  • burn-ndarray (CPU reference)

Matmul (f32, square matrices)

Size MPSGraph metal-wgpu ndarray (CPU)
64×64 1.9 ms ~2 ms 45 µs
256×256 2.4 ms ~3 ms 250 µs
512×512 4.6 ms ~5 ms 1.1 ms
1024×1024 8.0 ms ~9 ms 6.6 ms
2048×2048 14 ms ~15 ms 50 ms
4096×4096 30 ms ~29 ms 410 ms

At 4096×4096 both GPU backends are ~14× faster than CPU.

Element-wise add

Elements MPSGraph metal-wgpu ndarray
10K 2.0 ms ~1 ms 6 µs
100K 2.4 ms ~1 ms 30 µs
1M 8.5 ms ~2 ms 300 µs

Note: For small ops the per-graph-execution overhead (~2 ms) dominates. This is the primary area for future improvement — see Roadmap.


Examples

cargo run --example basic      # arithmetic, reductions, math, shapes
cargo run --example inference  # 2-layer MLP with softmax
cargo run --example conv       # conv2d + maxpool2d feature extractor

Seeded RNG

MpsGraph::seed(&device, 42);
let t: Tensor<B, 2> = Tensor::random([4, 4], Distribution::Default, &device);

The seed is stored globally per process. Two calls with the same seed produce identical tensors.


Platform requirements

Requirement Version
macOS 12.3+ (Monterey)
Xcode Command Line Tools 14+
Apple Silicon or AMD GPU M1 / M2 / M3 / M4 or supported AMD

The build.rs links these frameworks automatically: Foundation, Metal, MetalPerformanceShaders, MetalPerformanceShadersGraph.


Roadmap

  • Lazy / deferred execution — accumulate ops into one graph before running. This would eliminate the ~2 ms per-op overhead and make element-wise chains as fast as native MPSGraph programs.
  • MTLBuffer pool — reuse freed buffers instead of allocating new ones.
  • Async execution — use encodeToCommandBuffer to pipeline work.
  • BF16 arithmetic — currently storage-only.
  • F64 via emulation — two F32 ops for reduce_sum / mean.
  • Multi-device — map MpsGraphDevice { index } to discrete GPUs on Mac Pro.

License

Licensed under either of

at your option.

About

Apple MPSGraph backend for the Burn deep learning framework

Topics

Resources

License

Unknown, MIT licenses found

Licenses found

Unknown
LICENSE-APACHE
MIT
LICENSE-MIT

Stars

Watchers

Forks

Releases

No releases published

Contributors

Languages