@@ -62,8 +62,8 @@ function EnzymeRules.reverse(
6262 Aval = something (cacheA, A. val)
6363 Bval = something (cacheB, B. val)
6464
65- Δα = isnothing (AB) ? nothing : project_scalar (α . val, inner (AB, C . dval))
66- Δβ = isa (β, Const) ? nothing : pullback_dβ (C . dval , Cval, β)
65+ Δα = pullback_dα (α, C, AB)
66+ Δβ = pullback_dβ (β, C , Cval)
6767
6868 if ! isa (A, Const)
6969 blas_contract_pullback_ΔA! (
@@ -75,7 +75,7 @@ function EnzymeRules.reverse(
7575 B. dval, C. dval, Aval, pA. val, Bval, pB. val, pAB. val, α. val, backend. val, allocator. val
7676 ) # this typically returns nothing
7777 end
78- pullback_dC! (C. dval, β. val) # this typically returns nothing
78+ ! isa (C, Const) && pullback_dC! (C. dval, β. val) # this typically returns nothing
7979 return nothing , nothing , nothing , nothing , nothing , nothing , Δα, Δβ, nothing , nothing
8080end
8181
@@ -179,20 +179,8 @@ function EnzymeRules.reverse(
179179 Aval = something (A_cache, A. val)
180180 Cval = something (C_cache, C. val)
181181 ! isa (A, Const) && ! isa (C, Const) && trace_permute_pullback_ΔA! (A. dval, C. dval, Aval, p. val, q. val, α. val, backend. val)
182- Δαr = if ! isa (C, Const) && ! isnothing (At)
183- project_scalar (α. val, inner (At, C. dval))
184- elseif ! isnothing (At)
185- zero (α. val)
186- else
187- nothing
188- end
189- Δβr = if ! isa (β, Const) && ! isa (C, Const)
190- pullback_dβ (C. dval, Cval, β)
191- elseif ! isa (β, Const)
192- zero (β. val)
193- else
194- nothing
195- end
182+ Δαr = pullback_dα (α, C, At)
183+ Δβr = pullback_dβ (β, C, Cval)
196184 ! isa (C, Const) && pullback_dC! (C. dval, β. val)
197185 return nothing , nothing , nothing , nothing , Δαr, Δβr, nothing
198186end
0 commit comments