-
Notifications
You must be signed in to change notification settings - Fork 60
Expand file tree
/
Copy pathTensorKitChainRulesCoreExt.jl
More file actions
790 lines (698 loc) · 28.1 KB
/
TensorKitChainRulesCoreExt.jl
File metadata and controls
790 lines (698 loc) · 28.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
module TensorKitChainRulesCoreExt
using TensorOperations
using TensorOperations: Backend, promote_contract
using TensorKit
using TensorKit: planaradd!, planarcontract!, planarcontract, _canonicalize
using VectorInterface
using ChainRulesCore
using LinearAlgebra
using TupleTools
using TupleTools: getindices
# Utility
# -------
_conj(conjA::Symbol) = conjA == :C ? :N : :C
trivtuple(N) = ntuple(identity, N)
trivtuple(::Index2Tuple{N₁,N₂}) where {N₁,N₂} = trivtuple(N₁ + N₂)
function _repartition(p::IndexTuple, N₁::Int)
length(p) >= N₁ ||
throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)"))
return p[1:N₁], p[(N₁ + 1):end]
end
_repartition(p::Index2Tuple, N₁::Int) = _repartition(linearize(p), N₁)
function _repartition(p::Union{IndexTuple,Index2Tuple}, ::Index2Tuple{N₁}) where {N₁}
return _repartition(p, N₁)
end
function _repartition(p::Union{IndexTuple,Index2Tuple},
::AbstractTensorMap{<:Any,N₁}) where {N₁}
return _repartition(p, N₁)
end
TensorKit.block(t::ZeroTangent, c::Sector) = t
# Constructors
# ------------
@non_differentiable TensorKit.TensorMap(f::Function, storagetype, cod, dom)
@non_differentiable TensorKit.isomorphism(args...)
@non_differentiable TensorKit.isometry(args...)
@non_differentiable TensorKit.unitary(args...)
function ChainRulesCore.rrule(::Type{<:TensorMap}, d::DenseArray, args...)
function TensorMap_pullback(Δt)
∂d = convert(Array, Δt)
return NoTangent(), ∂d, fill(NoTangent(), length(args))...
end
return TensorMap(d, args...), TensorMap_pullback
end
function ChainRulesCore.rrule(::typeof(convert), T::Type{<:Array}, t::AbstractTensorMap)
A = convert(T, t)
function convert_pullback(ΔA)
∂t = TensorMap(ΔA, codomain(t), domain(t))
return NoTangent(), NoTangent(), ∂t
end
return A, convert_pullback
end
function ChainRulesCore.rrule(::typeof(Base.copy), t::AbstractTensorMap)
copy_pullback(Δt) = NoTangent(), Δt
return copy(t), copy_pullback
end
ChainRulesCore.ProjectTo(::T) where {T<:AbstractTensorMap} = ProjectTo{T}()
function (::ProjectTo{T1})(x::T2) where {S,N1,N2,T1<:AbstractTensorMap{S,N1,N2},
T2<:AbstractTensorMap{S,N1,N2}}
T1 === T2 && return x
y = similar(x, scalartype(T1))
for (c, b) in blocks(y)
p = ProjectTo(b)
b .= p(block(x, c))
end
return y
end
# Base Linear Algebra
# -------------------
function ChainRulesCore.rrule(::typeof(+), a::AbstractTensorMap, b::AbstractTensorMap)
plus_pullback(Δc) = NoTangent(), Δc, Δc
return a + b, plus_pullback
end
ChainRulesCore.rrule(::typeof(-), a::AbstractTensorMap) = -a, Δc -> (NoTangent(), -Δc)
function ChainRulesCore.rrule(::typeof(-), a::AbstractTensorMap, b::AbstractTensorMap)
minus_pullback(Δc) = NoTangent(), Δc, -Δc
return a - b, minus_pullback
end
function ChainRulesCore.rrule(::typeof(*), a::AbstractTensorMap, b::AbstractTensorMap)
times_pullback(Δc) = NoTangent(), @thunk(Δc * b'), @thunk(a' * Δc)
return a * b, times_pullback
end
function ChainRulesCore.rrule(::typeof(*), a::AbstractTensorMap, b::Number)
times_pullback(Δc) = NoTangent(), @thunk(Δc * b'), @thunk(dot(a, Δc))
return a * b, times_pullback
end
function ChainRulesCore.rrule(::typeof(*), a::Number, b::AbstractTensorMap)
times_pullback(Δc) = NoTangent(), @thunk(dot(b, Δc)), @thunk(a' * Δc)
return a * b, times_pullback
end
function ChainRulesCore.rrule(::typeof(⊗), A::AbstractTensorMap, B::AbstractTensorMap)
C = A ⊗ B
projectA = ProjectTo(A)
projectB = ProjectTo(B)
function otimes_pullback(ΔC_)
ΔC = unthunk(ΔC_)
pΔC = ((codomainind(A)..., (domainind(A) .+ numout(B))...),
((codomainind(B) .+ numout(A))...,
(domainind(B) .+ (numin(A) + numout(A)))...))
dA_ = @thunk begin
ipA = (codomainind(A), domainind(A))
pB = (allind(B), ())
dA = zerovector(A,
promote_contract(scalartype(ΔC), scalartype(B)))
dA = planarcontract!(dA, ΔC, pΔC, :N, B, pB, :C, ipA, One(), Zero())
return projectA(dA)
end
dB_ = @thunk begin
ipB = (codomainind(B), domainind(B))
pA = ((), allind(A))
dB = zerovector(B,
promote_contract(scalartype(ΔC), scalartype(A)))
dB = planarcontract!(dB, A, pA, :C, ΔC, pΔC, :N, ipB, One(), Zero())
return projectB(dB)
end
return NoTangent(), dA_, dB_
end
return C, otimes_pullback
end
function ChainRulesCore.rrule(::typeof(permute), tsrc::AbstractTensorMap, p::Index2Tuple;
copy::Bool=false)
function permute_pullback(Δtdst)
invp = _canonicalize(TupleTools.invperm(linearize(p)), tsrc)
return NoTangent(), permute(unthunk(Δtdst), invp; copy=true), NoTangent()
end
return permute(tsrc, p; copy=true), permute_pullback
end
# LinearAlgebra
# -------------
function ChainRulesCore.rrule(::typeof(tr), A::AbstractTensorMap)
tr_pullback(Δtr) = NoTangent(), Δtr * id(domain(A))
return tr(A), tr_pullback
end
function ChainRulesCore.rrule(::typeof(adjoint), A::AbstractTensorMap)
adjoint_pullback(Δadjoint) = NoTangent(), adjoint(unthunk(Δadjoint))
return adjoint(A), adjoint_pullback
end
function ChainRulesCore.rrule(::typeof(dot), a::AbstractTensorMap, b::AbstractTensorMap)
dot_pullback(Δd) = NoTangent(), @thunk(b * Δd'), @thunk(a * Δd)
return dot(a, b), dot_pullback
end
function ChainRulesCore.rrule(::typeof(norm), a::AbstractTensorMap, p::Real=2)
p == 2 || error("currently only implemented for p = 2")
n = norm(a, p)
norm_pullback(Δn) = NoTangent(), a * (Δn' + Δn) / (n * 2), NoTangent()
return n, norm_pullback
end
# Factorizations
# --------------
function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap;
trunc::TensorKit.TruncationScheme=TensorKit.NoTruncation(),
p::Real=2,
alg::Union{TensorKit.SVD,TensorKit.SDD}=TensorKit.SDD())
U, Σ, V, ϵ = tsvd(t; trunc=TensorKit.NoTruncation(), p=p, alg=alg)
if !(trunc isa TensorKit.NoTruncation) && !isempty(blocksectors(t))
Σddata = TensorKit.SectorDict(c => diag(b) for (c, b) in blocks(Σ))
dims = TensorKit.SectorDict(c => length(b) for (c, b) in Σddata)
Σddata, ϵ = TensorKit._truncate!(Σddata, trunc, p)
Udata = TensorKit.SectorDict(c => b for (c, b) in blocks(U))
Vdata = TensorKit.SectorDict(c => b for (c, b) in blocks(V))
Udata′, Σddata′, Vdata′, dims′ = TensorKit._implement_svdtruncation!(t,
Udata,
Σddata,
Vdata,
dims)
W = spacetype(t)(dims′)
if W ≅ domain(Σ)
W = domain(Σ)
end
U′, Σ′, V′ = TensorKit._create_svdtensors(t, Udata′, Σddata′, Vdata′, W)
else
U′, Σ′, V′ = U, Σ, V
end
function tsvd!_pullback((ΔU, ΔΣ, ΔV, Δϵ))
Δt = similar(t)
for (c, b) in blocks(Δt)
Uc, Σc, Vc = block(U, c), block(Σ, c), block(V, c)
ΔUc, ΔΣc, ΔVc = block(ΔU, c), block(ΔΣ, c), block(ΔV, c)
Σdc = view(Σc, diagind(Σc))
ΔΣdc = (ΔΣc isa AbstractZero) ? ΔΣc : view(ΔΣc, diagind(ΔΣc))
svd_pullback!(b, Uc, Σdc, Vc, ΔUc, ΔΣdc, ΔVc)
end
return NoTangent(), Δt
end
function tsvd!_pullback(::Tuple{ZeroTangent,ZeroTangent,ZeroTangent})
return NoTangent(), ZeroTangent()
end
return (U′, Σ′, V′, ϵ), tsvd!_pullback
end
function ChainRulesCore.rrule(::typeof(TensorKit.eig!), t::AbstractTensorMap; kwargs...)
D, V = eig(t; kwargs...)
function eig!_pullback((ΔD, ΔV))
Δt = similar(t)
for (c, b) in blocks(Δt)
Dc, Vc = block(D, c), block(V, c)
ΔDc, ΔVc = block(ΔD, c), block(ΔV, c)
Ddc = view(Dc, diagind(Dc))
ΔDdc = (ΔDc isa AbstractZero) ? ΔDc : view(ΔDc, diagind(ΔDc))
eig_pullback!(b, Ddc, Vc, ΔDdc, ΔVc)
end
return NoTangent(), Δt
end
function eig!_pullback(::Tuple{ZeroTangent,ZeroTangent})
return NoTangent(), ZeroTangent()
end
return (D, V), eig!_pullback
end
function ChainRulesCore.rrule(::typeof(TensorKit.eigh!), t::AbstractTensorMap; kwargs...)
D, V = eigh(t; kwargs...)
function eigh!_pullback((ΔD, ΔV))
Δt = similar(t)
for (c, b) in blocks(Δt)
Dc, Vc = block(D, c), block(V, c)
ΔDc, ΔVc = block(ΔD, c), block(ΔV, c)
Ddc = view(Dc, diagind(Dc))
ΔDdc = (ΔDc isa AbstractZero) ? ΔDc : view(ΔDc, diagind(ΔDc))
eigh_pullback!(b, Ddc, Vc, ΔDdc, ΔVc)
end
return NoTangent(), Δt
end
function eigh!_pullback(::Tuple{ZeroTangent,ZeroTangent})
return NoTangent(), ZeroTangent()
end
return (D, V), eigh!_pullback
end
function ChainRulesCore.rrule(::typeof(leftorth!), t::AbstractTensorMap; alg=QRpos())
alg isa TensorKit.QR || alg isa TensorKit.QRpos ||
error("only `alg=QR()` and `alg=QRpos()` are supported")
Q, R = leftorth(t; alg)
function leftorth!_pullback((ΔQ, ΔR))
Δt = similar(t)
for (c, b) in blocks(Δt)
qr_pullback!(b, block(Q, c), block(R, c), block(ΔQ, c), block(ΔR, c))
end
return NoTangent(), Δt
end
leftorth!_pullback(::Tuple{ZeroTangent,ZeroTangent}) = NoTangent(), ZeroTangent()
return (Q, R), leftorth!_pullback
end
function ChainRulesCore.rrule(::typeof(rightorth!), t::AbstractTensorMap; alg=LQpos())
alg isa TensorKit.LQ || alg isa TensorKit.LQpos ||
error("only `alg=LQ()` and `alg=LQpos()` are supported")
L, Q = rightorth(t; alg)
function rightorth!_pullback((ΔL, ΔQ))
Δt = similar(t)
for (c, b) in blocks(Δt)
lq_pullback!(b, block(L, c), block(Q, c), block(ΔL, c), block(ΔQ, c))
end
return NoTangent(), Δt
end
rightorth!_pullback(::Tuple{ZeroTangent,ZeroTangent}) = NoTangent(), ZeroTangent()
return (L, Q), rightorth!_pullback
end
# Corresponding matrix factorisations: implemented as mutating methods
# ---------------------------------------------------------------------
# helper routines
safe_inv(a, tol) = abs(a) < tol ? zero(a) : inv(a)
function lowertriangularind(A::AbstractMatrix)
m, n = size(A)
I = Vector{Int}(undef, div(m * (m - 1), 2) + m * (n - m))
offset = 0
for j in 1:n
r = (j + 1):m
I[offset .- j .+ r] = (j - 1) * m .+ r
offset += length(r)
end
return I
end
function uppertriangularind(A::AbstractMatrix)
m, n = size(A)
I = Vector{Int}(undef, div(m * (m - 1), 2) + m * (n - m))
offset = 0
for i in 1:m
r = (i + 1):n
I[offset .- i .+ r] = i .+ m .* (r .- 1)
offset += length(r)
end
return I
end
# SVD_pullback: pullback implementation for general (possibly truncated) SVD
#
# Arguments are U, S and Vd of full (non-truncated, but still thin) SVD, as well as
# cotangent ΔU, ΔS, ΔVd variables of truncated SVD
#
# Checks whether the cotangent variables are such that they would couple to gauge-dependent
# degrees of freedom (phases of singular vectors), and prints a warning if this is the case
#
# An implementation that only uses U, S, and Vd from truncated SVD is also possible, but
# requires solving a Sylvester equation, which does not seem to be supported on GPUs.
#
# Other implementation considerations for GPU compatibility:
# no scalar indexing, lots of broadcasting and views
#
function svd_pullback!(ΔA::AbstractMatrix, U::AbstractMatrix, S::AbstractVector,
Vd::AbstractMatrix, ΔU, ΔS, ΔVd;
atol::Real=0,
rtol::Real=atol > 0 ? 0 : eps(eltype(S))^(3 / 4))
# Basic size checks and determination
m, n = size(U, 1), size(Vd, 2)
size(U, 2) == size(Vd, 1) == length(S) == min(m, n) || throw(DimensionMismatch())
p = -1
if !(ΔU isa AbstractZero)
m == size(ΔU, 1) || throw(DimensionMismatch())
p = size(ΔU, 2)
end
if !(ΔVd isa AbstractZero)
n == size(ΔVd, 2) || throw(DimensionMismatch())
if p == -1
p = size(ΔVd, 1)
else
p == size(ΔVd, 1) || throw(DimensionMismatch())
end
end
if !(ΔS isa AbstractZero)
if p == -1
p = length(ΔS)
else
p == length(ΔS) || throw(DimensionMismatch())
end
end
Up = view(U, :, 1:p)
Vp = view(Vd, 1:p, :)'
Sp = view(S, 1:p)
# tolerance and rank
tol = atol > 0 ? atol : rtol * S[1, 1]
r = findlast(>=(tol), S)
# compute antihermitian part of projection of ΔU and ΔV onto U and V
# also already subtract this projection from ΔU and ΔV
if !(ΔU isa AbstractZero)
UΔU = Up' * ΔU
aUΔU = rmul!(UΔU - UΔU', 1 / 2)
if m > p
ΔU -= Up * UΔU
end
else
aUΔU = fill!(similar(U, (p, p)), 0)
end
if !(ΔVd isa AbstractZero)
VΔV = Vp' * ΔVd'
aVΔV = rmul!(VΔV - VΔV', 1 / 2)
if n > p
ΔVd -= VΔV' * Vp'
end
else
aVΔV = fill!(similar(Vd, (p, p)), 0)
end
# check whether cotangents arise from gauge-invariance objective function
mask = abs.(Sp' .- Sp) .< tol
Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf)
if p > r
rprange = (r + 1):p
Δgauge = max(Δgauge, norm(view(aUΔU, rprange, rprange), Inf))
Δgauge = max(Δgauge, norm(view(aVΔV, rprange, rprange), Inf))
end
Δgauge < tol ||
@warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
UdΔAV = (aUΔU .+ aVΔV) .* safe_inv.(Sp' .- Sp, tol) .+
(aUΔU .- aVΔV) .* safe_inv.(Sp' .+ Sp, tol)
if !(ΔS isa ZeroTangent)
UdΔAV[diagind(UdΔAV)] .+= real.(ΔS)
# in principle, ΔS is real, but maybe not if coming from an anyonic tensor
end
mul!(ΔA, Up, UdΔAV * Vp')
if r > p # contribution from truncation
Ur = view(U, :, (p + 1):r)
Vr = view(Vd, (p + 1):r, :)'
Sr = view(S, (p + 1):r)
if !(ΔU isa AbstractZero)
UrΔU = Ur' * ΔU
if m > r
ΔU -= Ur * UrΔU # subtract this part from ΔU
end
else
UrΔU = fill!(similar(U, (r - p, p)), 0)
end
if !(ΔVd isa AbstractZero)
VrΔV = Vr' * ΔVd'
if n > r
ΔVd -= VrΔV' * Vr' # subtract this part from ΔV
end
else
VrΔV = fill!(similar(Vd, (r - p, p)), 0)
end
X = (1 // 2) .* ((UrΔU .+ VrΔV) .* safe_inv.(Sp' .- Sr, tol) .+
(UrΔU .- VrΔV) .* safe_inv.(Sp' .+ Sr, tol))
Y = (1 // 2) .* ((UrΔU .+ VrΔV) .* safe_inv.(Sp' .- Sr, tol) .-
(UrΔU .- VrΔV) .* safe_inv.(Sp' .+ Sr, tol))
# ΔA += Ur * X * Vp' + Up * Y' * Vr'
mul!(ΔA, Ur, X * Vp', 1, 1)
mul!(ΔA, Up * Y', Vr', 1, 1)
end
if m > max(r, p) && !(ΔU isa AbstractZero) # remaining ΔU is already orthogonal to U[:,1:max(p,r)]
# ΔA += (ΔU .* safe_inv.(Sp', tol)) * Vp'
mul!(ΔA, ΔU .* safe_inv.(Sp', tol), Vp', 1, 1)
end
if n > max(r, p) && !(ΔVd isa AbstractZero) # remaining ΔV is already orthogonal to V[:,1:max(p,r)]
# ΔA += U * (safe_inv.(Sp, tol) .* ΔVd)
mul!(ΔA, Up, safe_inv.(Sp, tol) .* ΔVd, 1, 1)
end
return ΔA
end
function eig_pullback!(ΔA::AbstractMatrix, D::AbstractVector, V::AbstractMatrix, ΔD, ΔV;
atol::Real=0,
rtol::Real=atol > 0 ? 0 : eps(real(eltype(D)))^(3 / 4))
# Basic size checks and determination
n = LinearAlgebra.checksquare(V)
n == length(D) || throw(DimensionMismatch())
# tolerance and rank
tol = atol > 0 ? atol : rtol * maximum(abs, D)
if !(ΔV isa AbstractZero)
VdΔV = V' * ΔV
mask = abs.(transpose(D) .- D) .< tol
Δgauge = norm(view(VdΔV, mask), Inf)
Δgauge < tol ||
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
VdΔV .*= conj.(safe_inv.(transpose(D) .- D, tol))
if !(ΔD isa AbstractZero)
view(VdΔV, diagind(VdΔV)) .+= ΔD
end
PΔV = V' \ VdΔV
if eltype(ΔA) <: Real
ΔAc = mul!(VdΔV, PΔV, V') # recycle VdΔV memory
ΔA .= real.(ΔAc)
else
mul!(ΔA, PΔV, V')
end
else
PΔV = V' \ Diagonal(ΔD)
if eltype(ΔA) <: Real
ΔAc = PΔV * V'
ΔA .= real.(ΔAc)
else
mul!(ΔA, PΔV, V')
end
end
return ΔA
end
function eigh_pullback!(ΔA::AbstractMatrix, D::AbstractVector, V::AbstractMatrix, ΔD, ΔV;
atol::Real=0,
rtol::Real=atol > 0 ? 0 : eps(real(eltype(D)))^(3 / 4))
# Basic size checks and determination
n = LinearAlgebra.checksquare(V)
n == length(D) || throw(DimensionMismatch())
# tolerance and rank
tol = atol > 0 ? atol : rtol * maximum(abs, D)
if !(ΔV isa AbstractZero)
VdΔV = V' * ΔV
aVdΔV = rmul!(VdΔV - VdΔV', 1 / 2)
mask = abs.(D' .- D) .< tol
Δgauge = norm(view(aVdΔV, mask))
Δgauge < tol ||
@warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
aVdΔV .*= safe_inv.(D' .- D, tol)
if !(ΔD isa AbstractZero)
view(aVdΔV, diagind(aVdΔV)) .+= real.(ΔD)
# in principle, ΔD is real, but maybe not if coming from an anyonic tensor
end
# recylce VdΔV space
mul!(ΔA, mul!(VdΔV, V, aVdΔV), V')
else
mul!(ΔA, V * Diagonal(ΔD), V')
end
return ΔA
end
function qr_pullback!(ΔA::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix, ΔQ, ΔR;
atol::Real=0,
rtol::Real=atol > 0 ? 0 : eps(real(eltype(R)))^(3 / 4))
Rd = view(R, diagind(R))
p = let tol = atol > 0 ? atol : rtol * maximum(abs, Rd)
findlast(x -> abs(x) >= tol, Rd)
end
m, n = size(R)
Q1 = view(Q, :, 1:p)
R1 = view(R, 1:p, :)
R11 = view(R, 1:p, 1:p)
ΔA1 = view(ΔA, :, 1:p)
ΔQ1 = view(ΔQ, :, 1:p)
ΔR1 = view(ΔR, 1:p, :)
ΔR11 = view(ΔR, 1:p, 1:p)
M = similar(R, (p, p))
ΔR isa AbstractZero || mul!(M, ΔR1, R1')
ΔQ isa AbstractZero || mul!(M, Q1', ΔQ1, -1, !(ΔR isa AbstractZero))
view(M, lowertriangularind(M)) .= conj.(view(M, uppertriangularind(M)))
if eltype(M) <: Complex
Md = view(M, diagind(M))
Md .= real.(Md)
end
ΔA1 .= ΔQ1
mul!(ΔA1, Q1, M, +1, 1)
if n > p
R12 = view(R, 1:p, (p + 1):n)
ΔA2 = view(ΔA, :, (p + 1):n)
ΔR12 = view(ΔR, 1:p, (p + 1):n)
if ΔR isa AbstractZero
ΔA2 .= zero(eltype(ΔA))
else
mul!(ΔA2, Q1, ΔR12)
mul!(ΔA1, ΔA2, R12', -1, 1)
end
end
if m > p && !(ΔQ isa AbstractZero) # case where R is not full rank
Q2 = view(Q, :, (p + 1):m)
ΔQ2 = view(ΔQ, :, (p + 1):m)
Q1dΔQ2 = Q1' * ΔQ2
Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf)
Δgauge < tol ||
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
mul!(ΔA1, Q2, Q1dΔQ2', -1, 1)
end
rdiv!(ΔA1, UpperTriangular(R11)')
return ΔA
end
function lq_pullback!(ΔA::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix, ΔL, ΔQ;
atol::Real=0,
rtol::Real=atol > 0 ? 0 : eps(real(eltype(L)))^(3 / 4))
Ld = view(L, diagind(L))
p = let tol = atol > 0 ? atol : rtol * maximum(abs, Ld)
findlast(x -> abs(x) >= tol, Ld)
end
m, n = size(L)
L1 = view(L, :, 1:p)
L11 = view(L, 1:p, 1:p)
Q1 = view(Q, 1:p, :)
ΔA1 = view(ΔA, 1:p, :)
ΔQ1 = view(ΔQ, 1:p, :)
ΔL1 = view(ΔL, :, 1:p)
ΔR11 = view(ΔL, 1:p, 1:p)
M = similar(L, (p, p))
ΔL isa AbstractZero || mul!(M, L1', ΔL1)
ΔQ isa AbstractZero || mul!(M, ΔQ1, Q1', -1, !(ΔL isa AbstractZero))
view(M, uppertriangularind(M)) .= conj.(view(M, lowertriangularind(M)))
if eltype(M) <: Complex
Md = view(M, diagind(M))
Md .= real.(Md)
end
ΔA1 .= ΔQ1
mul!(ΔA1, M, Q1, +1, 1)
if m > p
L21 = view(L, (p + 1):m, 1:p)
ΔA2 = view(ΔA, (p + 1):m, :)
ΔL21 = view(ΔL, (p + 1):m, 1:p)
if ΔL isa AbstractZero
ΔA2 .= zero(eltype(ΔA))
else
mul!(ΔA2, ΔL21, Q1)
mul!(ΔA1, L21', ΔA2, -1, 1)
end
end
if n > p && !(ΔQ isa AbstractZero) # case where R is not full rank
Q2 = view(Q, (p + 1):n, :)
ΔQ2 = view(ΔQ, (p + 1):n, :)
ΔQ2Q1d = ΔQ2 * Q1'
Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1d, Q1, -1, 1))
Δgauge < tol ||
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
mul!(ΔA1, ΔQ2Q1d', Q2, -1, 1)
end
ldiv!(LowerTriangular(L11)', ΔA1)
return ΔA
end
# Planar rrules
# --------------
function ChainRulesCore.rrule(::typeof(TensorKit.planaradd!),
C::AbstractTensorMap,
A::AbstractTensorMap, pA::Index2Tuple, conjA::Symbol,
α::Number, β::Number,
backend::Backend...)
C′ = planaradd!(copy(C), A, pA, conjA, α, β, backend...)
projectA = ProjectTo(A)
projectC = ProjectTo(C)
projectα = ProjectTo(α)
projectβ = ProjectTo(β)
function planaradd_pullback(ΔC′)
ΔC = unthunk(ΔC′)
dC = @thunk projectC(scale(ΔC, conj(β)))
dA = @thunk begin
ip = _canonicalize(invperm(linearize(pA)), A)
_dA = zerovector(A, VectorInterface.promote_add(ΔC, α))
_dA = planaradd!(_dA, ΔC, ip, conjA, conjA == :N ? conj(α) : α, Zero(),
backend...)
return projectA(_dA)
end
dα = @thunk begin
_dα = tensorscalar(planarcontract(A, ((), linearize(pA)), _conj(conjA),
ΔC, (trivtuple(pA), ()), :N,
((), ()), One(), backend...))
return projectα(_dα)
end
dβ = @thunk begin
_dβ = tensorscalar(planarcontract(C,
((), trivtuple(TensorOperations.numind(pA))),
:C,
ΔC, (trivtuple(pA), ()), :N,
((), ()), One(), backend...))
return projectβ(_dβ)
end
dbackend = map(x -> NoTangent(), backend)
return NoTangent(), dC, dA, NoTangent(), NoTangent(), dα, dβ, dbackend...
end
return C′, planaradd_pullback
end
function ChainRulesCore.rrule(::typeof(TensorKit.planarcontract!),
C::AbstractTensorMap,
A::AbstractTensorMap, pA::Index2Tuple, conjA::Symbol,
B::AbstractTensorMap, pB::Index2Tuple, conjB::Symbol,
pAB::Index2Tuple,
α::Number, β::Number, backend::Backend...)
# indA = (codomainind(A), reverse(domainind(A)))
# indB = (codomainind(B), reverse(domainind(B)))
# pA, pB, pAB = TensorKit.reorder_planar_indices(indA, pA, indB, pB, pAB)
C′ = planarcontract!(copy(C), A, pA, conjA, B, pB, conjB, pAB, α, β, backend...)
projectA = ProjectTo(A)
projectB = ProjectTo(B)
projectC = ProjectTo(C)
projectα = ProjectTo(α)
projectβ = ProjectTo(β)
function planarcontract_pullback(ΔC′)
ΔC = unthunk(ΔC′)
ipAB = invperm(linearize(pAB))
pΔC = (getindices(ipAB, trivtuple(length(pA[1]))),
getindices(ipAB, length(pA[1]) .+ trivtuple(length(pB[2]))))
dC = @thunk projectC(scale(ΔC, conj(β)))
dA = @thunk begin
ipA = _canonicalize(invperm(linearize(pA)), A)
conjΔC = conjA == :C ? :C : :N
conjB′ = conjA == :C ? conjB : _conj(conjB)
_dA = zerovector(A, promote_contract(scalartype(ΔC), scalartype(B), typeof(α)))
_dA = planarcontract!(_dA, ΔC, pΔC, conjΔC, B, reverse(pB), conjB′, ipA,
conjA == :C ? α : conj(α), Zero(), backend...)
return projectA(_dA)
end
dB = @thunk begin
ipB = _canonicalize((invperm(linearize(pB)), ()), B)
conjΔC = conjB == :C ? :C : :N
conjA′ = conjB == :C ? conjA : _conj(conjA)
_dB = zerovector(B, promote_contract(scalartype(ΔC), scalartype(A), typeof(α)))
_dB = planarcontract!(_dB,
A, reverse(pA), conjA′,
ΔC, pΔC, conjΔC,
ipB, conjB == :C ? α : conj(α), Zero(), backend...)
return projectB(_dB)
end
dα = @thunk begin
_dα = tensorscalar(planarcontract(planarcontract(A, pA, conjA,
B, pB, conjB,
pAB, One(), backend...),
((), trivtuple(TensorOperations.numind(pAB))),
:C,
ΔC,
(trivtuple(TensorOperations.numind(pAB)), ()),
:N,
((), ()), One(), backend...))
return projectα(_dα)
end
dβ = @thunk begin
p′ = TensorKit.adjointtensorindices(C, trivtuple(pAB))
_dβ = tensorscalar(planarcontract(C', ((), p′),
ΔC, (trivtuple(pAB), ()), ((), ()),
One(), backend...))
return projectβ(_dβ)
end
dbackend = map(x -> NoTangent(), backend)
return NoTangent(), dC, dA, NoTangent(), NoTangent(), dB, NoTangent(), NoTangent(),
NoTangent(),
dα, dβ, dbackend...
end
return C′, planarcontract_pullback
end
function ChainRulesCore.rrule(::typeof(TensorKit.planartrace!),
C::AbstractTensorMap,
A::AbstractTensorMap,
p::Index2Tuple, q::Index2Tuple, conjA::Symbol,
α::Number, β::Number, backend::Backend...)
C′ = planartrace!(copy(C), A, p, q, conjA, α, β, backend...)
function planartrace_pullback(ΔC′)
return ΔC = unthunk(ΔC′)
end
return C′, planartrace_pullback
end
# Convert rrules
#----------------
function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{Dict}, t::AbstractTensorMap)
out = convert(Dict, t)
function convert_pullback(c)
if haskey(c, :data) # :data is the only thing for which this dual makes sense
dual = copy(out)
dual[:data] = c[:data]
return (NoTangent(), NoTangent(), convert(TensorMap, dual))
else
# instead of zero(t) you can also return ZeroTangent(), which is type unstable
return (NoTangent(), NoTangent(), zero(t))
end
end
return out, convert_pullback
end
function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{TensorMap},
t::Dict{Symbol,Any})
return convert(TensorMap, t), v -> (NoTangent(), NoTangent(), convert(Dict, v))
end
end