Skip to content

Commit a0eeaac

Browse files
committed
type stability improvements
1 parent 5b83bf4 commit a0eeaac

1 file changed

Lines changed: 28 additions & 20 deletions

File tree

src/factorizations/truncation.jl

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -265,37 +265,45 @@ function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationSpace)
265265
return SectorDict(c => MAK.findtruncated_svd(d, blockstrategy(c)) for (c, d) in pairs(values))
266266
end
267267

268+
# The implementations below assume that the `SectorDict` always contains an entry for every block sector
269+
# for example, if a block gets fully truncated, inds[c] = Int[].
270+
# This is always the case in the implementations above.
271+
268272
function MAK.findtruncated(values::SectorVector, strategy::TruncationIntersection)
269273
inds = map(Base.Fix1(MAK.findtruncated, values), strategy.components)
270-
return SectorDict(
271-
c => mapreduce(
272-
Base.Fix2(getindex, c), MatrixAlgebraKit._ind_intersect, inds
273-
) for c in intersect(map(keys, inds)...)
274-
)
274+
@assert allequal(keys, inds) "missing blocks are not supported right now"
275+
sectors = keys(first(inds))
276+
vals = map(keys(first(inds))) do c
277+
mapreduce(Base.Fix2(getindex, c), MatrixAlgebraKit._ind_intersect, inds)
278+
end
279+
return SectorDict{eltype(sectors), eltype(vals)}(sectors, vals)
275280
end
276281
function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationIntersection)
277282
inds = map(Base.Fix1(MAK.findtruncated_svd, values), strategy.components)
278-
return SectorDict(
279-
c => mapreduce(
280-
Base.Fix2(getindex, c), MatrixAlgebraKit._ind_intersect, inds
281-
) for c in intersect(map(keys, inds)...)
282-
)
283+
@assert allequal(keys, inds) "missing blocks are not supported right now"
284+
sectors = keys(first(inds))
285+
vals = map(keys(first(inds))) do c
286+
mapreduce(Base.Fix2(getindex, c), MatrixAlgebraKit._ind_intersect, inds)
287+
end
288+
return SectorDict{eltype(sectors), eltype(vals)}(sectors, vals)
283289
end
284290
function MAK.findtruncated(values::SectorVector, strategy::TruncationUnion)
285291
inds = map(Base.Fix1(MAK.findtruncated, values), strategy.components)
286-
return SectorDict(
287-
c => reduce(
288-
MatrixAlgebraKit._ind_union, [ind[c] for ind in inds if haskey(ind, c)]
289-
) for c in union(map(keys, inds)...)
290-
)
292+
@assert allequal(keys, inds) "missing blocks are not supported right now"
293+
sectors = keys(first(inds))
294+
vals = map(keys(first(inds))) do c
295+
mapreduce(Base.Fix2(getindex, c), MatrixAlgebraKit._ind_union, inds)
296+
end
297+
return SectorDict{eltype(sectors), eltype(vals)}(sectors, vals)
291298
end
292299
function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationUnion)
293300
inds = map(Base.Fix1(MAK.findtruncated_svd, values), strategy.components)
294-
return SectorDict(
295-
c => reduce(
296-
MatrixAlgebraKit._ind_union, [ind[c] for ind in inds if haskey(ind, c)]
297-
) for c in union(map(keys, inds)...)
298-
)
301+
@assert allequal(keys, inds) "missing blocks are not supported right now"
302+
sectors = keys(first(inds))
303+
vals = map(keys(first(inds))) do c
304+
mapreduce(Base.Fix2(getindex, c), MatrixAlgebraKit._ind_union, inds)
305+
end
306+
return SectorDict{eltype(sectors), eltype(vals)}(sectors, vals)
299307
end
300308

301309
# Truncation error

0 commit comments

Comments
 (0)