[WIP] Add Metal backend to Dr.Jit#495
Open
Speierers wants to merge 1 commit into
Open
Conversation
Wires Dr.Jit-Core's new Metal backend (added in the bumped drjit-core
submodule) into Dr.Jit's Python and C++ layers. Mirrors the existing
CUDA / LLVM backend bindings.
Python bindings
---------------
src/python/metal.{h,cpp} nanobind module exposing the Metal
backend's array types (Float32, Float16,
Int32/64, UInt32/64, Bool, Mask, masked
variants, tensor types) under
drjit.metal.*. Mirrors src/python/cuda.cpp.
src/python/metal_ad.cpp Autodiff-enabled counterparts under
drjit.metal.ad.*.
src/python/main.cpp /
main_v.cpp / meta.cpp /
alias.cpp / init.cpp /
base.cpp / dlpack.cpp /
freeze.cpp / quat.cpp /
resample.cpp / scalar.cpp /
print.cpp Per-backend dispatch points gain Metal
cases (variant resolution, has_backend
checks, dlpack device IDs, freeze /
RecordedThreadState wiring, etc.).
drjit/__init__.py /
drjit/metal/__init__.py /
drjit/metal/ad.py /
drjit/scalar/ad.py Public-API surface — drjit.metal,
drjit.metal.ad submodules.
drjit/interop.py / opt.py /
_reduce.py / stubs.pat Backend-aware helpers (numpy/torch
interop, optimizer dispatch, stub
generation) updated for the new backend.
Type system
-----------
include/drjit/array_base.h Trait machinery extended with
/array_traits.h / has_metal / is_metal_v predicates and
autodiff.h / jit.h / appropriate JitBackend constants.
tensor.h / python.h
src/extra/call.cpp / math.cpp / Backend-agnostic Dr.Jit-Extra functions
resample.cpp gain Metal-aware paths.
Benchmarks
----------
drjit/bench.py New module: throughput / per-op benches
docs/bench.rst with documentation. Used during Metal
tests/test_bench.py backend perf tuning, but is generic over
any backend.
CMake
-----
CMakeLists.txt /
src/extra/CMakeLists.txt / Adds metal.cpp + metal_ad.cpp to the
src/python/CMakeLists.txt / nanobind extension when APPLE +
tests/CMakeLists.txt DRJIT_ENABLE_METAL is on.
Tests
-----
tests/test_freeze.py /
test_arithmetic.py / Existing tests parameterised over
test_autodiff.py / backends gain Metal entries; new
test_format.py / assertions for backend-specific
test_init.py / test_memop.py / behaviour where Metal differs (e.g.
test_rand.py / test_reduction.py / Float64 emulation flag, lack of
test_tensor.py / test_texture.py / hardware texture sampler).
test_wrap.py / etc.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Drjit: Metal backend support
This PR plumbs the Metal backend (added in the companion drjit-core PR)
through the top-level Drjit C++ traits, Python bindings, and test suite.
After this PR, user code can use
dr.MetalArray<T>,dr.MetalDiffArray<T>,and the
drjit.metal/drjit.metal.adPython modules in the same way astheir CUDA / LLVM counterparts.
This PR is a prerequisite for the Mitsuba
metal_*variants (see thecompanion PR in the
mitsuba3repository).Summary
MetalArray<T>/MetalDiffArray<T>C++ aliases and full traitsupport (
is_metal_v,enable_if_metal_array_t,backend<T>).drjit.metalanddrjit.metal.adPython modules.dr.is_cuda_v,dr.is_llvm_v) gain aMetalsibling.
dr.wrapinterop migrate data correctly between Metaland other backends.
C++ type system
include/drjit/jit.hJitArray::IsMetalusing MetalArray = JitArray<JitBackend::Metal, T>include/drjit/autodiff.hDiffArray::IsMetalusing MetalDiffArray = DiffArray<JitBackend::Metal, T>include/drjit/array_traits.his_metal_v<T>,enable_if_metal_array_t<T>backend<T>specialisation that returnsJitBackend::Metalinclude/drjit/python.hArrayMeta::backendwidened from 2 → 3 bits to accommodateJitBackend::Metal == 4. This is a small ABI change inside thebit-field — no public API impact, no source change required for
downstream users — but it does mean any pre-built downstream binary
that embeds
ArrayMetawill need to be rebuilt.Python bindings
Added symmetric to the CUDA / LLVM modules:
After build, the new modules expose:
Interop
drjit/interop.pywas extended so thatdr.wrapand the DLPack / Torchbridges migrate data correctly when one side is on Metal:
Metal (Metal does not expose a DLPack-compatible device type).
and CPU are unchanged.
dr.wrap: explicit migration when the function's backend differs fromthe input array's backend.
Test suite
dr.has_backend(dr.JitBackend.Metal)is now consulted in conftestfixtures so Metal-specific tests are skipped automatically on
non-Apple machines.
dr.allcloseno longer silently promotes Numpyfloat64inputs toFloat64when the active variant is Metal (which has no FP64).skip_on(...)markers foratomic kinds and operations that Metal doesn't support
(FP64 atomics, etc.).
online — they exercise edge cases that were latent on CUDA / LLVM and
are valuable independent of the Metal backend.
New dependencies
None directly. The Metal frameworks and the
metal-cppheaders arelinked in via the drjit-core submodule (see the drjit-core PR for the
detailed list). No new Python or third-party-source dependencies are
introduced at this level.
Build
The Metal backend is opt-in and Apple-only. To build with it:
(The flag is picked up by drjit-core; the top-level Drjit CMake just
threads it through.) Builds with the flag off — including all non-Apple
platforms — produce a binary that is bit-for-bit identical to the
pre-Metal build.
Limitations
MetalArray<double>. User code thatrelies on
dr.Float64must use the LLVM or CUDA backends.cudaHostRegister, so the in-place Torch interop test is skipped onMetal.
sizes (per-thread scalar code with a
simdgroup_matrixfast path; noWMMA equivalent on Metal).
Test plan
regression for non-Apple users).
DRJIT_ENABLE_METAL=OFF.DRJIT_ENABLE_METAL=ON:functions, cooperative vectors, NN, ray tracing).
skip on Metal).
dr.has_backend(dr.JitBackend.Metal)returnsTrueon supportedhardware and
Falseeverywhere else.