Skip to content

Commit add6826

Browse files
committed
Scalar indexing of lazy mul
1 parent ff0ac3f commit add6826

2 files changed

Lines changed: 29 additions & 1 deletion

File tree

src/lazyarrays.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,20 @@ similar_mul(a::AbstractArray, elt::Type) = similar(a, elt, axes(a))
786786
# TODO: Make use of both arguments to determine the output, maybe
787787
# using `LinearAlgebra.matprod_dest(factors(a)..., elt)`?
788788
similar_mul(a::AbstractArray, elt::Type, ax) = similar(last(factors(a)), elt, ax)
789+
function mul_getindex(a1::AbstractMatrix, a2::AbstractMatrix, i::Int, j::Int)
790+
return transpose(view(a1, i, :)) * view(a2, :, j)
791+
end
792+
function mul_getindex(a1::AbstractMatrix, a2::AbstractVector, i::Int)
793+
return transpose(view(a1, i, :)) * a2
794+
end
795+
function mul_getindex(a1::AbstractVector, a2::AbstractMatrix, j::Int)
796+
return transpose(a1) * view(a2, :, j)
797+
end
798+
function getindex_mul(a::AbstractArray, i::Int)
799+
I = Tuple(CartesianIndices(axes(a))[i])
800+
return getindex_mul(a, I...)
801+
end
802+
getindex_mul(a::AbstractArray, I::Vararg{Int}) = mul_getindex(factors(a)..., I...)
789803
copyto!_mul(dest::AbstractArray, src::AbstractArray) = add!(dest, src, true, false)
790804
show_mul(io::IO, a::AbstractArray) = show_lazy(io, a)
791805
show_mul(io::IO, mime::MIME"text/plain", a::AbstractArray) = show_lazy(io, mime, a)
@@ -843,6 +857,12 @@ macro mularray_type(MulArray, AbstractArray = :AbstractArray)
843857
)
844858
end
845859

860+
function copy_permuteddims_mul(a::PermutedDimsArray{<:Any, 2, perm}) where {perm}
861+
perm == (1, 2) && return copy(parent(a))
862+
perm == (2, 1) && return copy(transpose(parent(a)))
863+
throw(ArgumentError("Unsupported permutation $perm"))
864+
end
865+
846866
macro mularray_base(MulArray, AbstractArray = :AbstractArray)
847867
return esc(
848868
quote
@@ -864,6 +884,9 @@ macro mularray_base(MulArray, AbstractArray = :AbstractArray)
864884
function Base.similar(a::$MulArray, elt::Type, ax::Dims)
865885
return $TensorAlgebra.similar_mul(a, elt, ax)
866886
end
887+
Base.@propagate_inbounds function Base.getindex(a::$MulArray, I...)
888+
return $TensorAlgebra.getindex_mul(a, I...)
889+
end
867890
function Base.copyto!(dest::$AbstractArray, src::$MulArray)
868891
return $TensorAlgebra.copyto!_mul(dest, src)
869892
end
@@ -926,6 +949,9 @@ macro mularray_terminterface(MulArray, AbstractArray = :AbstractArray)
926949
$TensorAlgebra.iscall(a::$MulArray) = $TensorAlgebra.iscall_mul(a)
927950
$TensorAlgebra.operation(a::$MulArray) = $TensorAlgebra.operation_mul(a)
928951
$TensorAlgebra.arguments(a::$MulArray) = $TensorAlgebra.arguments_mul(a)
952+
function Base.copy(a::PermutedDimsArray{<:Any, 2, <:Any, <:Any, $MulArray})
953+
return $TensorAlgebra.copy_permuteddims_mul(a)
954+
end
929955
end
930956
)
931957
end

test/test_lazy.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ using Test: @test, @test_broken, @test_throws, @testset
9393

9494
x = FI.permuteddims(a *ₗ b, perm)
9595
@test x PermutedDimsArray(a *ₗ b, perm)
96-
@test_broken copy(x) permutedims(a * b, perm)
96+
@test copy(x) permutedims(a * b, perm)
9797
end
9898
@testset "linear broadcast lowering" begin
9999
a = randn(ComplexF64, 2, 2)
@@ -117,5 +117,7 @@ using Test: @test, @test_broken, @test_throws, @testset
117117
@test (2 *ₗ a)[1, 2] == 2 * a[1, 2]
118118
@test conjed(a)[2, 1] == conj(a[2, 1])
119119
@test (a +ₗ b)[2, 2] == a[2, 2] + b[2, 2]
120+
@test (a *ₗ b)[1, 2] (a * b)[1, 2]
121+
@test (a *ₗ b)[3] (a * b)[3]
120122
end
121123
end

0 commit comments

Comments
 (0)