Skip to content

Commit 6c57abd

Browse files
committed
fix fermionic tensor trace
1 parent 2aa35fd commit 6c57abd

2 files changed

Lines changed: 30 additions & 15 deletions

File tree

src/tensors/tensoroperations.jl

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,10 @@ const _add_kernels = (_add_trivial_kernel!, _add_abelian_kernel!, _add_general_k
163163
function trace!(α, tsrc::AbstractTensorMap{S}, β, tdst::AbstractTensorMap{S, N₁, N₂},
164164
p1::IndexTuple{N₁}, p2::IndexTuple{N₂},
165165
q1::IndexTuple{N₃}, q2::IndexTuple{N₃}) where {S, N₁, N₂, N₃}
166-
# TODO: check Frobenius-Schur indicators!, and add fermions!
166+
167+
if !(BraidingStyle(sectortype(S)) isa SymmetricBraiding)
168+
throw(SectorMismatch("only tensors with symmetric braiding rules can be contracted; try `@planar` instead"))
169+
end
167170
@boundscheck begin
168171
all(i->space(tsrc, p1[i]) == space(tdst, i), 1:N₁) ||
169172
throw(SpaceMismatch("trace: tsrc = $(codomain(tsrc))$(domain(tsrc)),
@@ -203,6 +206,11 @@ function trace!(α, tsrc::AbstractTensorMap{S}, β, tdst::AbstractTensorMap{S, N
203206
f2′′, g2 = split(f2′, N₂)
204207
if g1 == g2
205208
coeff *= dim(g1.coupled)/dim(g1.uncoupled[1])
209+
for i = 2:length(g1.uncoupled)
210+
if !(g1.isdual[i])
211+
coeff *= twist(g1.uncoupled[i])
212+
end
213+
end
206214
TO._trace!*coeff, tsrc[f1, f2], true, tdst[f1′′, f2′′], pdata, q1, q2)
207215
end
208216
end
@@ -250,21 +258,21 @@ function contract!(α, A::AbstractTensorMap{S}, B::AbstractTensorMap{S},
250258
memcost4 += dB*(!hsp(B, oindB, cindB′′)) +
251259
dA*(!hsp(A, cindA′′, oindA))
252260

253-
# if min(memcost1, memcost2) <= min(memcost3, memcost4)
261+
if min(memcost1, memcost2) <= min(memcost3, memcost4)
254262
if memcost1 <= memcost2
255263
return _contract!(α, A, B, β, C, oindA, cindA′, oindB, cindB′, p1, p2, syms)
256264
else
257265
return _contract!(α, A, B, β, C, oindA, cindA′′, oindB, cindB′′, p1, p2, syms)
258266
end
259-
# else
260-
# p1′ = map(n->ifelse(n>N₁, n-N₁, n+N₂), p1)
261-
# p2′ = map(n->ifelse(n>N₁, n-N₁, n+N₂), p2)
262-
# if memcost3 <= memcost4
263-
# return _contract!(α, B, A, β, C, oindB, cindB′, oindA, cindA′, p1′, p2′, syms)
264-
# else
265-
# return _contract!(α, B, A, β, C, oindB, cindB′′, oindA, cindA′′, p1′, p2′, syms)
266-
# end
267-
# end
267+
else
268+
p1′ = map(n->ifelse(n>N₁, n-N₁, n+N₂), p1)
269+
p2′ = map(n->ifelse(n>N₁, n-N₁, n+N₂), p2)
270+
if memcost3 <= memcost4
271+
return _contract!(α, B, A, β, C, oindB, cindB′, oindA, cindA′, p1′, p2′, syms)
272+
else
273+
return _contract!(α, B, A, β, C, oindB, cindB′′, oindA, cindA′′, p1′, p2′, syms)
274+
end
275+
end
268276
end
269277

270278
function _contract!(α, A::AbstractTensorMap{S}, B::AbstractTensorMap{S},

test/tensors.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,13 @@ VSU₃ = (ℂ[SU3Irrep]((0,0,0)=>3, (1,0,0)=>1),
4949
ℂ[SU3Irrep]((1,0,0)=>1, (2,0,0)=>1),
5050
ℂ[SU3Irrep]((0,0,0)=>1, (1,0,0)=>1, (1,1,0)=>1)')
5151

52-
for V in (Vtr, Vℤ₂, Vfℤ₂, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VfSU₂)#, VSU₃)
52+
for V in (Vtr, Vℤ₂, Vfℤ₂, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VfSU₂, VSU₃)
5353
V1, V2, V3, V4, V5 = V
5454
@assert V3 * V4 * V2 V1' * V5' # necessary for leftorth tests
5555
@assert V3 * V4 V1' * V2' * V5' # necessary for rightorth tests
5656
end
5757

58-
for V in (Vtr, Vℤ₂, Vℤ₃, VU₁, VCU₁, VSU₂, VSU₃)
58+
for V in (Vtr, Vℤ₂, Vfℤ₂, Vℤ₃, VU₁, VfU₁, VCU₁, VSU₂, VfSU₂, VSU₃)
5959
I = sectortype(first(V))
6060
Istr = TensorKit.type_repr(I)
6161
println("---------------------------------------")
@@ -205,11 +205,18 @@ for V in (Vtr, Vℤ₂, Vℤ₃, VU₁, VCU₁, VSU₂, VSU₃)
205205
t2 = permute(t, (1,2), (4,3))
206206
s = @constinferred tr(t2)
207207
@test conj(s) tr(t2')
208+
if !isdual(V1)
209+
t2 = twist!(t2, 1)
210+
end
211+
if isdual(V2)
212+
t2 = twist!(t2, 2)
213+
end
214+
ss = tr(t2)
208215
@tensor s2 = t[a,b,b,a]
209216
@tensor t3[a,b] := t[a,c,c,b]
210217
@tensor s3 = t3[a,a]
211-
@test s s2
212-
@test s s3
218+
@test ss s2
219+
@test ss s3
213220
end
214221
@timedtestset "Partial trace: test self-consistency" begin
215222
t = Tensor(rand, ComplexF64, V1 V2' V3 V2 V1' V3')

0 commit comments

Comments
 (0)