Skip to content

Commit dc1159f

Browse files
committed
fix exponentiation and add test, rename tensor concatenation
1 parent d829ceb commit dc1159f

3 files changed

Lines changed: 46 additions & 18 deletions

File tree

src/tensors/linalg.jl

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,19 @@ function Base.:-(t1::AbstractTensorMap, t2::AbstractTensorMap)
1212
return axpy!(-one(T), t2, copyto!(similar(t1, T), t1))
1313
end
1414

15-
Base.:*(t::AbstractTensorMap, α::Number) = mul!(similar(t, promote_type(eltype(t), typeof(α))), t, α)
16-
Base.:*::Number, t::AbstractTensorMap) = mul!(similar(t, promote_type(eltype(t), typeof(α))), α, t)
15+
Base.:*(t::AbstractTensorMap, α::Number) =
16+
mul!(similar(t, promote_type(eltype(t), typeof(α))), t, α)
17+
Base.:*::Number, t::AbstractTensorMap) =
18+
mul!(similar(t, promote_type(eltype(t), typeof(α))), α, t)
1719
Base.:/(t::AbstractTensorMap, α::Number) = *(t, one(α)/α)
1820
Base.:\::Number, t::AbstractTensorMap) = *(t, one(α)/α)
1921

2022
LinearAlgebra.normalize!(t::AbstractTensorMap, p::Real = 2) = rmul!(t, inv(norm(t, p)))
21-
LinearAlgebra.normalize(t::AbstractTensorMap, p::Real = 2) = mul!(similar(t), t, inv(norm(t, p)))
23+
LinearAlgebra.normalize(t::AbstractTensorMap, p::Real = 2) =
24+
mul!(similar(t), t, inv(norm(t, p)))
2225

23-
Base.:*(t1::AbstractTensorMap, t2::AbstractTensorMap) = mul!(similar(t1, promote_type(eltype(t1),eltype(t2)), codomain(t1)domain(t2)), t1, t2)
26+
Base.:*(t1::AbstractTensorMap, t2::AbstractTensorMap) =
27+
mul!(similar(t1, promote_type(eltype(t1),eltype(t2)), codomain(t1)domain(t2)), t1, t2)
2428
Base.exp(t::AbstractTensorMap) = exp!(copy(t))
2529

2630
# Special purpose constructors
@@ -159,15 +163,26 @@ end
159163

160164
# TensorMap exponentation:
161165
function exp!(t::TensorMap)
162-
domain(t) == codomain(t) || error("Exponentional of a tensor only exist when domain == codomain.")
166+
domain(t) == codomain(t) ||
167+
error("Exponentional of a tensor only exist when domain == codomain.")
163168
for (c,b) in blocks(t)
164-
copyto!(b, exp!(b))
169+
copyto!(b, LinearAlgebra.exp!(b))
165170
end
166171
return t
167172
end
168173

169-
# hcat and vcat of tensors
170-
function Base.hcat(t1::AbstractTensorMap{S,N₁,1}, t2::AbstractTensorMap{S,N₁,1}) where {S,N₁}
174+
# # concatenate tensors
175+
# function concatenate(t1::AbstractTensorMap{S}, ts::AbstractTensorMap{S}...; direction::Symbol) where {S}
176+
# if direction = :domain
177+
# numin(t1) == 1 || throw(SpaceMismatch("concatenation along domain"))
178+
# for t2 in ts
179+
# domain(t1) == domain(t2) || throw(SpaceMismatch())
180+
# end
181+
# t = t1
182+
# for t2 in ts
183+
# cat(t1)
184+
185+
function catdomain(t1::AbstractTensorMap{S,N₁,1}, t2::AbstractTensorMap{S,N₁,1}) where {S,N₁}
171186
codomain(t1) == codomain(t2) || throw(SpaceMismatch())
172187

