Skip to content

[WIP] Add Metal backend to Dr.Jit#495

Open
Speierers wants to merge 1 commit into
masterfrom
metal_support
Open

[WIP] Add Metal backend to Dr.Jit#495
Speierers wants to merge 1 commit into
masterfrom
metal_support

Conversation

@Speierers
Copy link
Copy Markdown
Member

Drjit: Metal backend support

This PR plumbs the Metal backend (added in the companion drjit-core PR)
through the top-level Drjit C++ traits, Python bindings, and test suite.
After this PR, user code can use dr.MetalArray<T>, dr.MetalDiffArray<T>,
and the drjit.metal / drjit.metal.ad Python modules in the same way as
their CUDA / LLVM counterparts.

This PR is a prerequisite for the Mitsuba metal_* variants (see the
companion PR in the mitsuba3 repository).

Summary

  • New MetalArray<T> / MetalDiffArray<T> C++ aliases and full trait
    support (is_metal_v, enable_if_metal_array_t, backend<T>).
  • New drjit.metal and drjit.metal.ad Python modules.
  • Backend-detection paths (dr.is_cuda_v, dr.is_llvm_v) gain a Metal
    sibling.
  • DLPack / Torch / dr.wrap interop migrate data correctly between Metal
    and other backends.
  • Drjit-core submodule bumped to the Metal-enabled commit.

C++ type system

  • include/drjit/jit.h
    • JitArray::IsMetal
    • using MetalArray = JitArray<JitBackend::Metal, T>
  • include/drjit/autodiff.h
    • DiffArray::IsMetal
    • using MetalDiffArray = DiffArray<JitBackend::Metal, T>
  • include/drjit/array_traits.h
    • is_metal_v<T>, enable_if_metal_array_t<T>
    • backend<T> specialisation that returns JitBackend::Metal
  • include/drjit/python.h
    • ArrayMeta::backend widened from 2 → 3 bits to accommodate
      JitBackend::Metal == 4. This is a small ABI change inside the
      bit-field — no public API impact, no source change required for
      downstream users — but it does mean any pre-built downstream binary
      that embeds ArrayMeta will need to be rebuilt.

Python bindings

Added symmetric to the CUDA / LLVM modules:

src/python/metal.cpp     ↔ cuda.cpp
src/python/metal.h       ↔ cuda.h
src/python/metal_ad.cpp  ↔ cuda_ad.cpp
drjit/metal/__init__.py
drjit/metal/ad.py
drjit/scalar/ad.py        # tiny stub, mirrors drjit/{cuda,llvm}/ad

After build, the new modules expose:

import drjit as dr
import drjit.metal       as m       # non-AD types
import drjit.metal.ad    as mad     # AD types

x = m.Float([1, 2, 3])              # JitArray<Metal, float>
y = mad.Float([1, 2, 3])            # DiffArray<Metal, float>
assert dr.is_metal_v(x)

Interop

drjit/interop.py was extended so that dr.wrap and the DLPack / Torch
bridges migrate data correctly when one side is on Metal:

  • DLPack: copies through host memory when crossing CUDA ↔ Metal or LLVM ↔
    Metal (Metal does not expose a DLPack-compatible device type).
  • Torch: data flows through CPU tensors; the existing fast paths for CUDA
    and CPU are unchanged.
  • dr.wrap: explicit migration when the function's backend differs from
    the input array's backend.

Test suite

  • dr.has_backend(dr.JitBackend.Metal) is now consulted in conftest
    fixtures so Metal-specific tests are skipped automatically on
    non-Apple machines.
  • dr.allclose no longer silently promotes Numpy float64 inputs to
    Float64 when the active variant is Metal (which has no FP64).
  • Several existing tests were widened with skip_on(...) markers for
    atomic kinds and operations that Metal doesn't support
    (FP64 atomics, etc.).
  • New tests added for bugs uncovered while bringing the Metal backend
    online — they exercise edge cases that were latent on CUDA / LLVM and
    are valuable independent of the Metal backend.

New dependencies

None directly. The Metal frameworks and the metal-cpp headers are
linked in via the drjit-core submodule (see the drjit-core PR for the
detailed list). No new Python or third-party-source dependencies are
introduced at this level.

Build

The Metal backend is opt-in and Apple-only. To build with it:

cmake -DDRJIT_ENABLE_METAL=ON ..

