From c1e04f21af8c76038ed1203e1e4467eaa0256cf7 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Mon, 17 Feb 2025 16:02:48 +0100 Subject: [PATCH] Support disabling automatic sync on task switch --- src/array.jl | 17 +++++++++++++++++ src/memory.jl | 15 ++++++++++++--- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/src/array.jl b/src/array.jl index 83b81ca862..41e7fd37ee 100644 --- a/src/array.jl +++ b/src/array.jl @@ -453,6 +453,23 @@ function Base.unsafe_convert(::Type{CuDeviceArray{T,N,AS.Global}}, a::DenseCuArr a.maxsize - a.offset*Base.elsize(a)) end +## synchronization behavior + +""" + unsafe_disable_task_sync!(arr::CuArray) + +By default `CuArray`s are implicitly synchronized when they are used on different CUDA streams. +A `CuArray` that is used on multiple Julia tasks will be used by different streams +and thus will cause a synchronization between multiple Julia tasks. + +This `unsafe_disable_task_sync` disables synchronization when stream are being switched. +""" +function unsafe_disable_task_sync!(arr::CuArray) + return arr.data[].task_sync = false +end +function unsafe_enable_task_sync!(arr::CuArray) + return arr.data[].task_sync = true +end ## memory copying diff --git a/src/memory.jl b/src/memory.jl index 5b8d866cc8..fe7808c1f4 100644 --- a/src/memory.jl +++ b/src/memory.jl @@ -509,10 +509,13 @@ mutable struct Managed{M} # whether the memory has been captured in a way that would make the dirty bit unreliable captured::Bool - function Managed(mem::AbstractMemory; stream=CUDA.stream(), dirty=true, captured=false) + # whether memory accessed from another task causes implicit syncrhonization + task_sync::Bool + + function Managed(mem::AbstractMemory; stream = CUDA.stream(), dirty = true, captured = false, task_sync = true) # NOTE: memory starts as dirty, because stream-ordered allocations are only # guaranteed to be physically allocated at a synchronization event. - new{typeof(mem)}(mem, stream, dirty, captured) + return new{typeof(mem)}(mem, stream, dirty, captured, task_sync) end end @@ -566,7 +569,13 @@ function Base.convert(::Type{CuPtr{T}}, managed::Managed{M}) where {T,M} # accessing memory on another stream: ensure the data is ready and take ownership if managed.stream != state.stream - maybe_synchronize(managed) + # Synchronize the array if task_sync is enabled + managed.task_sync && maybe_synchronize(managed) + # We must still switch the stream, since we are about to set ".dirty=true" + # and otherwise subsequent operations will synchronize against the wrong stream. + # XXX: What to do when an array is used on multiple tasks concurrently? + # We could have the situation that we will synchronize against the "wrong" stream + # But that would mean that we need a mapping from task to stream per Managed object... managed.stream = state.stream end