Skip to content

Commit 45bbbdb

Browse files
committed
A bit of cleanup for alpha and beta
1 parent 8f0486a commit 45bbbdb

5 files changed

Lines changed: 27 additions & 90 deletions

File tree

ext/TensorKitEnzymeExt/indexmanipulations.jl

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -59,20 +59,8 @@ for transform in (:permute, :transpose)
5959
TK.$add_transform!(A.dval, C.dval, pΔA, conj.val), One(), bavs...)
6060
end
6161
end
62-
Δαr = if !isnothing(Ap) && !isa(C, Const)
63-
project_scalar.val, inner(Ap, C.dval))
64-
elseif !isnothing(Ap)
65-
zero.val)
66-
else
67-
nothing
68-
end
69-
Δβr = if !isa(C, Const) && !isa(β, Const)
70-
pullback_dβ(C.dval, Cval, β)
71-
elseif !isa(β, Const)
72-
zero.val)
73-
else
74-
nothing
75-
end
62+
Δα = pullback_dα(α, C, Ap)
63+
Δβ = pullback_dβ(β, C, Cval)
7664
!isa(C, Const) && pullback_dC!(C.dval, β.val)
7765
return nothing, nothing, nothing, Δαr, Δβr, map(Returns(nothing), ba)...
7866
end
@@ -140,20 +128,8 @@ function EnzymeRules.reverse(
140128
TK.add_braid!(A.dval, C.dval, pΔA, ilevels, conj.val), One(), bavs...)
141129
end
142130
end
143-
Δαr = if !isnothing(Ap) && !isa(C, Const)
144-
project_scalar.val, inner(Ap, C.dval))
145-
elseif !isnothing(Ap)
146-
zero.val)
147-
else
148-
nothing
149-
end
150-
Δβr = if !isa(C, Const) && !isa(β, Const)
151-
pullback_dβ(C.dval, Cval, β)
152-
elseif !isa(β, Const)
153-
zero.val)
154-
else
155-
nothing
156-
end
131+
Δαr = pullback_dα(α, C, Ap)
132+
Δβr = pullback_dβ(β, C, Cval)
157133
!isa(C, Const) && pullback_dC!(C.dval, β.val)
158134
return nothing, nothing, nothing, nothing, Δαr, Δβr, map(Returns(nothing), ba)...
159135
end

ext/TensorKitEnzymeExt/linalg.jl

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
# Shared
22
# ------
3-
pullback_dC!(ΔC, β) = scale!(ΔC, conj(β))
4-
pullback_dβ(ΔC, C, β) = !isa(β, Const) ? project_scalar.val, inner(C, ΔC)) : nothing
5-
63
# Can Enzyme do this itself? Apparently not...
74
function EnzymeRules.augmented_primal(
85
config::EnzymeRules.RevConfigWidth{1},
@@ -54,20 +51,8 @@ function EnzymeRules.reverse(
5451

5552
!isa(A, Const) && !isa(C, Const) && project_mul!(A.dval, C.dval, Bval', conj.val))
5653
!isa(B, Const) && !isa(C, Const) && project_mul!(B.dval, Aval', C.dval, conj.val))
57-
Δαr = if !isnothing(AB) && !isa(C, Const)
58-
project_scalar.val, inner(AB, C.dval))
59-
elseif !isnothing(AB)
60-
zero.val)
61-
else
62-
nothing
63-
end
64-
Δβr = if !isa(β, Const) && !isa(C, Const)
65-
pullback_dβ(C.dval, Cval, β)
66-
elseif !isa(β, Const)
67-
zero.val)
68-
else
69-
nothing
70-
end
54+
Δαr = pullback_dα(α, C, AB)
55+
Δβr = pullback_dβ(β, C, Cval)
7156
!isa(C, Const) && pullback_dC!(C.dval, β.val)
7257

7358
return (nothing, nothing, nothing, Δαr, Δβr)

ext/TensorKitEnzymeExt/tensoroperations.jl

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ function EnzymeRules.reverse(
6262
Aval = something(cacheA, A.val)
6363
Bval = something(cacheB, B.val)
6464

65-
Δα = isnothing(AB) ? nothing : project_scalar.val, inner(AB, C.dval))
66-
Δβ = isa(β, Const) ? nothing : pullback_dβ(C.dval, Cval, β)
65+
Δα = pullback_dα(α, C, AB)
66+
Δβ = pullback_dβ(β, C, Cval)
6767

6868
if !isa(A, Const)
6969
blas_contract_pullback_ΔA!(
@@ -75,7 +75,7 @@ function EnzymeRules.reverse(
7575
B.dval, C.dval, Aval, pA.val, Bval, pB.val, pAB.val, α.val, backend.val, allocator.val
7676
) # this typically returns nothing
7777
end
78-
pullback_dC!(C.dval, β.val) # this typically returns nothing
78+
!isa(C, Const) && pullback_dC!(C.dval, β.val) # this typically returns nothing
7979
return nothing, nothing, nothing, nothing, nothing, nothing, Δα, Δβ, nothing, nothing
8080
end
8181

@@ -179,20 +179,8 @@ function EnzymeRules.reverse(
179179
Aval = something(A_cache, A.val)
180180
Cval = something(C_cache, C.val)
181181
!isa(A, Const) && !isa(C, Const) && trace_permute_pullback_ΔA!(A.dval, C.dval, Aval, p.val, q.val, α.val, backend.val)
182-
Δαr = if !isa(C, Const) && !isnothing(At)
183-
project_scalar.val, inner(At, C.dval))
184-
elseif !isnothing(At)
185-
zero.val)
186-
else
187-
nothing
188-
end
189-
Δβr = if !isa(β, Const) && !isa(C, Const)
190-
pullback_dβ(C.dval, Cval, β)
191-
elseif !isa(β, Const)
192-
zero.val)
193-
else
194-
nothing
195-
end
182+
Δαr = pullback_dα(α, C, At)
183+
Δβr = pullback_dβ(β, C, Cval)
196184
!isa(C, Const) && pullback_dC!(C.dval, β.val)
197185
return nothing, nothing, nothing, nothing, Δαr, Δβr, nothing
198186
end

ext/TensorKitEnzymeExt/utility.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,17 @@
11
# Projection
22
# ----------
3+
pullback_dα::Const, C::Const, A) = nothing
4+
pullback_dα::Const, C::Annotation, A) = nothing
5+
pullback_dα::Annotation, C::Const, A) = zero.val)
6+
pullback_dα::Annotation, C::Annotation, A) = project_scalar.val, inner(A, C.dval))
7+
8+
pullback_dβ::Const, C::Const, Ccache) = nothing
9+
pullback_dβ::Const, C::Annotation, Ccache) = nothing
10+
pullback_dβ::Annotation, C::Const, Ccache) = zero.val)
11+
pullback_dβ::Annotation, C::Annotation, Ccache) = project_scalar.val, inner(Ccache, C.dval))
12+
13+
pullback_dC!(ΔC, β::Number) = scale!(ΔC, conj(β))
14+
315
"""
416
project_scalar(x::Number, dx::Number)
517

ext/TensorKitEnzymeExt/vectorinterface.jl

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,7 @@ function EnzymeRules.reverse(
2121
α::Annotation{<:Number},
2222
) where {RT}
2323
Cval = something(cache, C.val)
24-
Δα = if !isa(α, Const) && !isa(C, Const)
25-
project_scalar.val, inner(Cval, C.dval))
26-
elseif !isa(α, Const)
27-
zero.val)
28-
else
29-
nothing
30-
end
24+
Δα = pullback_dα(α, C, Cval)
3125
!isa(C, Const) && scale!(C.dval, conj.val))
3226
return (nothing, Δα)
3327
end
@@ -58,13 +52,7 @@ function EnzymeRules.reverse(
5852
α::Annotation{<:Number},
5953
) where {RT}
6054
Aval = something(cache, A.val)
61-
Δα = if !isa(α, Const) && !isa(C, Const)
62-
project_scalar.val, inner(Aval, C.dval))
63-
elseif !isa(α, Const)
64-
zero.val)
65-
else
66-
nothing
67-
end
55+
Δα = pullback_dα(α, C, Aval)
6856
!isa(A, Const) && !isa(C, Const) && add!(A.dval, C.dval, conj.val))
6957
!isa(C, Const) && zerovector!(C.dval)
7058
return (nothing, nothing, Δα)
@@ -101,20 +89,8 @@ function EnzymeRules.reverse(
10189
A_cache, C_cache = cache
10290
Aval = something(A_cache, A.val)
10391
Cval = something(C_cache, C.val)
104-
Δα = if !isa(α, Const) && !isa(C, Const)
105-
project_scalar.val, inner(Aval, C.dval))
106-
elseif !isa(α, Const)
107-
zero.val)
108-
else
109-
nothing
110-
end
111-
Δβ = if !isa(β, Const) && !isa(C, Const)
112-
project_scalar.val, inner(Cval, C.dval))
113-
elseif !isa(β, Const)
114-
zero.val)
115-
else
116-
nothing
117-
end
92+
Δα = pullback_dα(α, C, Aval)
93+
Δβ = pullback_dβ(β, C, Cval)
11894
!isa(A, Const) && !isa(C, Const) && add!(A.dval, C.dval, conj.val))
11995
!isa(C, Const) && scale!(C.dval, conj.val))
12096
return (nothing, nothing, Δα, Δβ)

0 commit comments

Comments
 (0)