11module StridedGPUArraysExt
22
3- using Strided, GPUArrays
3+ using Strided, GPUArrays, LinearAlgebra
44using GPUArrays: Adapt, KernelAbstractions
55using GPUArrays. KernelAbstractions: @kernel , @index
66
@@ -20,6 +20,14 @@ function Base.copy!(dst::AbstractArray{TD, ND}, src::StridedView{TS, NS, TAS, FS
2020 return dst
2121end
2222
23+ function Base. copyto! (dest:: StridedView{T, N, <:AnyGPUArray{T}} , bc:: Base.Broadcast.Broadcasted{Strided.StridedArrayStyle{N}} ) where {T <: Number , N}
24+ dims = size (dest)
25+ any (isequal (0 ), dims) && return dest
26+
27+ GPUArrays. _copyto! (dest, bc)
28+ return dest
29+ end
30+
2331# lifted from GPUArrays.jl
2432function Base. fill! (A:: StridedView{T, N, TA, F} , x) where {T, N, TA <: AbstractGPUArray{T} , F <: ALL_FS }
2533 isempty (A) && return A
@@ -34,7 +42,7 @@ function Base.fill!(A::StridedView{T, N, TA, F}, x) where {T, N, TA <: AbstractG
3442 return A
3543end
3644
37- function Strided . __mul ! (
45+ function LinearAlgebra . mul ! (
3846 C:: StridedView{TC, 2, <:AnyGPUArray{TC}} ,
3947 A:: StridedView{TA, 2, <:AnyGPUArray{TA}} ,
4048 B:: StridedView{TB, 2, <:AnyGPUArray{TB}} ,
0 commit comments