Skip to content

Commit 349993f

Browse files
committed
More factorizations support
1 parent 35b8cf3 commit 349993f

7 files changed

Lines changed: 171 additions & 40 deletions

File tree

ext/TensorKitEnzymeExt/factorizations.jl

Lines changed: 105 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,5 @@
1-
function EnzymeRules.reverse(
2-
config::EnzymeRules.RevConfigWidth{1},
3-
func::Const{typeof(MatrixAlgebraKit.copy_input)},
4-
::Type{RT},
5-
cache,
6-
f::Annotation,
7-
A::Annotation{<:AbstractTensorMap}
8-
) where {RT}
9-
copy_shadow = cache
10-
if !isa(A, Const) && !isnothing(copy_shadow)
11-
add!(A.dval, copy_shadow)
12-
end
13-
return (nothing, nothing)
14-
end
15-
161
# need these due to Enzyme choking on blocks
2+
173
for f in (:project_hermitian, :project_antihermitian)
184
f! = Symbol(f, :!)
195
@eval begin
@@ -87,6 +73,10 @@ for (f, pb) in (
8773
(:eigh_full, :(MatrixAlgebraKit.eigh_pullback!)),
8874
(:lq_compact, :(MatrixAlgebraKit.lq_pullback!)),
8975
(:qr_compact, :(MatrixAlgebraKit.qr_pullback!)),
76+
(:lq_full, :(MatrixAlgebraKit.lq_pullback!)),
77+
(:qr_full, :(MatrixAlgebraKit.qr_pullback!)),
78+
(:lq_null, :(MatrixAlgebraKit.lq_null_pullback!)),
79+
(:qr_null, :(MatrixAlgebraKit.qr_null_pullback!)),
9080
)
9181
@eval begin
9282
function EnzymeRules.augmented_primal(
@@ -116,6 +106,40 @@ for (f, pb) in (
116106
end
117107
end
118108

109+
for (f, f_full, pb) in (
110+
(:eig_vals, :eig_full, :(MatrixAlgebraKit.eig_vals_pullback!)),
111+
(:eigh_vals, :eigh_full, :(MatrixAlgebraKit.eigh_vals_pullback!)),
112+
)
113+
@eval begin
114+
function EnzymeRules.augmented_primal(
115+
config::EnzymeRules.RevConfigWidth{1},
116+
func::Const{typeof($f)},
117+
::Type{RT},
118+
A::Annotation{<:AbstractTensorMap},
119+
alg::Const,
120+
) where {RT}
121+
ret_full = $f_full(A.val, alg.val)
122+
ret = diagview(ret_full[1])
123+
primal = EnzymeRules.needs_primal(config) ? ret : nothing
124+
shadow = EnzymeRules.needs_shadow(config) ? make_zero(ret) : nothing
125+
cache = (ret, shadow, ret_full[2])
126+
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
127+
end
128+
function EnzymeRules.reverse(
129+
config::EnzymeRules.RevConfigWidth{1},
130+
func::Const{typeof($f)},
131+
::Type{RT},
132+
cache,
133+
A::Annotation{<:AbstractTensorMap},
134+
alg::Const,
135+
) where {RT}
136+
D, dD, V = cache
137+
!isa(A, Const) && $pb(A.dval, A.val, (DiagonalTensorMap(D), V), dD)
138+
return (nothing, nothing)
139+
end
140+
end
141+
end
142+
119143
for f in (:svd_compact, :svd_full)
120144
@eval begin
121145
function EnzymeRules.augmented_primal(
@@ -172,32 +196,80 @@ for f in (:svd_compact, :svd_full)
172196
end=# #hmmmm
173197
end
174198

175-
# TODO
176-
#=
177199
function EnzymeRules.augmented_primal(
178200
config::EnzymeRules.RevConfigWidth{1},
179-
func::Const{typeof(svd_trunc)},
201+
func::Const{typeof(svd_trunc_no_error)},
180202
::Type{RT},
181203
A::Annotation{<:AbstractTensorMap},
182204
alg::Const,
183205
) where {RT}
184-
185206
USVᴴ = svd_compact(A.val, alg.val.alg)
186207
USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.val.trunc)
187-
ϵ = MatrixAlgebraKit.truncation_error(diagview(USVᴴ[2]), ind)
188208
dUSVᴴtrunc = make_zero(USVᴴtrunc)
189-
cache = (USVᴴtrunc, dUSVᴴtrunc)
190-
return EnzymeRules.AugmentedReturn(USVᴴtrunc, dUSVᴴtrunc, cache)
209+
cache = (USVᴴ, USVᴴtrunc, dUSVᴴtrunc, ind)
210+
return EnzymeRules.AugmentedReturn(USVᴴtrunc, dUSVᴴtrunc, cache)
191211
end
192212
function EnzymeRules.reverse(
193-
config::EnzymeRules.RevConfigWidth{1},
194-
func::Const{typeof(svd_trunc)},
195-
::Type{RT},
196-
cache,
197-
A::Annotation{<:AbstractTensorMap},
198-
alg::Const,
199-
) where {RT}
200-
USVᴴ, dUSVᴴ = cache
201-
MatrixAlgebraKit.svd_pullback!(A.dval, A.val, USVᴴ, dUSVᴴ)
202-
return (nothing, nothing)
203-
end=#
213+
config::EnzymeRules.RevConfigWidth{1},
214+
func::Const{typeof(svd_trunc_no_error)},
215+
::Type{RT},
216+
cache,
217+
A::Annotation{<:AbstractTensorMap},
218+
alg::Const,
219+
) where {RT}
220+
USVᴴ, USVᴴtrunc, dUSVᴴtrunc, ind = cache
221+
MatrixAlgebraKit.svd_pullback!(A.dval, A.val, USVᴴ, dUSVᴴtrunc, ind)
222+
return (nothing, nothing)
223+
end
224+
225+
function EnzymeRules.augmented_primal(
226+
config::EnzymeRules.RevConfigWidth{1},
227+
func::Const{typeof(eig_trunc_no_error)},
228+
::Type{RT},
229+
A::Annotation{<:AbstractTensorMap},
230+
alg::Const,
231+
) where {RT}
232+
DV = eig_full(A.val, alg.val.alg)
233+
DVtrunc, ind = MatrixAlgebraKit.truncate(eig_trunc!, DV, alg.val.trunc)
234+
dDVtrunc = make_zero(DVtrunc)
235+
cache = (DV, DVtrunc, dDVtrunc, ind)
236+
return EnzymeRules.AugmentedReturn(DVtrunc, dDVtrunc, cache)
237+
end
238+
function EnzymeRules.reverse(
239+
config::EnzymeRules.RevConfigWidth{1},
240+
func::Const{typeof(eig_trunc_no_error)},
241+
::Type{RT},
242+
cache,
243+
A::Annotation{<:AbstractTensorMap},
244+
alg::Const,
245+
) where {RT}
246+
DV, DVtrunc, dDVtrunc, ind = cache
247+
MatrixAlgebraKit.eig_pullback!(A.dval, A.val, DV, dDVtrunc, ind)
248+
return (nothing, nothing)
249+
end
250+
251+
function EnzymeRules.augmented_primal(
252+
config::EnzymeRules.RevConfigWidth{1},
253+
func::Const{typeof(eigh_trunc_no_error)},
254+
::Type{RT},
255+
A::Annotation{<:AbstractTensorMap},
256+
alg::Const,
257+
) where {RT}
258+
DV = eigh_full(A.val, alg.val.alg)
259+
DVtrunc, ind = MatrixAlgebraKit.truncate(eigh_trunc!, DV, alg.val.trunc)
260+
dDVtrunc = make_zero(DVtrunc)
261+
cache = (DV, DVtrunc, dDVtrunc, ind)
262+
return EnzymeRules.AugmentedReturn(DVtrunc, dDVtrunc, cache)
263+
end
264+
function EnzymeRules.reverse(
265+
config::EnzymeRules.RevConfigWidth{1},
266+
func::Const{typeof(eigh_trunc_no_error)},
267+
::Type{RT},
268+
cache,
269+
A::Annotation{<:AbstractTensorMap},
270+
alg::Const,
271+
) where {RT}
272+
DV, DVtrunc, dDVtrunc, ind = cache
273+
MatrixAlgebraKit.eigh_pullback!(A.dval, A.val, DV, dDVtrunc, ind)
274+
return (nothing, nothing)
275+
end

src/factorizations/pullbacks.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,25 @@ for pullback! in (:svd_pullback!, :eig_pullback!, :eigh_pullback!)
4444
end
4545
end
4646

47+
for pullback! in (:eig_vals_pullback!, :eigh_vals_pullback!)
48+
@eval function MAK.$pullback!(
49+
Δt::AbstractTensorMap, t::AbstractTensorMap, DV, ΔD, inds = _notrunc_ind(t);
50+
kwargs...
51+
)
52+
D, V = DV
53+
foreachblock(Δt, t) do c, (Δb, b)
54+
haskey(inds, c) || return nothing
55+
ind = inds[c]
56+
Dc = block(D, c)
57+
Vc = block(V, c)
58+
ΔDc = block(ΔD, c)
59+
MAK.$pullback!(Δb, b, (Dc, Vc), ΔDc, ind; kwargs...)
60+
return nothing
61+
end
62+
return Δt
63+
end
64+
end
65+
4766
for pullback_trunc! in (:svd_trunc_pullback!, :eig_trunc_pullback!, :eigh_trunc_pullback!)
4867
@eval function MAK.$pullback_trunc!(
4968
Δt::AbstractTensorMap, t::AbstractTensorMap, F, ΔF; kwargs...

test/enzyme/factorizations/eig.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,15 @@ end
5959
ΔDV = EnzymeTestUtils.rand_tangent(DV)
6060
remove_eiggauge_dependence!(ΔDV[2], DV...)
6161
EnzymeTestUtils.test_reverse(eig_full, Duplicated, (t, Duplicated); output_tangent = ΔDV, atol, rtol)
62+
63+
#D = eig_vals(t)
64+
#EnzymeTestUtils.test_reverse(eig_vals, Duplicated, (t, Duplicated); atol, rtol)
65+
66+
V_trunc = spacetype(t)(c => min(size(b)...) ÷ 2 for (c, b) in blocks(t))
67+
trunc = truncspace(V_trunc)
68+
alg = MatrixAlgebraKit.select_algorithm(eig_trunc_no_error, t, nothing; trunc)
69+
DVtrunc = eig_trunc_no_error(t, alg)
70+
ΔDVtrunc = EnzymeTestUtils.rand_tangent(DVtrunc)
71+
remove_eiggauge_dependence!(ΔDVtrunc[2], DVtrunc...)
72+
EnzymeTestUtils.test_reverse(eig_trunc_no_error, Duplicated, (t, Duplicated), (alg, Const); output_tangent = ΔDVtrunc, atol, rtol)
6273
end

test/enzyme/factorizations/eigh.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,16 @@ end
6060
ΔDV = EnzymeTestUtils.rand_tangent(DV)
6161
remove_eighgauge_dependence!(ΔDV[2], DV...)
6262
EnzymeTestUtils.test_reverse(eigh_full project_hermitian, Duplicated, (t, Duplicated); output_tangent = ΔDV, atol, rtol)
63+
64+
#D = eigh_vals(th)
65+
#EnzymeTestUtils.test_reverse(eigh_vals ∘ project_hermitian, Duplicated, (t, Duplicated); atol, rtol)
66+
67+
V_trunc = spacetype(th)(c => min(size(b)...) ÷ 2 for (c, b) in blocks(t))
68+
trunc = truncspace(V_trunc)
69+
alg = MatrixAlgebraKit.select_algorithm(eigh_trunc_no_error, th, nothing; trunc)
70+
DVtrunc = eigh_trunc_no_error(th, alg)
71+
ΔDVtrunc = EnzymeTestUtils.rand_tangent(DVtrunc)
72+
remove_eighgauge_dependence!(ΔDVtrunc[2], DVtrunc...)
73+
proj_eigh(t, alg) = eigh_trunc_no_error(project_hermitian(t), alg)
74+
EnzymeTestUtils.test_reverse(proj_eigh, Duplicated, (t, Duplicated), (alg, Const); output_tangent = ΔDVtrunc, atol, rtol)
6375
end

test/enzyme/factorizations/lq.jl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,29 @@ function remove_lqgauge_dependence!(ΔQ, t, Q)
4747
return ΔQ
4848
end
4949

50+
function remove_lq_null_gauge_dependence!(ΔNᴴ, Q)
51+
for (c, b) in blocks(ΔNᴴ)
52+
Qc = block(Q, c)
53+
ΔNᴴQᴴ = b * Qc'
54+
mul!(b, ΔNᴴQᴴ, Qc)
55+
end
56+
return ΔNᴴ
57+
end
58+
5059
@timedtestset "Enzyme - Factorizations (LQ): $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes, A in (randn(T, V[1] V[2] V[1] V[2]), randn(T, V[1] V[2] V[1]))
5160
atol = default_tol(T)
5261
rtol = default_tol(T)
5362
EnzymeTestUtils.test_reverse(lq_compact, Duplicated, (A, Duplicated); atol, rtol)
5463

5564
# lq_full/lq_null requires being careful with gauges
56-
#=LQ = lq_full(A)
65+
LQ = lq_full(A)
5766
ΔLQ = EnzymeTestUtils.rand_tangent(LQ)
5867
remove_lqgauge_dependence!(ΔLQ[2], A, LQ[2])
59-
EnzymeTestUtils.test_reverse(lq_full, Duplicated, (A, Duplicated); output_tangent = ΔLQ, atol, rtol)=#
60-
#EnzymeTestUtils.test_reverse(lq_null, Duplicated, (A, Duplicated); atol, rtol)
68+
EnzymeTestUtils.test_reverse(lq_full, Duplicated, (A, Duplicated); output_tangent = ΔLQ, atol, rtol)
69+
70+
Nᴴ = lq_null(A)
71+
Q = lq_compact(A)[2]
72+
ΔNᴴ = EnzymeTestUtils.rand_tangent(Nᴴ)
73+
remove_lq_null_gauge_dependence!(ΔNᴴ, Q)
74+
EnzymeTestUtils.test_reverse(lq_null, Duplicated, (A, Duplicated); output_tangent = ΔNᴴ, atol, rtol)
6175
end

test/enzyme/factorizations/qr.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ end
5454
EnzymeTestUtils.test_reverse(qr_compact, Duplicated, (A, Duplicated); atol, rtol)
5555

5656
# qr_full/qr_null requires being careful with gauges
57-
#=QR = qr_full(A)
57+
QR = qr_full(A)
5858
ΔQR = EnzymeTestUtils.rand_tangent(QR)
5959
remove_qrgauge_dependence!(ΔQR[1], A, QR[1])
60-
EnzymeTestUtils.test_reverse(qr_full, Duplicated, (A, Duplicated); output_tangent = ΔQR, atol, rtol)=#
60+
EnzymeTestUtils.test_reverse(qr_full, Duplicated, (A, Duplicated); output_tangent = ΔQR, atol, rtol)
6161
#EnzymeTestUtils.test_reverse(qr_null, Duplicated, (A, Duplicated); atol, rtol)
6262
end

test/enzyme/factorizations/svd.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ end
5858
atol = default_tol(T)
5959
rtol = default_tol(T)
6060

61+
#S = svd_vals(t)
62+
#EnzymeTestUtils.test_reverse(svd_vals, Duplicated, (t, Duplicated); atol, rtol)
63+
6164
USVᴴ = svd_compact(t)
6265
ΔUSVᴴ = EnzymeTestUtils.rand_tangent(USVᴴ)
6366
remove_svdgauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...)
@@ -68,11 +71,11 @@ end
6871
remove_svdgauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...)
6972
EnzymeTestUtils.test_reverse(svd_full, Duplicated, (t, Duplicated); output_tangent = ΔUSVᴴ, atol, rtol)=#
7073

71-
#=V_trunc = spacetype(t)(c => min(size(b)...) ÷ 2 for (c, b) in blocks(t))
74+
V_trunc = spacetype(t)(c => min(size(b)...) ÷ 2 for (c, b) in blocks(t))
7275
trunc = truncspace(V_trunc)
7376
alg = MatrixAlgebraKit.select_algorithm(svd_trunc_no_error, t, nothing; trunc)
7477
USVᴴtrunc = svd_trunc_no_error(t, alg)
7578
ΔUSVᴴtrunc = EnzymeTestUtils.rand_tangent(USVᴴtrunc)
7679
remove_svdgauge_dependence!(ΔUSVᴴtrunc[1], ΔUSVᴴtrunc[3], USVᴴtrunc...)
77-
EnzymeTestUtils.test_reverse(svd_trunc_no_error, Duplicated, (t, Duplicated), (alg, Const); output_tangent = ΔUSVᴴtrunc, atol, rtol)=#
80+
EnzymeTestUtils.test_reverse(svd_trunc_no_error, Duplicated, (t, Duplicated), (alg, Const); output_tangent = ΔUSVᴴtrunc, atol, rtol)
7881
end

0 commit comments

Comments
 (0)