(The flag is picked up by drjit-core; the top-level Drjit CMake just
threads it through.) Builds with the flag off — including all non-Apple
platforms — produce a binary that is bit-for-bit identical to the
pre-Metal build.

Limitations

  • No FP64 on Metal. There is no MetalArray<double>. User code that
    relies on dr.Float64 must use the LLVM or CUDA backends.
  • Metal does not support zero-copy GPU↔CPU sharing like CUDA's
    cudaHostRegister, so the in-place Torch interop test is skipped on
    Metal.
  • Cooperative vectors are slower than CUDA for very large matrix
    sizes (per-thread scalar code with a simdgroup_matrix fast path; no
    WMMA equivalent on Metal).

Test plan

  • Existing C++ + Python unit tests pass on Linux / Windows (no
    regression for non-Apple users).
  • Existing tests pass on macOS with DRJIT_ENABLE_METAL=OFF.
  • On macOS with DRJIT_ENABLE_METAL=ON:
    • Tier 1–8 test suites pass (kernel history, vcalls, frozen
      functions, cooperative vectors, NN, ray tracing).
    • DLPack / Torch interop tests pass (with the documented in-place
      skip on Metal).
  • dr.has_backend(dr.JitBackend.Metal) returns True on supported
    hardware and False everywhere else.

@Speierers Speierers marked this pull request as draft April 29, 2026 13:41
@Speierers Speierers marked this pull request as ready for review May 5, 2026 07:06
Wires Dr.Jit-Core's new Metal backend (added in the bumped drjit-core
submodule) into Dr.Jit's Python and C++ layers. Mirrors the existing
CUDA / LLVM backend bindings.

Python bindings
---------------
  src/python/metal.{h,cpp}        nanobind module exposing the Metal
                                  backend's array types (Float32, Float16,
                                  Int32/64, UInt32/64, Bool, Mask, masked
                                  variants, tensor types) under
                                  drjit.metal.*. Mirrors src/python/cuda.cpp.
  src/python/metal_ad.cpp         Autodiff-enabled counterparts under
                                  drjit.metal.ad.*.
  src/python/main.cpp /
   main_v.cpp / meta.cpp /
   alias.cpp / init.cpp /
   base.cpp / dlpack.cpp /
   freeze.cpp / quat.cpp /
   resample.cpp / scalar.cpp /
   print.cpp                      Per-backend dispatch points gain Metal
                                  cases (variant resolution, has_backend
                                  checks, dlpack device IDs, freeze /
                                  RecordedThreadState wiring, etc.).
  drjit/__init__.py /
   drjit/metal/__init__.py /
   drjit/metal/ad.py /
   drjit/scalar/ad.py             Public-API surface — drjit.metal,
                                  drjit.metal.ad submodules.
  drjit/interop.py / opt.py /
   _reduce.py / stubs.pat         Backend-aware helpers (numpy/torch
                                  interop, optimizer dispatch, stub
                                  generation) updated for the new backend.

Type system
-----------
  include/drjit/array_base.h      Trait machinery extended with
   /array_traits.h /              has_metal / is_metal_v predicates and
   autodiff.h / jit.h /           appropriate JitBackend constants.
   tensor.h / python.h
  src/extra/call.cpp / math.cpp / Backend-agnostic Dr.Jit-Extra functions
   resample.cpp                   gain Metal-aware paths.

Benchmarks
----------
  drjit/bench.py                  New module: throughput / per-op benches
   docs/bench.rst                 with documentation. Used during Metal
   tests/test_bench.py            backend perf tuning, but is generic over
                                  any backend.

CMake
-----
  CMakeLists.txt /
   src/extra/CMakeLists.txt /     Adds metal.cpp + metal_ad.cpp to the
   src/python/CMakeLists.txt /    nanobind extension when APPLE +
   tests/CMakeLists.txt           DRJIT_ENABLE_METAL is on.

Tests
-----
  tests/test_freeze.py /
   test_arithmetic.py /           Existing tests parameterised over
   test_autodiff.py /             backends gain Metal entries; new
   test_format.py /               assertions for backend-specific
   test_init.py / test_memop.py / behaviour where Metal differs (e.g.
   test_rand.py / test_reduction.py /  Float64 emulation flag, lack of
   test_tensor.py / test_texture.py /  hardware texture sampler).
   test_wrap.py / etc.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant