Skip to content

Commit 5060bd7

Browse files
committed
part2b - tests are working?
1 parent 1166f24 commit 5060bd7

6 files changed

Lines changed: 84 additions & 65 deletions

File tree

src/planar/planaroperations.jl

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,32 @@
11
# planar versions of tensor operations add!, trace! and contract!
2-
function planaradd!(C::AbstractTensorMap{S,N₁,N₂}, p::Index2Tuple{N₁,N₂},
2+
function planaradd!(C::AbstractTensorMap{S,N₁,N₂},
33
A::AbstractTensorMap{S},
4-
α, β, backend::Backend...) where {S,N₁,N₂}
4+
p::Index2Tuple{N₁,N₂},
5+
α,
6+
β,
7+
backend::Backend...) where {S,N₁,N₂}
58
return add_transpose!(C, A, p, α, β, backend...)
69
end
710

8-
function planartrace!(C::AbstractTensorMap{S,N₁,N₂}, p::Index2Tuple{N₁,N₂},
9-
A::AbstractTensorMap{S}, q::Index2Tuple{N₃,N₃},
10-
α, β, backend::Backend...) where {S,N₁,N₂,N₃}
11+
function planartrace!(C::AbstractTensorMap{S,N₁,N₂},
12+
A::AbstractTensorMap{S},
13+
p::Index2Tuple{N₁,N₂},
14+
q::Index2Tuple{N₃,N₃},
15+
α,
16+
β,
17+
backend::Backend...) where {S,N₁,N₂,N₃}
1118
if BraidingStyle(sectortype(S)) == Bosonic()
1219
return trace_permute!(C, A, p, q, α, β, backend...)
1320
end
1421

