Skip to content

Commit 006cae4

Browse files
committed
More progress
1 parent 0082d25 commit 006cae4

2 files changed

Lines changed: 27 additions & 20 deletions

File tree

ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,11 @@ function TensorAlgebra.matricize(
4141
key(I) = Block(Tuple(I))
4242
value(I) = matricize(reshaped_blocks_a[I], length_codomain)
4343
Is = eachstoredindex(reshaped_blocks_a)
44-
bs = if isempty(Is)
45-
# Catch empty case and make sure the type is constrained properly.
46-
# This seems to only be necessary in Julia versions below v1.11,
47-
# try removing it when we drop support for those versions.
48-
keytype = Base.promote_op(key, eltype(Is))
49-
valtype = Base.promote_op(value, eltype(Is))
50-
valtype′ = !isconcretetype(valtype) ? AbstractMatrix{eltype(a)} : valtype
51-
Dict{keytype, valtype′}()
52-
else
53-
Dict(key(I) => value(I) for I in Is)
54-
end
44+
# Constrain key/value types explicitly so empty cases are still typed.
45+
keytype = Base.promote_op(key, eltype(Is))
46+
valtype = Base.promote_op(value, eltype(Is))
47+
valtype′ = !isconcretetype(valtype) ? AbstractMatrix{eltype(a)} : valtype
48+
bs = Dict{keytype, valtype′}(key(I) => value(I) for I in Is)
5549
return blocksparse(bs, ax)
5650
end
5751

@@ -73,7 +67,12 @@ function TensorAlgebra.unmatricize(
7367
)
7468
return unmatricize(reshaped_blocks_m[I], block_axes_I)
7569
end
76-
bs = Dict(key(I) => value(I) for I in eachstoredindex(reshaped_blocks_m))
70+
Is = eachstoredindex(reshaped_blocks_m)
71+
# Constrain key/value types explicitly so empty cases are still typed.
72+
keytype = Base.promote_op(key, eltype(Is))
73+
valtype = Base.promote_op(value, eltype(Is))
74+
valtype′ = !isconcretetype(valtype) ? AbstractArray{eltype(m), length(ax)} : valtype
75+
bs = Dict{keytype, valtype′}(key(I) => value(I) for I in Is)
7776
return blocksparse(bs, ax)
7877
end
7978

src/abstractblocksparsearray/arraylayouts.jl

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,27 +62,35 @@ function ArrayLayouts.sub_materialize(layout::BlockLayout{<:SparseLayout}, a, ax
6262
# TODO: Use `similar`?
6363
blocktype_a = blocktype(parent(a))
6464
a_dest = BlockSparseArray{eltype(a), length(axes), blocktype_a}(undef, axes)
65-
for I in SparseArraysBase.eachstoredindex(blocks(a))
66-
b = Block(Tuple(I))
67-
a_dest[b] = copy(blocks(a)[Tuple(I)...])
68-
end
65+
a_dest .= a
6966
return a_dest
7067
end
7168

72-
function _similar(arraytype::Type{<:AbstractArray}, size::Tuple)
73-
return similar(arraytype, size)
69+
function _similar(arraytype::Type{<:AbstractArray{T, N}}, size::Tuple) where {T, N}
70+
if isconcretetype(arraytype)
71+
try
72+
return similar(arraytype, size)
73+
catch err
74+
if !(err isa MethodError)
75+
rethrow()
76+
end
77+
end
78+
end
79+
return similar(Array{T, N}, size)
7480
end
7581
function _similar(
7682
::Type{<:SubArray{<:Any, <:Any, <:ArrayType}}, size::Tuple
7783
) where {ArrayType}
78-
return similar(ArrayType, size)
84+
return _similar(ArrayType, size)
7985
end
8086

8187
# Materialize a SubArray view.
8288
function ArrayLayouts.sub_materialize(
8389
layout::BlockLayout{<:SparseLayout}, a, axes::Tuple{Vararg{Base.OneTo}}
8490
)
8591
a_dest = _similar(blocktype(a), length.(axes))
86-
a_dest .= a
92+
for I in CartesianIndices(a_dest)
93+
@inbounds a_dest[I] = a[I]
94+
end
8795
return a_dest
8896
end

0 commit comments

Comments
 (0)