|
1 | | -function EnzymeRules.reverse( |
2 | | - config::EnzymeRules.RevConfigWidth{1}, |
3 | | - func::Const{typeof(MatrixAlgebraKit.copy_input)}, |
4 | | - ::Type{RT}, |
5 | | - cache, |
6 | | - f::Annotation, |
7 | | - A::Annotation{<:AbstractTensorMap} |
8 | | - ) where {RT} |
9 | | - copy_shadow = cache |
10 | | - if !isa(A, Const) && !isnothing(copy_shadow) |
11 | | - add!(A.dval, copy_shadow) |
12 | | - end |
13 | | - return (nothing, nothing) |
14 | | -end |
15 | | - |
16 | 1 | # need these due to Enzyme choking on blocks |
| 2 | + |
17 | 3 | for f in (:project_hermitian, :project_antihermitian) |
18 | 4 | f! = Symbol(f, :!) |
19 | 5 | @eval begin |
@@ -87,6 +73,10 @@ for (f, pb) in ( |
87 | 73 | (:eigh_full, :(MatrixAlgebraKit.eigh_pullback!)), |
88 | 74 | (:lq_compact, :(MatrixAlgebraKit.lq_pullback!)), |
89 | 75 | (:qr_compact, :(MatrixAlgebraKit.qr_pullback!)), |
| 76 | + (:lq_full, :(MatrixAlgebraKit.lq_pullback!)), |
| 77 | + (:qr_full, :(MatrixAlgebraKit.qr_pullback!)), |
| 78 | + (:lq_null, :(MatrixAlgebraKit.lq_null_pullback!)), |
| 79 | + (:qr_null, :(MatrixAlgebraKit.qr_null_pullback!)), |
90 | 80 | ) |
91 | 81 | @eval begin |
92 | 82 | function EnzymeRules.augmented_primal( |
@@ -116,6 +106,40 @@ for (f, pb) in ( |
116 | 106 | end |
117 | 107 | end |
118 | 108 |
|
| 109 | +for (f, f_full, pb) in ( |
| 110 | + (:eig_vals, :eig_full, :(MatrixAlgebraKit.eig_vals_pullback!)), |
| 111 | + (:eigh_vals, :eigh_full, :(MatrixAlgebraKit.eigh_vals_pullback!)), |
| 112 | + ) |
| 113 | + @eval begin |
| 114 | + function EnzymeRules.augmented_primal( |
| 115 | + config::EnzymeRules.RevConfigWidth{1}, |
| 116 | + func::Const{typeof($f)}, |
| 117 | + ::Type{RT}, |
| 118 | + A::Annotation{<:AbstractTensorMap}, |
| 119 | + alg::Const, |
| 120 | + ) where {RT} |
| 121 | + ret_full = $f_full(A.val, alg.val) |
| 122 | + ret = diagview(ret_full[1]) |
| 123 | + primal = EnzymeRules.needs_primal(config) ? ret : nothing |
| 124 | + shadow = EnzymeRules.needs_shadow(config) ? make_zero(ret) : nothing |
| 125 | + cache = (ret, shadow, ret_full[2]) |
| 126 | + return EnzymeRules.AugmentedReturn(primal, shadow, cache) |
| 127 | + end |
| 128 | + function EnzymeRules.reverse( |
| 129 | + config::EnzymeRules.RevConfigWidth{1}, |
| 130 | + func::Const{typeof($f)}, |
| 131 | + ::Type{RT}, |
| 132 | + cache, |
| 133 | + A::Annotation{<:AbstractTensorMap}, |
| 134 | + alg::Const, |
| 135 | + ) where {RT} |
| 136 | + D, dD, V = cache |
| 137 | + !isa(A, Const) && $pb(A.dval, A.val, (DiagonalTensorMap(D), V), dD) |
| 138 | + return (nothing, nothing) |
| 139 | + end |
| 140 | + end |
| 141 | +end |
| 142 | + |
119 | 143 | for f in (:svd_compact, :svd_full) |
120 | 144 | @eval begin |
121 | 145 | function EnzymeRules.augmented_primal( |
@@ -172,32 +196,80 @@ for f in (:svd_compact, :svd_full) |
172 | 196 | end=# #hmmmm |
173 | 197 | end |
174 | 198 |
|
175 | | -# TODO |
176 | | -#= |
177 | 199 | function EnzymeRules.augmented_primal( |
178 | 200 | config::EnzymeRules.RevConfigWidth{1}, |
179 | | - func::Const{typeof(svd_trunc)}, |
| 201 | + func::Const{typeof(svd_trunc_no_error)}, |
180 | 202 | ::Type{RT}, |
181 | 203 | A::Annotation{<:AbstractTensorMap}, |
182 | 204 | alg::Const, |
183 | 205 | ) where {RT} |
184 | | -
|
185 | 206 | USVᴴ = svd_compact(A.val, alg.val.alg) |
186 | 207 | USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.val.trunc) |
187 | | - ϵ = MatrixAlgebraKit.truncation_error(diagview(USVᴴ[2]), ind) |
188 | 208 | dUSVᴴtrunc = make_zero(USVᴴtrunc) |
189 | | - cache = (USVᴴtrunc, dUSVᴴtrunc) |
190 | | - return EnzymeRules.AugmentedReturn(USVᴴtrunc, dUSVᴴtrunc, cache) |
| 209 | + cache = (USVᴴ, USVᴴtrunc, dUSVᴴtrunc, ind) |
| 210 | + return EnzymeRules.AugmentedReturn(USVᴴtrunc, dUSVᴴtrunc, cache) |
191 | 211 | end |
192 | 212 | function EnzymeRules.reverse( |
193 | | - config::EnzymeRules.RevConfigWidth{1}, |
194 | | - func::Const{typeof(svd_trunc)}, |
195 | | - ::Type{RT}, |
196 | | - cache, |
197 | | - A::Annotation{<:AbstractTensorMap}, |
198 | | - alg::Const, |
199 | | -) where {RT} |
200 | | - USVᴴ, dUSVᴴ = cache |
201 | | - MatrixAlgebraKit.svd_pullback!(A.dval, A.val, USVᴴ, dUSVᴴ) |
202 | | - return (nothing, nothing) |
203 | | -end=# |
| 213 | + config::EnzymeRules.RevConfigWidth{1}, |
| 214 | + func::Const{typeof(svd_trunc_no_error)}, |
| 215 | + ::Type{RT}, |
| 216 | + cache, |
| 217 | + A::Annotation{<:AbstractTensorMap}, |
| 218 | + alg::Const, |
| 219 | + ) where {RT} |
| 220 | + USVᴴ, USVᴴtrunc, dUSVᴴtrunc, ind = cache |
| 221 | + MatrixAlgebraKit.svd_pullback!(A.dval, A.val, USVᴴ, dUSVᴴtrunc, ind) |
| 222 | + return (nothing, nothing) |
| 223 | +end |
| 224 | + |
| 225 | +function EnzymeRules.augmented_primal( |
| 226 | + config::EnzymeRules.RevConfigWidth{1}, |
| 227 | + func::Const{typeof(eig_trunc_no_error)}, |
| 228 | + ::Type{RT}, |
| 229 | + A::Annotation{<:AbstractTensorMap}, |
| 230 | + alg::Const, |
| 231 | + ) where {RT} |
| 232 | + DV = eig_full(A.val, alg.val.alg) |
| 233 | + DVtrunc, ind = MatrixAlgebraKit.truncate(eig_trunc!, DV, alg.val.trunc) |
| 234 | + dDVtrunc = make_zero(DVtrunc) |
| 235 | + cache = (DV, DVtrunc, dDVtrunc, ind) |
| 236 | + return EnzymeRules.AugmentedReturn(DVtrunc, dDVtrunc, cache) |
| 237 | +end |
| 238 | +function EnzymeRules.reverse( |
| 239 | + config::EnzymeRules.RevConfigWidth{1}, |
| 240 | + func::Const{typeof(eig_trunc_no_error)}, |
| 241 | + ::Type{RT}, |
| 242 | + cache, |
| 243 | + A::Annotation{<:AbstractTensorMap}, |
| 244 | + alg::Const, |
| 245 | + ) where {RT} |
| 246 | + DV, DVtrunc, dDVtrunc, ind = cache |
| 247 | + MatrixAlgebraKit.eig_pullback!(A.dval, A.val, DV, dDVtrunc, ind) |
| 248 | + return (nothing, nothing) |
| 249 | +end |
| 250 | + |
| 251 | +function EnzymeRules.augmented_primal( |
| 252 | + config::EnzymeRules.RevConfigWidth{1}, |
| 253 | + func::Const{typeof(eigh_trunc_no_error)}, |
| 254 | + ::Type{RT}, |
| 255 | + A::Annotation{<:AbstractTensorMap}, |
| 256 | + alg::Const, |
| 257 | + ) where {RT} |
| 258 | + DV = eigh_full(A.val, alg.val.alg) |
| 259 | + DVtrunc, ind = MatrixAlgebraKit.truncate(eigh_trunc!, DV, alg.val.trunc) |
| 260 | + dDVtrunc = make_zero(DVtrunc) |
| 261 | + cache = (DV, DVtrunc, dDVtrunc, ind) |
| 262 | + return EnzymeRules.AugmentedReturn(DVtrunc, dDVtrunc, cache) |
| 263 | +end |
| 264 | +function EnzymeRules.reverse( |
| 265 | + config::EnzymeRules.RevConfigWidth{1}, |
| 266 | + func::Const{typeof(eigh_trunc_no_error)}, |
| 267 | + ::Type{RT}, |
| 268 | + cache, |
| 269 | + A::Annotation{<:AbstractTensorMap}, |
| 270 | + alg::Const, |
| 271 | + ) where {RT} |
| 272 | + DV, DVtrunc, dDVtrunc, ind = cache |
| 273 | + MatrixAlgebraKit.eigh_pullback!(A.dval, A.val, DV, dDVtrunc, ind) |
| 274 | + return (nothing, nothing) |
| 275 | +end |
0 commit comments