Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/factorizations/truncation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ function truncspace(space::ElementarySpace; by = abs, rev::Bool = true)
return TruncationSpace(space, by, rev)
end

TensorKit.spacetype(::Type{<:TruncationSpace{S}}) where {S} = S

# truncate!
# ---------
_blocklength(d::Integer, ind) = _blocklength(Base.OneTo(d), ind)
Expand Down Expand Up @@ -257,10 +259,12 @@ MAK.findtruncated_svd(values::SectorVector, strategy::TruncationByError) =
MAK.findtruncated(values, strategy)

function MAK.findtruncated(values::SectorVector, strategy::TruncationSpace)
@assert spacetype(values) == spacetype(strategy)
blockstrategy(c) = truncrank(dim(strategy.space, c); strategy.by, strategy.rev)
return SectorDict(c => MAK.findtruncated(d, blockstrategy(c)) for (c, d) in pairs(values))
end
function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationSpace)
@assert spacetype(values) == spacetype(strategy)
blockstrategy(c) = truncrank(dim(strategy.space, c); strategy.by, strategy.rev)
return SectorDict(c => MAK.findtruncated_svd(d, blockstrategy(c)) for (c, d) in pairs(values))
end
Expand Down
2 changes: 2 additions & 0 deletions test/factorizations/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ for V in spacelist
@test ϵ1 ϵ2

trunc = truncspace(space(S2, 1))
@test spacetype(typeof(trunc)) == spacetype(V)
@test sectortype(trunc) == sectortype(V)
U3, S3, Vᴴ3, ϵ3 = @constinferred svd_trunc(t; trunc)
@test t * Vᴴ3' U3 * S3
@test isisometric(U3)
Expand Down
Loading