1522
@boundscheck begin
16-
all(i -> space(A, pC[1][i]) == space(C, i), 1:N₁) ||
23+
all(i -> space(A, p[1][i]) == space(C, i), 1:N₁) ||
1724
throw(SpaceMismatch("trace: A = $(codomain(A))$(domain(A)),
1825
C = $(codomain(C))$(domain(C)), p1 = $(p1), p2 = $(p2)"))
19-
all(i -> space(A, pC[2][i]) == space(C, N₁ + i), 1:N₂) ||
26+
all(i -> space(A, p[2][i]) == space(C, N₁ + i), 1:N₂) ||
2027
throw(SpaceMismatch("trace: A = $(codomain(A))$(domain(A)),
2128
C = $(codomain(C))$(domain(C)), p1 = $(p1), p2 = $(p2)"))
22-
all(i -> space(A, pA[1][i]) == dual(space(A, pA[2][i])), 1:N₃) ||
29+
all(i -> space(A, q[1][i]) == dual(space(A, q[2][i])), 1:N₃) ||
2330
throw(SpaceMismatch("trace: A = $(codomain(A))$(domain(A)),
2431
q1 = $(q1), q2 = $(q2)"))
2532
end
@@ -29,19 +36,24 @@ function planartrace!(C::AbstractTensorMap{S,N₁,N₂}, p::Index2Tuple{N₁,N
2936
elseif !isone(β)
3037
rmul!(C, β)
3138
end
32-
pdata = linearize(pC)
3339
for (f₁, f₂) in fusiontrees(A)
3440
for ((f₁′, f₂′), coeff) in planar_trace(f₁, f₂, p..., q...)
35-
TO.tensortrace!(C[f₁′, f₂′], p, A[f₁, f₂], q, α * coeff, true, backend...)
41+
TO.tensortrace!(C[f₁′, f₂′], p, A[f₁, f₂], q, :N, α * coeff, true, backend...)
3642
end
3743
end
3844
return C
3945
end
4046

4147

42-
function planarcontract!(C::AbstractTensorMap{S,N₁,N₂}, pAB::Index2Tuple{N₁,N₂},
43-
A::AbstractTensorMap{S}, pA::Index2Tuple, B::AbstractTensorMap{S},
44-
pB::Index2Tuple, α, β, backend::Backend...) where {S,N₁,N₂}
48+
function planarcontract!(C::AbstractTensorMap{S,N₁,N₂},
49+
A::AbstractTensorMap{S},
50+
pA::Index2Tuple,
51+
B::AbstractTensorMap{S},
52+
pB::Index2Tuple,
53+
pAB::Index2Tuple{N₁,N₂},
54+
α,
55+
β,
56+
backend::Backend...) where {S,N₁,N₂}
4557
codA, domA = codomainind(A), domainind(A)
4658
codB, domB = codomainind(B), domainind(B)
4759
oindA, cindA = pA

src/planar/preprocessors.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,8 +356,16 @@ function _decompose_planar_contractions(ex::Expr, temporaries)
356356
lhs, rhs = TO.getlhs(ex), TO.getrhs(ex)
357357
if TO.istensorexpr(rhs)
358358
pre = Vector{Any}()
359-
rhs = _extract_contraction_pairs(rhs, lhs, pre, temporaries)
360-
return Expr(:block, pre..., Expr(ex.head, lhs, rhs))
359+
if TO.istensor(lhs)
360+
rhs = _extract_contraction_pairs(rhs, lhs, pre, temporaries)
361+
return Expr(:block, pre..., Expr(ex.head, lhs, rhs))
362+
else
363+
lhssym = gensym(string(lhs))
364+
lhstensor = Expr(:typed_vcat, lhssym, Expr(:tuple), Expr(:tuple))
365+
rhs = _extract_contraction_pairs(rhs, lhstensor, pre, temporaries)
366+
push!(temporaries, lhssym)
367+
return Expr(:block, pre..., Expr(:(:=), lhstensor, rhs), Expr(:(=), lhs, lhstensor))
368+
end
361369
else
362370
return ex
363371
end

src/tensors/indexmanipulations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ function has_shared_permute(t::TensorMap, (p₁, p₂)::Index2Tuple)
6868
return false
6969
end
7070
end
71-
function has_shared_permute(t::AdjointTensorMap, p₁, p₂)
71+
function has_shared_permute(t::AdjointTensorMap, (p₁, p₂)::Index2Tuple)
7272
p₁′ = adjointtensorindices(t, p₂)
7373
p₂′ = adjointtensorindices(t, p₁)
7474
return has_shared_permute(t', (p₁′, p₂′))

src/tensors/tensoroperations.jl

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -216,15 +216,15 @@ end
216216
# TODO: contraction with either A or B a rank (1, 1) tensor does not require to
217217
# permute the fusion tree and should therefore be special cased. This will speed
218218
# up MPS algorithms
219-
function contract!(C::AbstractTensorMap{S,N₁,N₂},
220-
A::AbstractTensorMap{S}, pA::Index2Tuple,
221-
B::AbstractTensorMap{S}, pB::Index2Tuple,
222-
pAB::Index2Tuple{N₁,N₂},
223-
α::Number, β::Number, backend...) where {S,N₁,N₂}
219+
function contract!(C::AbstractTensorMap{S},
220+
A::AbstractTensorMap{S}, (oindA, cindA)::Index2Tuple{N₁,N₃},
221+
B::AbstractTensorMap{S}, (cindB, oindB)::Index2Tuple{N₃,N₂},
222+
(p₁, p₂)::Index2Tuple,
223+
α::Number, β::Number, backend...) where {S,N₁,N₂,N₃}
224224

225225
# find optimal contraction scheme
226226
hsp = has_shared_permute
227-
ipC = TupleTools.invperm(linearize(pAB))
227+
ipC = TupleTools.invperm((p₁..., p₂...))
228228
oindAinC = TupleTools.getindices(ipC, ntuple(n -> n, N₁))
229229
oindBinC = TupleTools.getindices(ipC, ntuple(n -> n + N₁, N₂))
230230

@@ -239,32 +239,32 @@ function contract!(C::AbstractTensorMap{S,N₁,N₂},
239239
dA, dB, dC = dim(A), dim(B), dim(C)
240240

241241
# keep order A en B, check possibilities for cind
242-
memcost1 = memcost2 = dC * (!hsp(C, oindAinC, oindBinC))
243-
memcost1 += dA * (!hsp(A, oindA, cindA′)) +
244-
dB * (!hsp(B, cindB′, oindB))
245-
memcost2 += dA * (!hsp(A, oindA, cindA′′)) +
246-
dB * (!hsp(B, cindB′′, oindB))
242+
memcost1 = memcost2 = dC * (!hsp(C, (oindAinC, oindBinC)))
243+
memcost1 += dA * (!hsp(A, (oindA, cindA′))) +
244+
dB * (!hsp(B, (cindB′, oindB)))
245+
memcost2 += dA * (!hsp(A, (oindA, cindA′′))) +
246+
dB * (!hsp(B, (cindB′′, oindB)))
247247

248248
# reverse order A en B, check possibilities for cind
249-
memcost3 = memcost4 = dC * (!hsp(C, oindBinC, oindAinC))
250-
memcost3 += dB * (!hsp(B, oindB, cindB′)) +
251-
dA * (!hsp(A, cindA′, oindA))
252-
memcost4 += dB * (!hsp(B, oindB, cindB′′)) +
253-
dA * (!hsp(A, cindA′′, oindA))
249+
memcost3 = memcost4 = dC * (!hsp(C, (oindBinC, oindAinC)))
250+
memcost3 += dB * (!hsp(B, (oindB, cindB′))) +
251+
dA * (!hsp(A, (cindA′, oindA)))
252+
memcost4 += dB * (!hsp(B, (oindB, cindB′′))) +
253+
dA * (!hsp(A, (cindA′′, oindA)))
254254

255255
if min(memcost1, memcost2) <= min(memcost3, memcost4)
256256
if memcost1 <= memcost2
257-
return _contract!(α, A, B, β, C, oindA, cindA′, oindB, cindB′, p₁, p₂, syms)
257+
return _contract!(α, A, B, β, C, oindA, cindA′, oindB, cindB′, p₁, p₂)
258258
else
259-
return _contract!(α, A, B, β, C, oindA, cindA′′, oindB, cindB′′, p₁, p₂, syms)
259+
return _contract!(α, A, B, β, C, oindA, cindA′′, oindB, cindB′′, p₁, p₂)
260260
end
261261
else
262262
p1′ = map(n -> ifelse(n > N₁, n - N₁, n + N₂), p₁)
263263
p2′ = map(n -> ifelse(n > N₁, n - N₁, n + N₂), p₂)
264264
if memcost3 <= memcost4
265-
return _contract!(α, B, A, β, C, oindB, cindB′, oindA, cindA′, p1′, p2′, syms)
265+
return _contract!(α, B, A, β, C, oindB, cindB′, oindA, cindA′, p1′, p2′)
266266
else
267-
return _contract!(α, B, A, β, C, oindB, cindB′′, oindA, cindA′′, p1′, p2′, syms)
267+
return _contract!(α, B, A, β, C, oindB, cindB′′, oindA, cindA′′, p1′, p2′)
268268
end
269269
end
270270
end
@@ -273,8 +273,7 @@ function _contract!(α, A::AbstractTensorMap{S}, B::AbstractTensorMap{S},
273273
β, C::AbstractTensorMap{S},
274274
oindA::IndexTuple{N₁}, cindA::IndexTuple,
275275
oindB::IndexTuple{N₂}, cindB::IndexTuple,
276-
p₁::IndexTuple, p₂::IndexTuple,
277-
syms::Union{Nothing,NTuple{3,Symbol}}=nothing) where {S,N₁,N₂}
276+
p₁::IndexTuple, p₂::IndexTuple) where {S,N₁,N₂}
278277
if !(BraidingStyle(sectortype(S)) isa SymmetricBraiding)
279278
throw(SectorMismatch("only tensors with symmetric braiding rules can be contracted; try `@planar` instead"))
280279
end
@@ -286,8 +285,8 @@ function _contract!(α, A::AbstractTensorMap{S}, B::AbstractTensorMap{S},
286285
end
287286
end
288287
end
289-
A′ = permute(A, oindA, cindA; copy=copyA)
290-
B′ = permute(B, cindB, oindB)
288+
A′ = permute(A, (oindA, cindA); copy=copyA)
289+
B′ = permute(B, (cindB, oindB))
291290
if BraidingStyle(sectortype(S)) isa Fermionic
292291
for i in domainind(A′)
293292
if !isdual(space(A′, i))
@@ -298,12 +297,12 @@ function _contract!(α, A::AbstractTensorMap{S}, B::AbstractTensorMap{S},
298297
ipC = TupleTools.invperm((p₁..., p₂...))
299298
oindAinC = TupleTools.getindices(ipC, ntuple(n -> n, N₁))
300299
oindBinC = TupleTools.getindices(ipC, ntuple(n -> n + N₁, N₂))
301-
if has_shared_permute(C, oindAinC, oindBinC)
302-
C′ = permute(C, oindAinC, oindBinC)
300+
if has_shared_permute(C, (oindAinC, oindBinC))
301+
C′ = permute(C, (oindAinC, oindBinC))
303302
mul!(C′, A′, B′, α, β)
304303
else
305304
C′ = A′ * B′
306-
add!, C′, β, C, p₁, p₂)
305+
add_permute!(C, C′, (p₁, p₂), α, β)
307306
end
308307
return C
309308
end

test/planar.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@ function force_planar(V::GradedSpace)
1313
end
1414
force_planar(V::ProductSpace) = mapreduce(force_planar, , V)
1515
function force_planar(tsrc::TensorMap{ComplexSpace})
16-
tdst = TensorMap(undef, eltype(tsrc),
16+
tdst = TensorMap(undef, scalartype(tsrc),
1717
force_planar(codomain(tsrc)) force_planar(domain(tsrc)))
1818
copyto!(blocks(tdst)[PlanarTrivial()], blocks(tsrc)[Trivial()])
1919
return tdst
2020
end
2121
function force_planar(tsrc::TensorMap{<:GradedSpace})
22-
tdst = TensorMap(undef, eltype(tsrc),
22+
tdst = TensorMap(undef, scalartype(tsrc),
2323
force_planar(codomain(tsrc)) force_planar(domain(tsrc)))
2424
for (c, b) in blocks(tsrc)
2525
copyto!(blocks(tdst)[c PlanarTrivial()], b)
@@ -33,22 +33,22 @@ end
3333
C = TensorMap(randn, (ℂ^5)' (ℂ^6)' ^4 (ℂ^3)' (ℂ^2)')
3434
A′ = force_planar(A)
3535
C′ = force_planar(C)
36-
pC = ((4, 3), (5, 2, 1))
36+
p = ((4, 3), (5, 2, 1))
3737

38-
@test force_planar(tensoradd!(C, pC, A, :N, true, true))
39-
planaradd!(C′, pC, A′, true, true)
38+
@test force_planar(tensoradd!(C, p, A, :N, true, true))
39+
planaradd!(C′, A′, p, true, true)
4040
end
4141

4242
@testset "planartrace" begin
4343
A = TensorMap(randn, ℂ^2 ^3 ^2 ^5 ^4)
4444
C = TensorMap(randn, (ℂ^5)' ^3 ^4)
4545
A′ = force_planar(A)
4646
C′ = force_planar(C)
47-
pA = ((1,), (3,))
48-
pC = ((4, 2), (5,))
47+
p = ((4, 2), (5,))
48+
q = ((1,), (3,))
4949

50-
@test force_planar(tensortrace!(C, pC, A, pA, :N, true, true))
51-
planartrace!(C′, pC, A′, pA, true, true)
50+
@test force_planar(tensortrace!(C, p, A, q, :N, true, true))
51+
planartrace!(C′, A′, p, q, true, true)
5252
end
5353

5454
@testset "planarcontract" begin
@@ -62,10 +62,10 @@ end
6262

6363
pA = ((1, 3, 4), (5, 2))
6464
pB = ((2, 4), (1, 3))
65-
pC = ((3, 2, 1), (4, 5))
65+
pAB = ((3, 2, 1), (4, 5))
6666

67-
@test force_planar(tensorcontract!(C, pC, A, pA, :N, B, pB, :N, true, true))
68-
planarcontract!(C′, pC, A′, pA, B′, pB, true, true)
67+
@test force_planar(tensorcontract!(C, pAB, A, pA, :N, B, pB, :N, true, true))
68+
planarcontract!(C′, A′, pA, B′, pB, pAB, true, true)
6969
end
7070
end
7171

test/tensors.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,9 @@ for V in spacelist
193193
for p in permutations(1:5)
194194
p1 = ntuple(n -> p[n], k)
195195
p2 = ntuple(n -> p[k + n], 5 - k)
196-
t2 = @constinferred permute(t, p1, p2)
196+
t2 = @constinferred permute(t, (p1, p2))
197197
@test norm(t2) norm(t)
198-
t2′ = permute(t′, p1, p2)
198+
t2′ = permute(t′, (p1, p2))
199199
@test dot(t2′, t2) dot(t′, t) dot(transpose(t2′), transpose(t2))
200200
end
201201
end
@@ -208,7 +208,7 @@ for V in spacelist
208208
for p in permutations(1:5)
209209
p1 = ntuple(n -> p[n], k)
210210
p2 = ntuple(n -> p[k + n], 5 - k)
211-
t2 = permute(t, p1, p2)
211+
t2 = permute(t, (p1, p2))
212212
a2 = convert(Array, t2)
213213
@test a2 permutedims(convert(Array, t), (p1..., p2...))
214214
@test convert(Array, transpose(t2)) permutedims(a2, (5, 4, 3, 2, 1))
@@ -218,7 +218,7 @@ for V in spacelist
218218
end
219219
@timedtestset "Full trace: test self-consistency" begin
220220
t = Tensor(rand, ComplexF64, V1 V2' V2 V1')
221-
t2 = permute(t, (1, 2), (4, 3))
221+
t2 = permute(t, ((1, 2), (4, 3)))
222222
s = @constinferred tr(t2)
223223
@test conj(s) tr(t2')
224224
if !isdual(V1)
@@ -345,7 +345,7 @@ for V in spacelist
345345
Q, R = @constinferred leftorth(t, (3, 4, 2), (1, 5); alg=alg)
346346
QdQ = Q' * Q
347347
@test QdQ one(QdQ)
348-
@test Q * R permute(t, (3, 4, 2), (1, 5))
348+
@test Q * R permute(t, ((3, 4, 2), (1, 5)))
349349
if alg isa Polar
350350
@test isposdef(R)
351351
@test domain(R) == codomain(R) == space(t, 1)' space(t, 5)'
@@ -356,7 +356,7 @@ for V in spacelist
356356
N = @constinferred leftnull(t, (3, 4, 2), (1, 5); alg=alg)
357357
NdN = N' * N
358358
@test NdN one(NdN)
359-
@test norm(N' * permute(t, (3, 4, 2), (1, 5))) < 100 * eps(norm(t))
359+
@test norm(N' * permute(t, ((3, 4, 2), (1, 5)))) < 100 * eps(norm(t))
360360
end
361361
@testset "rightorth with $alg" for alg in
362362
(TensorKit.RQ(), TensorKit.RQpos(),
@@ -366,7 +366,7 @@ for V in spacelist
366366
L, Q = @constinferred rightorth(t, (3, 4), (2, 1, 5); alg=alg)
367367
QQd = Q * Q'
368368
@test QQd one(QQd)
369-
@test L * Q permute(t, (3, 4), (2, 1, 5))
369+
@test L * Q permute(t, ((3, 4), (2, 1, 5)))
370370
if alg isa Polar
371371
@test isposdef(L)
372372
@test domain(L) == codomain(L) == space(t, 3) space(t, 4)
@@ -377,15 +377,15 @@ for V in spacelist
377377
M = @constinferred rightnull(t, (3, 4), (2, 1, 5); alg=alg)
378378
MMd = M * M'
379379
@test MMd one(MMd)
380-
@test norm(permute(t, (3, 4), (2, 1, 5)) * M') < 100 * eps(norm(t))
380+
@test norm(permute(t, ((3, 4), (2, 1, 5))) * M') < 100 * eps(norm(t))
381381
end
382382
@testset "tsvd with $alg" for alg in (TensorKit.SVD(), TensorKit.SDD())
383383
U, S, V = @constinferred tsvd(t, (3, 4, 2), (1, 5); alg=alg)
384384
UdU = U' * U
385385
@test UdU one(UdU)
386386
VVd = V * V'
387387
@test VVd one(VVd)
388-
@test U * S * V permute(t, (3, 4, 2), (1, 5))
388+
@test U * S * V permute(t, ((3, 4, 2), (1, 5)))
389389
end
390390
end
391391
@testset "empty tensor" begin
@@ -436,7 +436,7 @@ for V in spacelist
436436
VdV = V' * V
437437
VdV = (VdV + VdV') / 2
438438
@test isposdef(VdV)
439-
t2 = permute(t, (1, 3), (2, 4))
439+
t2 = permute(t, ((1, 3), (2, 4)))
440440
@test t2 * V V * D
441441
@test !isposdef(t2) # unlikely for non-hermitian map
442442
t2 = (t2 + t2')

0 commit comments

Comments
 (0)