173188
V1, = domain(t1)
@@ -182,7 +197,7 @@ function Base.hcat(t1::AbstractTensorMap{S,N₁,1}, t2::AbstractTensorMap{S,N₁
182197
end
183198
return t
184199
end
185-
function Base.vcat(t1::AbstractTensorMap{S,1,N₂}, t2::AbstractTensorMap{S,1,N₂}) where {S,N₂}
200+
function catcodomain(t1::AbstractTensorMap{S,1,N₂}, t2::AbstractTensorMap{S,1,N₂}) where {S,N₂}
186201
domain(t1) == domain(t2) || throw(SpaceMismatch())
187202

188203
V1, = codomain(t1)

src/tensors/tensoroperations.jl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ function add!(α, tsrc::AbstractTensorMap{S}, β, tdst::AbstractTensorMap{S,N₁
131131
end
132132
end
133133
else # debugging is easier this way
134-
@inbounds for (f1,f2) in fusiontrees(tsrc)
134+
for (f1,f2) in fusiontrees(tsrc)
135135
_addabelianblock!(α, tsrc, β, tdst, p1, p2, f1, f2)
136136
end
137137
end
@@ -145,46 +145,49 @@ function add!(α, tsrc::AbstractTensorMap{S}, β, tdst::AbstractTensorMap{S,N₁
145145
elseif β != 1
146146
mul!(tdst, β, tdst)
147147
end
148-
@inbounds for (f1,f2) in fusiontrees(tsrc)
148+
for (f1,f2) in fusiontrees(tsrc)
149149
for ((f1′,f2′), coeff) in permute(f1, f2, p1, p2)
150-
for i in p2
150+
@inbounds for i in p2
151151
if i <= n && !isdual(cod[i])
152152
b = f1.uncoupled[i]
153153
coeff *= frobeniusschur(b) #*fermionparity(b)
154154
end
155155
end
156-
for i in p1
156+
@inbounds for i in p1
157157
if i > n && isdual(dom[i-n])
158158
b = f2.uncoupled[i-n]
159159
coeff /= frobeniusschur(b) #*fermionparity(b)
160160
end
161161
end
162-
axpy!*coeff, permutedims(tsrc[f1,f2], pdata), tdst[f1′,f2′])
162+
@inbounds axpy!*coeff, permutedims(tsrc[f1,f2], pdata), tdst[f1′,f2′])
163163
end
164164
end
165165
end
166166
return tdst
167167
end
168168

169-
@inbounds function _addabelianblock!(α, tsrc::AbstractTensorMap{S}, β, tdst::AbstractTensorMap{S,N₁,N₂}, p1::IndexTuple{N₁}, p2::IndexTuple{N₂}, f1::FusionTree, f2::FusionTree) where {S,N₁,N₂}
169+
function _addabelianblock!(α, tsrc::AbstractTensorMap{S},
170+
β, tdst::AbstractTensorMap{S,N₁,N₂},
171+
p1::IndexTuple{N₁}, p2::IndexTuple{N₂},
172+
f1::FusionTree, f2::FusionTree) where {S,N₁,N₂}
170173
cod = codomain(tsrc)
171174
dom = domain(tsrc)
172175
n = length(cod)
173176
(f1′,f2′), coeff = first(permute(f1, f2, p1, p2))
174-
for i in p2
177+
@inbounds for i in p2
175178
if i <= n && !isdual(cod[i])
176179
b = f1.uncoupled[i]
177180
coeff *= frobeniusschur(b) #*fermionparity(b)
178181
end
179182
end
180-
for i in p1
183+
@inbounds for i in p1
181184
if i > n && isdual(dom[i-n])
182185
b = f2.uncoupled[i-n]
183186
coeff /= frobeniusschur(b) #*fermionparity(b)
184187
end
185188
end
186189
pdata = (p1...,p2...)
187-
axpby!*coeff, permutedims(tsrc[f1,f2], pdata), β, tdst[f1′,f2′])
190+
@inbounds axpby!*coeff, permutedims(tsrc[f1,f2], pdata), β, tdst[f1′,f2′])
188191
end
189192

190193

test/tensors.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,4 +159,14 @@ VSU₂ = (ℂ[SU₂](0=>1, 1//2=>1, 1=>2),
159159
end
160160
end
161161
end
162+
@testset "Exponentiation" begin
163+
W = V1 V2 V3
164+
for T in (Float32, Float64, ComplexF32, ComplexF64)
165+
t = TensorMap(rand, T, W, W)
166+
s = dim(W)
167+
expt = @inferred exp(t)
168+
@test reshape(convert(Array, expt), (s,s))
169+
exp(reshape(convert(Array, t), (s,s)))
170+
end
171+
end
162172
end

0 commit comments

Comments
 (0)