@@ -216,15 +216,15 @@ end
216216# TODO : contraction with either A or B a rank (1, 1) tensor does not require to
217217# permute the fusion tree and should therefore be special cased. This will speed
218218# up MPS algorithms
219- function contract! (C:: AbstractTensorMap{S,N₁,N₂ } ,
220- A:: AbstractTensorMap{S} , pA :: Index2Tuple ,
221- B:: AbstractTensorMap{S} , pB :: Index2Tuple ,
222- pAB :: Index2Tuple{N₁,N₂} ,
223- α:: Number , β:: Number , backend... ) where {S,N₁,N₂}
219+ function contract! (C:: AbstractTensorMap{S} ,
220+ A:: AbstractTensorMap{S} , (oindA, cindA) :: Index2Tuple{N₁,N₃} ,
221+ B:: AbstractTensorMap{S} , (cindB, oindB) :: Index2Tuple{N₃,N₂} ,
222+ (p₁, p₂) :: Index2Tuple ,
223+ α:: Number , β:: Number , backend... ) where {S,N₁,N₂,N₃ }
224224
225225 # find optimal contraction scheme
226226 hsp = has_shared_permute
227- ipC = TupleTools. invperm (linearize (pAB ))
227+ ipC = TupleTools. invperm ((p₁ ... , p₂ ... ))
228228 oindAinC = TupleTools. getindices (ipC, ntuple (n -> n, N₁))
229229 oindBinC = TupleTools. getindices (ipC, ntuple (n -> n + N₁, N₂))
230230
@@ -239,32 +239,32 @@ function contract!(C::AbstractTensorMap{S,N₁,N₂},
239239 dA, dB, dC = dim (A), dim (B), dim (C)
240240
241241 # keep order A en B, check possibilities for cind
242- memcost1 = memcost2 = dC * (! hsp (C, oindAinC, oindBinC))
243- memcost1 += dA * (! hsp (A, oindA, cindA′)) +
244- dB * (! hsp (B, cindB′, oindB))
245- memcost2 += dA * (! hsp (A, oindA, cindA′′)) +
246- dB * (! hsp (B, cindB′′, oindB))
242+ memcost1 = memcost2 = dC * (! hsp (C, ( oindAinC, oindBinC) ))
243+ memcost1 += dA * (! hsp (A, ( oindA, cindA′) )) +
244+ dB * (! hsp (B, ( cindB′, oindB) ))
245+ memcost2 += dA * (! hsp (A, ( oindA, cindA′′) )) +
246+ dB * (! hsp (B, ( cindB′′, oindB) ))
247247
248248 # reverse order A en B, check possibilities for cind
249- memcost3 = memcost4 = dC * (! hsp (C, oindBinC, oindAinC))
250- memcost3 += dB * (! hsp (B, oindB, cindB′)) +
251- dA * (! hsp (A, cindA′, oindA))
252- memcost4 += dB * (! hsp (B, oindB, cindB′′)) +
253- dA * (! hsp (A, cindA′′, oindA))
249+ memcost3 = memcost4 = dC * (! hsp (C, ( oindBinC, oindAinC) ))
250+ memcost3 += dB * (! hsp (B, ( oindB, cindB′) )) +
251+ dA * (! hsp (A, ( cindA′, oindA) ))
252+ memcost4 += dB * (! hsp (B, ( oindB, cindB′′) )) +
253+ dA * (! hsp (A, ( cindA′′, oindA) ))
254254
255255 if min (memcost1, memcost2) <= min (memcost3, memcost4)
256256 if memcost1 <= memcost2
257- return _contract! (α, A, B, β, C, oindA, cindA′, oindB, cindB′, p₁, p₂, syms )
257+ return _contract! (α, A, B, β, C, oindA, cindA′, oindB, cindB′, p₁, p₂)
258258 else
259- return _contract! (α, A, B, β, C, oindA, cindA′′, oindB, cindB′′, p₁, p₂, syms )
259+ return _contract! (α, A, B, β, C, oindA, cindA′′, oindB, cindB′′, p₁, p₂)
260260 end
261261 else
262262 p1′ = map (n -> ifelse (n > N₁, n - N₁, n + N₂), p₁)
263263 p2′ = map (n -> ifelse (n > N₁, n - N₁, n + N₂), p₂)
264264 if memcost3 <= memcost4
265- return _contract! (α, B, A, β, C, oindB, cindB′, oindA, cindA′, p1′, p2′, syms )
265+ return _contract! (α, B, A, β, C, oindB, cindB′, oindA, cindA′, p1′, p2′)
266266 else
267- return _contract! (α, B, A, β, C, oindB, cindB′′, oindA, cindA′′, p1′, p2′, syms )
267+ return _contract! (α, B, A, β, C, oindB, cindB′′, oindA, cindA′′, p1′, p2′)
268268 end
269269 end
270270end
@@ -273,8 +273,7 @@ function _contract!(α, A::AbstractTensorMap{S}, B::AbstractTensorMap{S},
273273 β, C:: AbstractTensorMap{S} ,
274274 oindA:: IndexTuple{N₁} , cindA:: IndexTuple ,
275275 oindB:: IndexTuple{N₂} , cindB:: IndexTuple ,
276- p₁:: IndexTuple , p₂:: IndexTuple ,
277- syms:: Union{Nothing,NTuple{3,Symbol}} = nothing ) where {S,N₁,N₂}
276+ p₁:: IndexTuple , p₂:: IndexTuple ) where {S,N₁,N₂}
278277 if ! (BraidingStyle (sectortype (S)) isa SymmetricBraiding)
279278 throw (SectorMismatch (" only tensors with symmetric braiding rules can be contracted; try `@planar` instead" ))
280279 end
@@ -286,8 +285,8 @@ function _contract!(α, A::AbstractTensorMap{S}, B::AbstractTensorMap{S},
286285 end
287286 end
288287 end
289- A′ = permute (A, oindA, cindA; copy= copyA)
290- B′ = permute (B, cindB, oindB)
288+ A′ = permute (A, ( oindA, cindA) ; copy= copyA)
289+ B′ = permute (B, ( cindB, oindB) )
291290 if BraidingStyle (sectortype (S)) isa Fermionic
292291 for i in domainind (A′)
293292 if ! isdual (space (A′, i))
@@ -298,12 +297,12 @@ function _contract!(α, A::AbstractTensorMap{S}, B::AbstractTensorMap{S},
298297 ipC = TupleTools. invperm ((p₁... , p₂... ))
299298 oindAinC = TupleTools. getindices (ipC, ntuple (n -> n, N₁))
300299 oindBinC = TupleTools. getindices (ipC, ntuple (n -> n + N₁, N₂))
301- if has_shared_permute (C, oindAinC, oindBinC)
302- C′ = permute (C, oindAinC, oindBinC)
300+ if has_shared_permute (C, ( oindAinC, oindBinC) )
301+ C′ = permute (C, ( oindAinC, oindBinC) )
303302 mul! (C′, A′, B′, α, β)
304303 else
305304 C′ = A′ * B′
306- add! (α , C′, β, C, p₁, p₂)
305+ add_permute! (C , C′, ( p₁, p₂), α, β )
307306 end
308307 return C
309308end
0 commit comments