diff --git a/ext/CUDAExt.jl b/ext/CUDAExt.jl index ea77889ff..4e232d76d 100644 --- a/ext/CUDAExt.jl +++ b/ext/CUDAExt.jl @@ -104,7 +104,7 @@ end function _sync_with_context(x::Union{Dagger.Processor,Dagger.MemorySpace}) with_context(x) do - CUDA.synchronize() + CUDA.synchronize(stream()) end end function sync_with_context(x::Union{Dagger.Processor,Dagger.MemorySpace}) @@ -391,7 +391,7 @@ Dagger.gpu_with_device(f, proc::CuArrayDeviceProc) = CUDA.device!(f, proc.device) function Dagger.gpu_synchronize(proc::CuArrayDeviceProc) with_context(proc) do - CUDA.synchronize() + CUDA.synchronize(stream()) end end function Dagger.gpu_synchronize(::Val{:CUDA}) diff --git a/ext/ROCExt.jl b/ext/ROCExt.jl index 3ab6d0731..c2058b829 100644 --- a/ext/ROCExt.jl +++ b/ext/ROCExt.jl @@ -98,7 +98,7 @@ end function _sync_with_context(x::Union{Dagger.Processor,Dagger.MemorySpace}) with_context(x) do - AMDGPU.synchronize() + AMDGPU.synchronize(stream()) end end function sync_with_context(x::Union{Dagger.Processor,Dagger.MemorySpace}) @@ -364,7 +364,7 @@ Dagger.gpu_with_device(f, proc::ROCArrayDeviceProc) = AMDGPU.device!(f, AMDGPU.devices()[proc.device_id]) function Dagger.gpu_synchronize(proc::ROCArrayDeviceProc) with_context(proc) do - AMDGPU.synchronize() + AMDGPU.synchronize(stream()) end end function Dagger.gpu_synchronize(::Val{:ROC})