-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathBlockSparseArraysTensorAlgebraExt.jl
More file actions
90 lines (84 loc) · 3.51 KB
/
BlockSparseArraysTensorAlgebraExt.jl
File metadata and controls
90 lines (84 loc) · 3.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
module BlockSparseArraysTensorAlgebraExt
using BlockArrays: AbstractBlockArray, Block, blocklength, blocks, eachblockaxes1
using BlockSparseArrays: AbstractBlockSparseArray, AbstractBlockSparseMatrix,
BlockUnitRange, blockrange, blocksparse
using SparseArraysBase: eachstoredindex
using TensorAlgebra: TensorAlgebra, BlockedTuple, FusionStyle, matricize, matricize_axes,
tensor_product_axis, unmatricize
const BlockReshapeFusion = typeof(FusionStyle(AbstractBlockArray))
function TensorAlgebra.tensor_product_axis(
style::BlockReshapeFusion, side::Val{:codomain}, r1::BlockUnitRange, r2::BlockUnitRange
)
return tensor_product_blockrange(style, side, r1, r2)
end
function TensorAlgebra.tensor_product_axis(
style::BlockReshapeFusion, side::Val{:domain}, r1::BlockUnitRange, r2::BlockUnitRange
)
return tensor_product_blockrange(style, side, r1, r2)
end
function tensor_product_blockrange(
::BlockReshapeFusion, side::Val, r1::BlockUnitRange, r2::BlockUnitRange
)
(isone(first(r1)) && isone(first(r2))) ||
throw(ArgumentError("Only one-based axes are supported"))
blockaxpairs = Iterators.product(eachblockaxes1(r1), eachblockaxes1(r2))
blockaxs = map(blockaxpairs) do (b1, b2)
# TODO: Store a FusionStyle for the blocks in `BlockReshapeFusion`
# and use that here.
return tensor_product_axis(side, b1, b2)
end
return blockrange(vec(blockaxs))
end
function TensorAlgebra.matricize(
style::BlockReshapeFusion, a::AbstractBlockSparseArray, length_codomain::Val
)
ax = matricize_axes(style, a, length_codomain)
reshaped_blocks_a = reshape(blocks(a), blocklength.(ax))
key(I) = Block(Tuple(I))
value(I) = matricize(reshaped_blocks_a[I], length_codomain)
Is = eachstoredindex(reshaped_blocks_a)
bs = if isempty(Is)
# Catch empty case and make sure the type is constrained properly.
# This seems to only be necessary in Julia versions below v1.11,
# try removing it when we drop support for those versions.
keytype = Base.promote_op(key, eltype(Is))
valtype = Base.promote_op(value, eltype(Is))
valtype′ = !isconcretetype(valtype) ? AbstractMatrix{eltype(a)} : valtype
Dict{keytype, valtype′}()
else
Dict(key(I) => value(I) for I in Is)
end
return blocksparse(bs, ax)
end
function TensorAlgebra.unmatricize(
::BlockReshapeFusion,
m::AbstractBlockSparseMatrix,
codomain_axes::Tuple{Vararg{AbstractUnitRange}},
domain_axes::Tuple{Vararg{AbstractUnitRange}}
)
ax = (codomain_axes..., domain_axes...)
reshaped_blocks_m = reshape(blocks(m), blocklength.(ax))
key(I) = Block(Tuple(I))
function value(I)
block_axes_I = BlockedTuple(
map(ntuple(identity, length(ax))) do i
return Base.axes1(ax[i][Block(I[i])])
end,
(length(codomain_axes), length(domain_axes))
)
return unmatricize(reshaped_blocks_m[I], block_axes_I)
end
Is = eachstoredindex(reshaped_blocks_m)
bs = if isempty(Is)
# Constrain key/value types explicitly so empty cases are still typed.
keytype = Base.promote_op(key, eltype(Is))
valtype = Base.promote_op(value, eltype(Is))
valtype′ =
!isconcretetype(valtype) ? AbstractArray{eltype(m), length(ax)} : valtype
bs = Dict{keytype, valtype′}()
else
Dict(key(I) => value(I) for I in eachstoredindex(reshaped_blocks_m))
end
return blocksparse(bs, ax)
end
end