From 4443a33d394bedb8401c2f1716a1c3ff92affe8d Mon Sep 17 00:00:00 2001 From: Jesse Perla Date: Thu, 14 May 2026 16:31:49 -0700 Subject: [PATCH] fix: unroll _transform_tuple to work around Enzyme.jl#3104 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The recursive Base.tail fold in _transform_tuple makes Enzyme.autodiff (Forward and Reverse) throw `AssertionError("conv == 37")` from Enzyme/src/rules/jitrules.jl:2073 once the tuple has ≥ 33 entries (EnzymeAD/Enzyme.jl#3104). Replace it with a @generated straight-line unroll that produces the same outputs bit-for-bit while emitting no self-invoke in the typed IR — which is what Enzyme trips on. Verified against the full Pkg.test() suite (all Pass = Total) and a 35-entry SW07-Pfeifer-style NamedTuple prior (fwd + rev both succeed). --- src/aggregation.jl | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/src/aggregation.jl b/src/aggregation.jl index d42f57a..5dfbce6 100644 --- a/src/aggregation.jl +++ b/src/aggregation.jl @@ -392,15 +392,25 @@ $(SIGNATURES) Helper function for transforming tuples. Used internally, to help type inference. Use via `transfom_tuple`. -""" -_transform_tuple(flag::LogJacFlag, x::AbstractVector, index, ::Tuple{}) = - (), logjac_zero(flag, _ensure_float(eltype(x))), index -function _transform_tuple(flag::LogJacFlag, x::AbstractVector, index, ts) - tfirst = first(ts) - yfirst, ℓfirst, index′ = transform_with(flag, tfirst, x, index) - yrest, ℓrest, index′′ = _transform_tuple(flag, x, index′, Base.tail(ts)) - (yfirst, yrest...), ℓfirst + ℓrest, index′′ +Implemented as a `@generated` straight-line unroll over the static tuple length. +Equivalent to the natural `Base.tail` recursion, but emits non-recursive code +so that `Enzyme.autodiff` does not hit `AssertionError("conv == 37")` on +tuples of length ≥ 33 (EnzymeAD/Enzyme.jl#3104). +""" +@generated function _transform_tuple(flag::LogJacFlag, x::AbstractVector, index, + ts::Tuple{Vararg{AbstractTransform,N}}) where {N} + N == 0 && return :(((), logjac_zero(flag, _ensure_float(eltype(x))), index)) + ys = [Symbol(:y_, i) for i in 1:N] + ℓs = [Symbol(:ℓ_, i) for i in 1:N] + calls = [:(($(ys[i]), $(ℓs[i]), idx) = transform_with(flag, ts[$i], x, idx)) + for i in 1:N] + ℓ_sum = foldl((a, b) -> :($a + $b), ℓs) + return quote + idx = index + $(calls...) + (($(ys...),), $ℓ_sum, idx) + end end """