Skip to content

Commit bc72d6f

Browse files
authored
Slow but working copyto (#52)
* Slow but working copyto * Formatter * Another copyto * Remove unneeded copyto * Test copyto with Broadcasted * Better test * GPU tests too
1 parent 7612f6d commit bc72d6f

4 files changed

Lines changed: 28 additions & 2 deletions

File tree

ext/StridedGPUArraysExt.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module StridedGPUArraysExt
22

3-
using Strided, GPUArrays
3+
using Strided, GPUArrays, LinearAlgebra
44
using GPUArrays: Adapt, KernelAbstractions
55
using GPUArrays.KernelAbstractions: @kernel, @index
66

@@ -20,6 +20,14 @@ function Base.copy!(dst::AbstractArray{TD, ND}, src::StridedView{TS, NS, TAS, FS
2020
return dst
2121
end
2222

23+
function Base.copyto!(dest::StridedView{T, N, <:AnyGPUArray{T}}, bc::Base.Broadcast.Broadcasted{Strided.StridedArrayStyle{N}}) where {T <: Number, N}
24+
dims = size(dest)
25+
any(isequal(0), dims) && return dest
26+
27+
GPUArrays._copyto!(dest, bc)
28+
return dest
29+
end
30+
2331
# lifted from GPUArrays.jl
2432
function Base.fill!(A::StridedView{T, N, TA, F}, x) where {T, N, TA <: AbstractGPUArray{T}, F <: ALL_FS}
2533
isempty(A) && return A
@@ -34,7 +42,7 @@ function Base.fill!(A::StridedView{T, N, TA, F}, x) where {T, N, TA <: AbstractG
3442
return A
3543
end
3644

37-
function Strided.__mul!(
45+
function LinearAlgebra.mul!(
3846
C::StridedView{TC, 2, <:AnyGPUArray{TC}},
3947
A::StridedView{TA, 2, <:AnyGPUArray{TA}},
4048
B::StridedView{TB, 2, <:AnyGPUArray{TB}},

test/amd.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ for T in (Float32, Float64, Complex{Float32}, Complex{Float64})
1616
axes(f1(A1)) == axes(f2(A2)) || continue
1717
@test collect(ROCMatrix(copy!(f2(A2), f1(A1)))) == AMDGPU.Adapt.adapt(Vector{T}, copy!(B2, B1))
1818
@test copy!(zA1, f1(A1)) == copy!(zA2, B1)
19+
A3 = ROCArray(randn(T, (m1, m2)))
20+
A3c = copy(A3)
21+
B3 = f1(StridedView(A3c))
22+
@. B1 = 2 * B1 - B3 / 3 # test copyto! of Broadcasted
23+
@. A1 = 2 * A1 - A3 / 3 # test copyto! of Broadcasted
24+
@test AMDGPU.Adapt.adapt(Vector{T}, f1(A1)) == AMDGPU.Adapt.adapt(Vector{T}, B1)
1925
x = rand(T)
2026
@test f1(StridedView(AMDGPU.Adapt.adapt(Vector{T}, fill!(A1c, x)))) == AMDGPU.Adapt.adapt(Vector{T}, fill!(B1, x))
2127
end

test/cuda.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@ for T in (Float32, Float64, Complex{Float32}, Complex{Float64})
1212
axes(f1(A1)) == axes(f2(A2)) || continue
1313
@test collect(CuMatrix(copy!(f2(A2), f1(A1)))) == CUDA.Adapt.adapt(Vector{T}, copy!(B2, B1))
1414
@test copy!(zA1, f1(A1)) == copy!(zA2, B1)
15+
A3 = CuArray(randn(T, (m1, m2)))
16+
A3c = copy(A3)
17+
B3 = f1(StridedView(A3c))
18+
@. B1 = 2 * B1 - B3 / 3 # test copyto! of Broadcasted
19+
@. A1 = 2 * A1 - A3 / 3 # test copyto! of Broadcasted
20+
@test CUDA.Adapt.adapt(Vector{T}, f1(A1)) == CUDA.Adapt.adapt(Vector{T}, B1)
1521
x = rand(T)
1622
@test f1(StridedView(CUDA.Adapt.adapt(Vector{T}, fill!(A1c, x)))) == CUDA.Adapt.adapt(Vector{T}, fill!(B1, x))
1723
end

test/jlarrays.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@
1212
axes(f1(A1)) == axes(f2(A2)) || continue
1313
@test collect(Matrix(copy!(f2(A2), f1(A1)))) == JLArrays.Adapt.adapt(Vector{T}, copy!(B2, B1))
1414
@test copy!(zA1, f1(A1)) == copy!(zA2, B1)
15+
A3 = JLArray(randn(T, (m1, m2)))
16+
A3c = copy(A3)
17+
B3 = f1(StridedView(A3c))
18+
@. B1 = 2 * B1 - B3 / 3 # test copyto! of Broadcasted
19+
@. A1 = 2 * A1 - A3 / 3 # test copyto! of Broadcasted
20+
@test JLArrays.Adapt.adapt(Vector{T}, f1(A1)) == JLArrays.Adapt.adapt(Vector{T}, B1)
1521
x = rand(T)
1622
@test f1(StridedView(JLArrays.Adapt.adapt(Vector{T}, fill!(A1c, x)))) == JLArrays.Adapt.adapt(Vector{T}, fill!(B1, x))
1723
end

0 commit comments

Comments
 (0)