-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy patharraylayouts.jl
More file actions
101 lines (91 loc) · 3.49 KB
/
arraylayouts.jl
File metadata and controls
101 lines (91 loc) · 3.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
using ArrayLayouts: ArrayLayouts, DenseColumnMajor, DualLayout, MemoryLayout, MulAdd
using BlockArrays: BlockLayout
using SparseArraysBase: SparseLayout
using TypeParameterAccessors: parenttype, similartype
function ArrayLayouts.MemoryLayout(arraytype::Type{<:AnyAbstractBlockSparseArray})
outer_layout = typeof(MemoryLayout(blockstype(arraytype)))
inner_layout = typeof(MemoryLayout(blocktype(arraytype)))
return BlockLayout{outer_layout, inner_layout}()
end
# TODO: Generalize to `BlockSparseVectorLike`/`AnyBlockSparseVector`.
function ArrayLayouts.MemoryLayout(
arraytype::Type{<:Adjoint{<:Any, <:AbstractBlockSparseVector}}
)
return DualLayout{typeof(MemoryLayout(parenttype(arraytype)))}()
end
# TODO: Generalize to `BlockSparseVectorLike`/`AnyBlockSparseVector`.
function ArrayLayouts.MemoryLayout(
arraytype::Type{<:Transpose{<:Any, <:AbstractBlockSparseVector}}
)
return DualLayout{typeof(MemoryLayout(parenttype(arraytype)))}()
end
function Base.similar(
mul::MulAdd{
<:BlockLayout{<:SparseLayout, BlockLayoutA},
<:BlockLayout{<:SparseLayout, BlockLayoutB},
LayoutC,
T,
A,
B,
C,
},
elt::Type,
axes,
) where {BlockLayoutA, BlockLayoutB, LayoutC, T, A, B, C}
# TODO: Consider using this instead:
# ```julia
# blockmultype = MulAdd{BlockLayoutA,BlockLayoutB,LayoutC,T,blocktype(A),blocktype(B),C}
# output_blocktype = Base.promote_op(
# similar, blockmultype, Type{elt}, Tuple{eltype.(eachblockaxis.(axes))...}
# )
# ```
# The issue is that it in some cases it seems to lose some information about the block types.
# TODO: Maybe this should be:
# output_blocktype = Base.promote_op(
# mul!, blocktype(mul.A), blocktype(mul.B), blocktype(mul.C), typeof(mul.α), typeof(mul.β)
# )
output_blocktype = Base.promote_op(*, blocktype(mul.A), blocktype(mul.B))
output_blocktype′ =
!isconcretetype(output_blocktype) ? AbstractMatrix{elt} : output_blocktype
return similar(BlockSparseArray{elt, length(axes), output_blocktype′}, axes)
end
# BlockSparseMatrix * dense Vector → dense Vector
# Returns a plain Vector instead of a BlockedVector
function Base.similar(
mul::MulAdd{
<:BlockLayout{<:SparseLayout},
<:DenseColumnMajor,
<:Any,
},
elt::Type,
axes::Tuple{<:AbstractUnitRange},
)
# Convert blocked axes to plain axes to avoid creating BlockedVector
plain_axes = map(ax -> Base.OneTo(length(ax)), axes)
return similar(mul.B, elt, plain_axes)
end
# Materialize a SubArray view.
function ArrayLayouts.sub_materialize(layout::BlockLayout{<:SparseLayout}, a, axes)
# TODO: Define `blocktype`/`blockstype` for `SubArray` wrapping `BlockSparseArray`.
# TODO: Use `similar`?
blocktype_a = blocktype(parent(a))
a_dest = BlockSparseArray{eltype(a), length(axes), blocktype_a}(undef, axes)
a_dest .= a
return a_dest
end
function _similar(arraytype::Type{<:AbstractArray}, size::Tuple)
return similar(arraytype, size)
end
function _similar(
::Type{<:SubArray{<:Any, <:Any, <:ArrayType}}, size::Tuple
) where {ArrayType}
return similar(ArrayType, size)
end
# Materialize a SubArray view.
function ArrayLayouts.sub_materialize(
layout::BlockLayout{<:SparseLayout}, a, axes::Tuple{Vararg{Base.OneTo}}
)
a_dest = _similar(blocktype(a), length.(axes))
a_dest .= a
return a_dest
end