Skip to content

Commit 59c5d97

Browse files
authored
transpose specialization for DiagonalTensorMap (#335)
* add `transpose` specialization for `DiagonalTensorMap` * add test * slight test fixes
1 parent cb40722 commit 59c5d97

2 files changed

Lines changed: 38 additions & 9 deletions

File tree

src/tensors/diagonal.jl

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,29 @@ function permute(
211211
end
212212
return d′
213213
else
214-
throw(ArgumentError("invalid permutation $((p₁, p₂)) for tensor in space $(space(d))"))
214+
throw(ArgumentError(lazy"invalid permutation $((p₁, p₂)) for tensor in space $(space(d))"))
215+
end
216+
end
217+
218+
function LinearAlgebra.transpose(
219+
d::DiagonalTensorMap, (p₁, p₂)::Index2Tuple{1, 1}; copy::Bool = false
220+
)
221+
if p₁ === (1,) && p₂ === (2,)
222+
return copy ? Base.copy(d) : d
223+
elseif p₁ === (2,) && p₂ === (1,) # transpose
224+
if has_shared_permute(d, (p₁, p₂)) # tranpose for bosonic sectors
225+
return DiagonalTensorMap(copy ? Base.copy(d.data) : d.data, dual(d.domain))
226+
end
227+
d′ = typeof(d)(undef, dual(d.domain))
228+
for (c, b) in blocks(d)
229+
f = only(fusiontrees(codomain(d), c))
230+
((f′, _), coeff) = only(transpose(f, f, p₁, p₂))
231+
c′ = f′.coupled
232+
scale!(block(d′, c′), b, coeff)
233+
end
234+
return d′
235+
else
236+
throw(ArgumentError(lazy"invalid transposition $((p₁, p₂)) for tensor in space $(space(d))"))
215237
end
216238
end
217239

test/tensors/diagonal.jl

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -119,19 +119,26 @@ diagspacelist = (
119119
if BraidingStyle(I) isa SymmetricBraiding
120120
@timedtestset "Permutations" begin
121121
t = DiagonalTensorMap(randn(ComplexF64, reduceddim(V)), V)
122+
t_tm = convert(TensorMap, t)
123+
124+
# preserving diagonal
122125
t1 = @constinferred permute(t, $(((2,), (1,))))
123-
if BraidingStyle(sectortype(V)) isa Bosonic
124-
@test t1 transpose(t)
125-
end
126-
@test convert(TensorMap, t1) == permute(convert(TensorMap, t), (((2,), (1,))))
126+
@test t1 isa DiagonalTensorMap
127+
@test convert(TensorMap, t1) == permute(t_tm, (((2,), (1,))))
128+
t1′ = @constinferred transpose(t)
129+
@test t1′ isa DiagonalTensorMap
130+
@test convert(TensorMap, t1′) == transpose(t_tm)
131+
BraidingStyle(I) isa Bosonic && @test t1 t1′
132+
133+
# not preserving diagonal
127134
t2 = @constinferred permute(t, $(((1, 2), ())))
128-
@test convert(TensorMap, t2) == permute(convert(TensorMap, t), (((1, 2), ())))
135+
@test convert(TensorMap, t2) == permute(t_tm, (((1, 2), ())))
129136
t3 = @constinferred permute(t, $(((2, 1), ())))
130-
@test convert(TensorMap, t3) == permute(convert(TensorMap, t), (((2, 1), ())))
137+
@test convert(TensorMap, t3) == permute(t_tm, (((2, 1), ())))
131138
t4 = @constinferred permute(t, $(((), (1, 2))))
132-
@test convert(TensorMap, t4) == permute(convert(TensorMap, t), (((), (1, 2))))
139+
@test convert(TensorMap, t4) == permute(t_tm, (((), (1, 2))))
133140
t5 = @constinferred permute(t, $(((), (2, 1))))
134-
@test convert(TensorMap, t5) == permute(convert(TensorMap, t), (((), (2, 1))))
141+
@test convert(TensorMap, t5) == permute(t_tm, (((), (2, 1))))
135142
end
136143
end
137144
@timedtestset "Trace, Multiplication and inverse" begin

0 commit comments

Comments
 (0)