Skip to content

Commit eabf1e1

Browse files
committed
bug fixes and tests for inv and pinv
1 parent 37e3c38 commit eabf1e1

2 files changed

Lines changed: 96 additions & 46 deletions

File tree

src/tensors/linalg.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,9 @@ function LinearAlgebra.mul!(tC::AbstractTensorMap, tA::AbstractTensorMap, tB::A
162162
end
163163

164164
# TensorMap inverse
165-
function Base.inv(t::TensorMap)
165+
function Base.inv(t::AbstractTensorMap)
166166
domain(t) == codomain(t) ||
167-
SpaceMismatch("Inverse of a tensor only exist when domain == codomain; check pinv")
167+
throw(SpaceMismatch("Inverse of a tensor only exist when domain == codomain; check pinv"))
168168
if sectortype(t) === Trivial
169169
return TensorMap(inv(block(t, Trivial())), domain(t)codomain(t))
170170
else
@@ -175,7 +175,7 @@ function Base.inv(t::TensorMap)
175175
return TensorMap(data, domain(t)codomain(t))
176176
end
177177
end
178-
function LinearAlgebra.pinv(t::TensorMap; kwargs...)
178+
function LinearAlgebra.pinv(t::AbstractTensorMap; kwargs...)
179179
if sectortype(t) === Trivial
180180
return TensorMap(pinv(block(t, Trivial()); kwargs...), domain(t)codomain(t))
181181
else
@@ -186,21 +186,21 @@ function LinearAlgebra.pinv(t::TensorMap; kwargs...)
186186
return TensorMap(data, domain(t)codomain(t))
187187
end
188188
end
189-
function Base.:(\)(t1::TensorMap, t2::TensorMap)
189+
function Base.:(\)(t1::AbstractTensorMap, t2::AbstractTensorMap)
190190
codomain(t1) == codomain(t2) ||
191-
SpaceMismatch("non-matching codomains in t1 \\ t2")
191+
throw(SpaceMismatch("non-matching codomains in t1 \\ t2"))
192192
if sectortype(t1) === Trivial
193193
data = block(t1, Trivial()) \ block(t2, Trivial())
194194
return TensorMap(data, domain(t1)domain(t2))
195195
else
196196
cod = codomain(t1)
197-
data = SectorDict(c=>block(t1,c) / block(t2,c) for c in blocksectors(codomain(t1)))
197+
data = SectorDict(c=>block(t1,c) \ block(t2,c) for c in blocksectors(codomain(t1)))
198198
return TensorMap(data, domain(t1)domain(t2))
199199
end
200200
end
201-
function Base.:(/)(t1::TensorMap, t2::TensorMap)
201+
function Base.:(/)(t1::AbstractTensorMap, t2::AbstractTensorMap)
202202
domain(t1) == domain(t2) ||
203-
SpaceMismatch("non-matching domains in t1 / t2")
203+
throw(SpaceMismatch("non-matching domains in t1 / t2"))
204204
if sectortype(t1) === Trivial
205205
data = block(t1, Trivial()) / block(t2, Trivial())
206206
return TensorMap(data, codomain(t1)codomain(t2))

test/tensors.jl

Lines changed: 88 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -105,59 +105,76 @@ for (G,V) in ((Trivial, Vtr), (ℤ₂, Vℤ₂), (ℤ₃, Vℤ₃), (U₁, VU₁
105105
end
106106
end
107107
end
108-
@testset "Tensor product: test via norm preservation" begin
109-
for T in (Float32, Float64, ComplexF32, ComplexF64)
110-
t1 = TensorMap(rand, T, V2 V3 V1, V1 V2)
111-
t2 = TensorMap(rand, T, V2 V1 V3, V1 V1)
112-
t = @inferred (t1 t2)
113-
@test norm(t) norm(t1) * norm(t2)
114-
end
115-
end
116-
@testset "Tensor product: test via conversion" begin
117-
for T in (Float32, Float64, ComplexF32, ComplexF64)
118-
t1 = TensorMap(rand, T, V2 V3 V1, V1)
119-
t2 = TensorMap(rand, T, V2 V1 V3, V2)
120-
t = @inferred (t1 t2)
121-
d1 = dim(codomain(t1))
122-
d2 = dim(codomain(t2))
123-
d3 = dim(domain(t1))
124-
d4 = dim(domain(t2))
125-
At = convert(Array, t)
126-
@show sizeof(At)
127-
@test reshape(At, (d1, d2, d3, d4))
128-
reshape(convert(Array, t1), (d1, 1, d3, 1)) .*
129-
reshape(convert(Array, t2), (1, d2, 1, d4))
130-
end
131-
end
132-
@testset "Tensor product: test via tensor contraction" begin
133-
for T in (Float32, Float64, ComplexF32, ComplexF64)
134-
t1 = Tensor(rand, T, V2 V3 V1)
135-
t2 = Tensor(rand, T, V2 V1 V3)
136-
t = @inferred (t1 t2)
137-
@tensor t′[1, 2, 3, 4, 5, 6] := t1[1,2,3]*t2[4,5,6]
138-
@test t t′
139-
end
140-
end
141-
142108
@testset "Tensor contraction: test via conversion" begin
143-
A1 = TensorMap(randn, ComplexF64, V1*V2, V3)
109+
A1 = TensorMap(randn, ComplexF64, V1'*V2', V3')
144110
A2 = TensorMap(randn, ComplexF64, V3*V4, V5)
145111
rhoL = TensorMap(randn, ComplexF64, V1, V1)
146-
rhoR = TensorMap(randn, ComplexF64, V5, V5)
112+
rhoR = TensorMap(randn, ComplexF64, V5, V5)' # test adjoint tensor
147113
H = TensorMap(randn, ComplexF64, V2*V4, V2*V4)
148-
@tensor HrA12[a, s1, s2, c] := rhoL[a, a'] * A1[a', t1, b] *
114+
@tensor HrA12[a, s1, s2, c] := rhoL[a, a'] * conj(A1[a', t1, b]) *
149115
A2[b, t2, c'] * rhoR[c', c] * H[s1, s2, t1, t2]
150116

151117
@tensor HrA12array[a, s1, s2, c] := convert(Array, rhoL)[a, a'] *
152-
convert(Array, A1)[a', t1, b] *
118+
conj(convert(Array, A1)[a', t1, b]) *
153119
convert(Array, A2)[b, t2, c'] *
154120
convert(Array, rhoR)[c', c] *
155121
convert(Array, H)[s1, s2, t1, t2]
156122

157123
@test HrA12array convert(Array, HrA12)
158124
end
125+
@testset "Multiplication and inverse: test compatibility" begin
126+
W1 = V1 V2 V3
127+
W2 = V4 V5
128+
for T in (Float64, ComplexF64)
129+
t1 = TensorMap(rand, T, W1, W1)
130+
t2 = TensorMap(rand, T, W2, W2)
131+
t = TensorMap(rand, T, W1, W2)
132+
@test t1*(t1\t) t
133+
@test (t/t2)*t2 t
134+
@test t1\one(t1) inv(t1)
135+
@test one(t1)/t1 pinv(t1)
136+
@test_throws SpaceMismatch inv(t)
137+
@test_throws SpaceMismatch t2\t
138+
@test_throws SpaceMismatch t/t1
139+
tp = pinv(t)*t
140+
@test tp tp*tp
141+
end
142+
end
143+
@testset "Multiplication and inverse: test via conversion" begin
144+
W1 = V1 V2 V3
145+
W2 = V4 V5
146+
for T in (Float32, Float64, ComplexF32, ComplexF64)
147+
t1 = TensorMap(rand, T, W1, W1)
148+
t2 = TensorMap(rand, T, W2, W2)
149+
t = TensorMap(rand, T, W1, W2)
150+
d1 = dim(W1)
151+
d2 = dim(W2)
152+
At1 = reshape(convert(Array, t1), d1, d1)
153+
At2 = reshape(convert(Array, t2), d2, d2)
154+
At = reshape(convert(Array, t), d1, d2)
155+
@test reshape(convert(Array, t1*t), d1, d2) At1*At
156+
@test reshape(convert(Array, t1'*t), d1, d2) At1'*At
157+
@test reshape(convert(Array, t2*t'), d2, d1) At2*At'
158+
@test reshape(convert(Array, t2'*t'), d2, d1) At2'*At'
159159

160+
@test reshape(convert(Array, inv(t1)), d1, d1) inv(At1)
161+
@test reshape(convert(Array, pinv(t)), d2, d1) pinv(At)
160162

163+
if T == Float32 || T == ComplexF32
164+
continue
165+
end
166+
167+
@test reshape(convert(Array, t1\t), d1, d2) At1\At
168+
@test reshape(convert(Array, t1'\t), d1, d2) At1'\At
169+
@test reshape(convert(Array, t2\t'), d2, d1) At2\At'
170+
@test reshape(convert(Array, t2'\t'), d2, d1) At2'\At'
171+
172+
@test reshape(convert(Array, t2/t), d2, d1) At2/At
173+
@test reshape(convert(Array, t2'/t), d2, d1) At2'/At
174+
@test reshape(convert(Array, t1/t'), d1, d2) At1/At'
175+
@test reshape(convert(Array, t1'/t'), d1, d2) At1'/At'
176+
end
177+
end
161178
@testset "Factorization" begin
162179
W = V1 V2 V3 V4 V5
163180
for T in (Float32, Float64, ComplexF32, ComplexF64)
@@ -225,4 +242,37 @@ for (G,V) in ((Trivial, Vtr), (ℤ₂, Vℤ₂), (ℤ₃, Vℤ₃), (U₁, VU₁
225242
exp(reshape(convert(Array, t), (s,s)))
226243
end
227244
end
245+
@testset "Tensor product: test via norm preservation" begin
246+
for T in (Float32, Float64, ComplexF32, ComplexF64)
247+
t1 = TensorMap(rand, T, V2 V3 V1, V1 V2)
248+
t2 = TensorMap(rand, T, V2 V1 V3, V1 V1)
249+
t = @inferred (t1 t2)
250+
@test norm(t) norm(t1) * norm(t2)
251+
end
252+
end
253+
@testset "Tensor product: test via conversion" begin
254+
for T in (Float32, Float64, ComplexF32, ComplexF64)
255+
t1 = TensorMap(rand, T, V2 V3 V1, V1)
256+
t2 = TensorMap(rand, T, V2 V1 V3, V2)
257+
t = @inferred (t1 t2)
258+
d1 = dim(codomain(t1))
259+
d2 = dim(codomain(t2))
260+
d3 = dim(domain(t1))
261+
d4 = dim(domain(t2))
262+
At = convert(Array, t)
263+
@show sizeof(At)
264+
@test reshape(At, (d1, d2, d3, d4))
265+
reshape(convert(Array, t1), (d1, 1, d3, 1)) .*
266+
reshape(convert(Array, t2), (1, d2, 1, d4))
267+
end
268+
end
269+
@testset "Tensor product: test via tensor contraction" begin
270+
for T in (Float32, Float64, ComplexF32, ComplexF64)
271+
t1 = Tensor(rand, T, V2 V3 V1)
272+
t2 = Tensor(rand, T, V2 V1 V3)
273+
t = @inferred (t1 t2)
274+
@tensor t′[1, 2, 3, 4, 5, 6] := t1[1,2,3]*t2[4,5,6]
275+
@test t t′
276+
end
277+
end
228278
end

0 commit comments

Comments
 (0)