@@ -786,6 +786,20 @@ similar_mul(a::AbstractArray, elt::Type) = similar(a, elt, axes(a))
786786# TODO : Make use of both arguments to determine the output, maybe
787787# using `LinearAlgebra.matprod_dest(factors(a)..., elt)`?
788788similar_mul (a:: AbstractArray , elt:: Type , ax) = similar (last (factors (a)), elt, ax)
789+ function mul_getindex (a1:: AbstractMatrix , a2:: AbstractMatrix , i:: Int , j:: Int )
790+ return transpose (view (a1, i, :)) * view (a2, :, j)
791+ end
792+ function mul_getindex (a1:: AbstractMatrix , a2:: AbstractVector , i:: Int )
793+ return transpose (view (a1, i, :)) * a2
794+ end
795+ function mul_getindex (a1:: AbstractVector , a2:: AbstractMatrix , j:: Int )
796+ return transpose (a1) * view (a2, :, j)
797+ end
798+ function getindex_mul (a:: AbstractArray , i:: Int )
799+ I = Tuple (CartesianIndices (axes (a))[i])
800+ return getindex_mul (a, I... )
801+ end
802+ getindex_mul (a:: AbstractArray , I:: Vararg{Int} ) = mul_getindex (factors (a)... , I... )
789803copyto!_mul (dest:: AbstractArray , src:: AbstractArray ) = add! (dest, src, true , false )
790804show_mul (io:: IO , a:: AbstractArray ) = show_lazy (io, a)
791805show_mul (io:: IO , mime:: MIME"text/plain" , a:: AbstractArray ) = show_lazy (io, mime, a)
@@ -843,6 +857,12 @@ macro mularray_type(MulArray, AbstractArray = :AbstractArray)
843857 )
844858end
845859
860+ function copy_permuteddims_mul (a:: PermutedDimsArray{<:Any, 2, perm} ) where {perm}
861+ perm == (1 , 2 ) && return copy (parent (a))
862+ perm == (2 , 1 ) && return copy (transpose (parent (a)))
863+ throw (ArgumentError (" Unsupported permutation $perm " ))
864+ end
865+
846866macro mularray_base (MulArray, AbstractArray = :AbstractArray )
847867 return esc (
848868 quote
@@ -864,6 +884,9 @@ macro mularray_base(MulArray, AbstractArray = :AbstractArray)
864884 function Base. similar (a:: $MulArray , elt:: Type , ax:: Dims )
865885 return $ TensorAlgebra. similar_mul (a, elt, ax)
866886 end
887+ Base. @propagate_inbounds function Base. getindex (a:: $MulArray , I... )
888+ return $ TensorAlgebra. getindex_mul (a, I... )
889+ end
867890 function Base. copyto! (dest:: $AbstractArray , src:: $MulArray )
868891 return $ TensorAlgebra. copyto!_mul (dest, src)
869892 end
@@ -926,6 +949,9 @@ macro mularray_terminterface(MulArray, AbstractArray = :AbstractArray)
926949 $ TensorAlgebra. iscall (a:: $MulArray ) = $ TensorAlgebra. iscall_mul (a)
927950 $ TensorAlgebra. operation (a:: $MulArray ) = $ TensorAlgebra. operation_mul (a)
928951 $ TensorAlgebra. arguments (a:: $MulArray ) = $ TensorAlgebra. arguments_mul (a)
952+ function Base. copy (a:: PermutedDimsArray{<:Any, 2, <:Any, <:Any, $MulArray} )
953+ return $ TensorAlgebra. copy_permuteddims_mul (a)
954+ end
929955 end
930956 )
931957end
0 commit comments