11function cached_permute (sym:: Symbol , t:: TensorMap{S} ,
2- p1:: IndexTuple{N₁} , p2:: IndexTuple{N₂} = ()) where {S, N₁, N₂}
2+ p1:: IndexTuple{N₁} , p2:: IndexTuple{N₂} = ();
3+ copy:: Bool = false ) where {S, N₁, N₂}
34 cod = ProductSpace {S, N₁} (map (n-> space (t, n), p1))
45 dom = ProductSpace {S, N₂} (map (n-> dual (space (t, n)), p2))
5-
66 # share data if possible
7- if p1 === codomainind (t) && p2 === domainind (t)
8- return t
9- elseif isa (t, TensorMap) && sectortype (S) === Trivial
10- stridet = i-> stride (t[], i)
11- sizet = i-> size (t[], i)
12- canfuse1, d1, s1 = TensorOperations. _canfuse (sizet .(p1), stridet .(p1))
13- canfuse2, d2, s2 = TensorOperations. _canfuse (sizet .(p2), stridet .(p2))
14- if canfuse1 && canfuse2 && s1 == 1 && (d2 == 1 || s2 == d1)
7+ if ! copy
8+ if p1 === codomainind (t) && p2 === domainind (t)
9+ return t
10+ elseif has_shared_permute (t, p1, p2)
1511 return TensorMap (reshape (t. data, dim (cod), dim (dom)), cod, dom)
1612 end
1713 end
@@ -23,11 +19,11 @@ function cached_permute(sym::Symbol, t::TensorMap{S},
2319end
2420
2521function cached_permute (sym:: Symbol , t:: AdjointTensorMap{S} ,
26- p1:: IndexTuple{N₁} , p2:: IndexTuple{N₂} = ()) where {S, N₁, N₂}
27-
22+ p1:: IndexTuple , p2:: IndexTuple = ();
23+ copy :: Bool = false ) where {S, N₁, N₂}
2824 p1′ = adjointtensorindices (t, p2)
2925 p2′ = adjointtensorindices (t, p1)
30- adjoint (cached_permute (sym, adjoint (t), p1′, p2′))
26+ adjoint (cached_permute (sym, adjoint (t), p1′, p2′; copy = copy ))
3127end
3228
3329scalar (t:: AbstractTensorMap{S} ) where {S<: IndexSpace } =
@@ -254,21 +250,21 @@ function contract!(α, A::AbstractTensorMap{S}, B::AbstractTensorMap{S},
254250 memcost4 += dB* (! hsp (B, oindB, cindB′′)) +
255251 dA* (! hsp (A, cindA′′, oindA))
256252
257- if min (memcost1, memcost2) <= min (memcost3, memcost4)
253+ # if min(memcost1, memcost2) <= min(memcost3, memcost4)
258254 if memcost1 <= memcost2
259255 return _contract! (α, A, B, β, C, oindA, cindA′, oindB, cindB′, p1, p2, syms)
260256 else
261257 return _contract! (α, A, B, β, C, oindA, cindA′′, oindB, cindB′′, p1, p2, syms)
262258 end
263- else
264- p1′ = map (n-> ifelse (n> N₁, n- N₁, n+ N₂), p1)
265- p2′ = map (n-> ifelse (n> N₁, n- N₁, n+ N₂), p2)
266- if memcost3 <= memcost4
267- return _contract! (α, B, A, β, C, oindB, cindB′, oindA, cindA′, p1′, p2′, syms)
268- else
269- return _contract! (α, B, A, β, C, oindB, cindB′′, oindA, cindA′′, p1′, p2′, syms)
270- end
271- end
259+ # else
260+ # p1′ = map(n->ifelse(n>N₁, n-N₁, n+N₂), p1)
261+ # p2′ = map(n->ifelse(n>N₁, n-N₁, n+N₂), p2)
262+ # if memcost3 <= memcost4
263+ # return _contract!(α, B, A, β, C, oindB, cindB′, oindA, cindA′, p1′, p2′, syms)
264+ # else
265+ # return _contract!(α, B, A, β, C, oindB, cindB′′, oindA, cindA′′, p1′, p2′, syms)
266+ # end
267+ # end
272268end
273269
274270function _contract! (α, A:: AbstractTensorMap{S} , B:: AbstractTensorMap{S} ,
@@ -278,13 +274,31 @@ function _contract!(α, A::AbstractTensorMap{S}, B::AbstractTensorMap{S},
278274 p1:: IndexTuple , p2:: IndexTuple ,
279275 syms:: Union{Nothing, NTuple{3, Symbol}} = nothing ) where {S, N₁, N₂}
280276
277+ if ! (BraidingStyle (sectortype (S)) isa SymmetricBraiding)
278+ throw (SectorMismatch (" only tensors with symmetric braiding rules can be contracted; try `@planar` instead" ))
279+ end
280+ copyA = false
281+ if BraidingStyle (sectortype (S)) isa Fermionic
282+ for i in cindA
283+ if ! isdual (space (A, i))
284+ copyA = true
285+ end
286+ end
287+ end
281288 if syms === nothing
282- A′ = permute (A, oindA, cindA)
289+ A′ = permute (A, oindA, cindA; copy = copyA )
283290 B′ = permute (B, cindB, oindB)
284291 else
285- A′ = cached_permute (syms[1 ], A, oindA, cindA)
292+ A′ = cached_permute (syms[1 ], A, oindA, cindA; copy = copyA )
286293 B′ = cached_permute (syms[2 ], B, cindB, oindB)
287294 end
295+ if BraidingStyle (sectortype (S)) isa Fermionic
296+ for i in domainind (A′)
297+ if ! isdual (space (A′, i))
298+ A′ = twist! (A′, i)
299+ end
300+ end
301+ end
288302 ipC = TupleTools. invperm ((p1... , p2... ))
289303 oindAinC = TupleTools. getindices (ipC, ntuple (n-> n, N₁))
290304 oindBinC = TupleTools. getindices (ipC, ntuple (n-> n+ N₁, N₂))
0 commit comments