Skip to content

Commit 697e306

Browse files
committed
Incremental work on Enzyme support
1 parent ec7af8f commit 697e306

38 files changed

Lines changed: 2939 additions & 2 deletions

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ jobs:
7575
directories: 'src,ext'
7676
- uses: codecov/codecov-action@v6
7777
with:
78-
files: lcov.info
78+
files: lcov.info
7979
token: ${{ secrets.CODECOV_TOKEN }}
8080
fail_ci_if_error: false
8181

Project.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
2222
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
2323
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2424
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
25+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
26+
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
2527
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
2628
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
2729
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
@@ -30,6 +32,8 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
3032
TensorKitAdaptExt = "Adapt"
3133
TensorKitCUDAExt = ["CUDA", "cuTENSOR"]
3234
TensorKitChainRulesCoreExt = "ChainRulesCore"
35+
TensorKitEnzymeExt = "Enzyme"
36+
TensorKitEnzymeTestUtilsExt = "EnzymeTestUtils"
3337
TensorKitFiniteDifferencesExt = "FiniteDifferences"
3438
TensorKitMooncakeExt = "Mooncake"
3539

@@ -41,6 +45,8 @@ Adapt = "4"
4145
CUDA = "5.9"
4246
ChainRulesCore = "1"
4347
Dictionaries = "0.4"
48+
Enzyme = "0.13.134"
49+
EnzymeTestUtils = "0.2.5"
4450
FiniteDifferences = "0.12"
4551
LRUCache = "1.0.2"
4652
LinearAlgebra = "1"
@@ -52,7 +58,7 @@ Random = "1"
5258
ScopedValues = "1.3.0"
5359
Strided = "2"
5460
TensorKitSectors = "0.3.6"
55-
TensorOperations = "5.1"
61+
TensorOperations = "5.5.2"
5662
TupleTools = "1.5"
5763
VectorInterface = "0.4.8, 0.5"
5864
cuTENSOR = "2"
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
module TensorKitEnzymeExt
2+
3+
using Enzyme
4+
using TensorKit
5+
import TensorKit as TK
6+
using VectorInterface
7+
using TensorOperations: TensorOperations, IndexTuple, Index2Tuple, linearize
8+
import TensorOperations as TO
9+
using MatrixAlgebraKit
10+
using TupleTools
11+
using Random: AbstractRNG
12+
13+
include("utility.jl")
14+
include("linalg.jl")
15+
include("vectorinterface.jl")
16+
include("tensoroperations.jl")
17+
include("factorizations.jl")
18+
include("indexmanipulations.jl")
19+
#include("planaroperations.jl")
20+
21+
end
Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
# need these due to Enzyme choking on blocks
2+
3+
for f in (:project_hermitian, :project_antihermitian)
4+
f! = Symbol(f, :!)
5+
@eval begin
6+
function EnzymeRules.augmented_primal(
7+
config::EnzymeRules.RevConfigWidth{1},
8+
func::Const{typeof($f!)},
9+
::Type{RT},
10+
A::Annotation{<:AbstractTensorMap},
11+
arg::Annotation{<:AbstractTensorMap},
12+
alg::Const,
13+
) where {RT}
14+
$f!(A.val, arg.val, alg.val)
15+
primal = EnzymeRules.needs_primal(config) ? arg.val : nothing
16+
shadow = EnzymeRules.needs_shadow(config) ? arg.dval : nothing
17+
cache = nothing
18+
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
19+
end
20+
function EnzymeRules.reverse(
21+
config::EnzymeRules.RevConfigWidth{1},
22+
func::Const{typeof($f!)},
23+
::Type{RT},
24+
cache,
25+
A::Annotation{<:AbstractTensorMap},
26+
arg::Annotation{<:AbstractTensorMap},
27+
alg::Const,
28+
) where {RT}
29+
if !isa(A, Const)
30+
$f!(arg.dval, arg.dval, alg.val)
31+
if A.dval !== arg.dval
32+
A.dval .+= arg.dval
33+
make_zero!(arg.dval)
34+
end
35+
end
36+
return (nothing, nothing, nothing)
37+
end
38+
function EnzymeRules.augmented_primal(
39+
config::EnzymeRules.RevConfigWidth{1},
40+
func::Const{typeof($f)},
41+
::Type{RT},
42+
A::Annotation{<:AbstractTensorMap},
43+
alg::Const,
44+
) where {RT}
45+
ret = $f(A.val, alg.val)
46+
dret = make_zero(ret)
47+
primal = EnzymeRules.needs_primal(config) ? ret : nothing
48+
shadow = EnzymeRules.needs_shadow(config) ? dret : nothing
49+
cache = dret
50+
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
51+
end
52+
function EnzymeRules.reverse(
53+
config::EnzymeRules.RevConfigWidth{1},
54+
func::Const{typeof($f)},
55+
::Type{RT},
56+
cache,
57+
A::Annotation{<:AbstractTensorMap},
58+
alg::Const,
59+
) where {RT}
60+
dret = cache
61+
if !isa(A, Const)
62+
$f!(dret, dret, alg.val)
63+
add!(A.dval, dret)
64+
end
65+
make_zero!(dret)
66+
return (nothing, nothing)
67+
end
68+
end
69+
end
70+
71+
for (f, pb) in (
72+
(:eig_full, :(MatrixAlgebraKit.eig_pullback!)),
73+
(:eigh_full, :(MatrixAlgebraKit.eigh_pullback!)),
74+
(:lq_compact, :(MatrixAlgebraKit.lq_pullback!)),
75+
(:qr_compact, :(MatrixAlgebraKit.qr_pullback!)),
76+
(:lq_full, :(MatrixAlgebraKit.lq_pullback!)),
77+
(:qr_full, :(MatrixAlgebraKit.qr_pullback!)),
78+
(:lq_null, :(MatrixAlgebraKit.lq_null_pullback!)),
79+
(:qr_null, :(MatrixAlgebraKit.qr_null_pullback!)),
80+
)
81+
@eval begin
82+
function EnzymeRules.augmented_primal(
83+
config::EnzymeRules.RevConfigWidth{1},
84+
func::Const{typeof($f)},
85+
::Type{RT},
86+
A::Annotation{<:AbstractTensorMap},
87+
alg::Const,
88+
) where {RT}
89+
ret = $f(A.val, alg.val)
90+
primal = EnzymeRules.needs_primal(config) ? ret : nothing
91+
shadow = EnzymeRules.needs_shadow(config) ? make_zero(ret) : nothing
92+
cache = (ret, shadow)
93+
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
94+
end
95+
function EnzymeRules.reverse(
96+
config::EnzymeRules.RevConfigWidth{1},
97+
func::Const{typeof($f)},
98+
::Type{RT},
99+
cache,
100+
A::Annotation{<:AbstractTensorMap},
101+
alg::Const,
102+
) where {RT}
103+
!isa(A, Const) && $pb(A.dval, A.val, cache...)
104+
return (nothing, nothing)
105+
end
106+
end
107+
end
108+
109+
for (f, f_full, pb) in (
110+
(:eig_vals, :eig_full, :(MatrixAlgebraKit.eig_vals_pullback!)),
111+
(:eigh_vals, :eigh_full, :(MatrixAlgebraKit.eigh_vals_pullback!)),
112+
)
113+
@eval begin
114+
function EnzymeRules.augmented_primal(
115+
config::EnzymeRules.RevConfigWidth{1},
116+
func::Const{typeof($f)},
117+
::Type{RT},
118+
A::Annotation{<:AbstractTensorMap},
119+
alg::Const,
120+
) where {RT}
121+
ret_full = $f_full(A.val, alg.val)
122+
ret = diagview(ret_full[1])
123+
primal = EnzymeRules.needs_primal(config) ? ret : nothing
124+
shadow = EnzymeRules.needs_shadow(config) ? make_zero(ret) : nothing
125+
cache = (ret, shadow, ret_full[2])
126+
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
127+
end
128+
function EnzymeRules.reverse(
129+
config::EnzymeRules.RevConfigWidth{1},
130+
func::Const{typeof($f)},
131+
::Type{RT},
132+
cache,
133+
A::Annotation{<:AbstractTensorMap},
134+
alg::Const,
135+
) where {RT}
136+
D, dD, V = cache
137+
!isa(A, Const) && $pb(A.dval, A.val, (DiagonalTensorMap(D), V), dD)
138+
return (nothing, nothing)
139+
end
140+
end
141+
end
142+
143+
for f in (:svd_compact, :svd_full)
144+
@eval begin
145+
function EnzymeRules.augmented_primal(
146+
config::EnzymeRules.RevConfigWidth{1},
147+
func::Const{typeof($f)},
148+
::Type{RT},
149+
A::Annotation{<:AbstractTensorMap},
150+
alg::Const,
151+
) where {RT}
152+
USVᴴ = $f(A.val, alg.val)
153+
primal = EnzymeRules.needs_primal(config) ? USVᴴ : nothing
154+
shadow = EnzymeRules.needs_shadow(config) ? make_zero(USVᴴ) : nothing
155+
cache = (USVᴴ, shadow)
156+
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
157+
end
158+
function EnzymeRules.reverse(
159+
config::EnzymeRules.RevConfigWidth{1},
160+
func::Const{typeof($f)},
161+
::Type{RT},
162+
cache,
163+
A::Annotation{<:AbstractTensorMap},
164+
alg::Const,
165+
) where {RT}
166+
!isa(A, Const) && MatrixAlgebraKit.svd_pullback!(A.dval, A.val, cache...)
167+
return (nothing, nothing)
168+
end
169+
end
170+
171+
# mutating version is not guaranteed to actually mutate
172+
# so we can simply use the non-mutating version instead
173+
f! = Symbol(f, :!)
174+
#=@eval begin
175+
function EnzymeRules.augmented_primal(
176+
config::EnzymeRules.RevConfigWidth{1},
177+
func::Const{typeof($f!)},
178+
::Type{RT},
179+
A::Annotation{<:AbstractTensorMap},
180+
USVᴴ::Annotation,
181+
alg::Const,
182+
) where {RT}
183+
EnzymeRules.augmented_primal(func, RT, A, alg)
184+
end
185+
function EnzymeRules.reverse(
186+
config::EnzymeRules.RevConfigWidth{1},
187+
func::Const{typeof($f!)},
188+
::Type{RT},
189+
cache,
190+
A::Annotation{<:AbstractTensorMap},
191+
USVᴴ::Annotation,
192+
alg::Const,
193+
) where {RT}
194+
EnzymeRules.reverse(func, RT, A, alg)
195+
end
196+
end=# #hmmmm
197+
end
198+
199+
function EnzymeRules.augmented_primal(
200+
config::EnzymeRules.RevConfigWidth{1},
201+
func::Const{typeof(svd_trunc_no_error)},
202+
::Type{RT},
203+
A::Annotation{<:AbstractTensorMap},
204+
alg::Const,
205+
) where {RT}
206+
USVᴴ = svd_compact(A.val, alg.val.alg)
207+
USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.val.trunc)
208+
dUSVᴴtrunc = make_zero(USVᴴtrunc)
209+
cache = (USVᴴ, USVᴴtrunc, dUSVᴴtrunc, ind)
210+
return EnzymeRules.AugmentedReturn(USVᴴtrunc, dUSVᴴtrunc, cache)
211+
end
212+
function EnzymeRules.reverse(
213+
config::EnzymeRules.RevConfigWidth{1},
214+
func::Const{typeof(svd_trunc_no_error)},
215+
::Type{RT},
216+
cache,
217+
A::Annotation{<:AbstractTensorMap},
218+
alg::Const,
219+
) where {RT}
220+
USVᴴ, USVᴴtrunc, dUSVᴴtrunc, ind = cache
221+
MatrixAlgebraKit.svd_pullback!(A.dval, A.val, USVᴴ, dUSVᴴtrunc, ind)
222+
return (nothing, nothing)
223+
end
224+
225+
function EnzymeRules.augmented_primal(
226+
config::EnzymeRules.RevConfigWidth{1},
227+
func::Const{typeof(eig_trunc_no_error)},
228+
::Type{RT},
229+
A::Annotation{<:AbstractTensorMap},
230+
alg::Const,
231+
) where {RT}
232+
DV = eig_full(A.val, alg.val.alg)
233+
DVtrunc, ind = MatrixAlgebraKit.truncate(eig_trunc!, DV, alg.val.trunc)
234+
dDVtrunc = make_zero(DVtrunc)
235+
cache = (DV, DVtrunc, dDVtrunc, ind)
236+
return EnzymeRules.AugmentedReturn(DVtrunc, dDVtrunc, cache)
237+
end
238+
function EnzymeRules.reverse(
239+
config::EnzymeRules.RevConfigWidth{1},
240+
func::Const{typeof(eig_trunc_no_error)},
241+
::Type{RT},
242+
cache,
243+
A::Annotation{<:AbstractTensorMap},
244+
alg::Const,
245+
) where {RT}
246+
DV, DVtrunc, dDVtrunc, ind = cache
247+
MatrixAlgebraKit.eig_pullback!(A.dval, A.val, DV, dDVtrunc, ind)
248+
return (nothing, nothing)
249+
end
250+
251+
function EnzymeRules.augmented_primal(
252+
config::EnzymeRules.RevConfigWidth{1},
253+
func::Const{typeof(eigh_trunc_no_error)},
254+
::Type{RT},
255+
A::Annotation{<:AbstractTensorMap},
256+
alg::Const,
257+
) where {RT}
258+
DV = eigh_full(A.val, alg.val.alg)
259+
DVtrunc, ind = MatrixAlgebraKit.truncate(eigh_trunc!, DV, alg.val.trunc)
260+
dDVtrunc = make_zero(DVtrunc)
261+
cache = (DV, DVtrunc, dDVtrunc, ind)
262+
return EnzymeRules.AugmentedReturn(DVtrunc, dDVtrunc, cache)
263+
end
264+
function EnzymeRules.reverse(
265+
config::EnzymeRules.RevConfigWidth{1},
266+
func::Const{typeof(eigh_trunc_no_error)},
267+
::Type{RT},
268+
cache,
269+
A::Annotation{<:AbstractTensorMap},
270+
alg::Const,
271+
) where {RT}
272+
DV, DVtrunc, dDVtrunc, ind = cache
273+
MatrixAlgebraKit.eigh_pullback!(A.dval, A.val, DV, dDVtrunc, ind)
274+
return (nothing, nothing)
275+
end

0 commit comments

Comments
 (0)