diff --git a/src/pmap.jl b/src/pmap.jl index 603dfa7e031ceb..f884d47fff98eb 100644 --- a/src/pmap.jl +++ b/src/pmap.jl @@ -6,7 +6,7 @@ struct BatchProcessingError <: Exception end """ - pgenerate([::WorkerPool], f, c...) -> iterator + pgenerate([::AbstractWorkerPool], f, c...) -> iterator Apply `f` to each element of `c` in parallel using available workers and tasks. @@ -18,14 +18,14 @@ Note that `f` must be made available to all worker processes; see [Code Availability and Loading Packages](@ref code-availability) for details. """ -function pgenerate(p::WorkerPool, f, c) +function pgenerate(p::AbstractWorkerPool, f, c) if length(p) == 0 return AsyncGenerator(f, c; ntasks=()->nworkers(p)) end batches = batchsplit(c, min_batch_count = length(p) * 3) return Iterators.flatten(AsyncGenerator(remote(p, b -> asyncmap(f, b)), batches)) end -pgenerate(p::WorkerPool, f, c1, c...) = pgenerate(p, a->f(a...), zip(c1, c...)) +pgenerate(p::AbstractWorkerPool, f, c1, c...) = pgenerate(p, a->f(a...), zip(c1, c...)) pgenerate(f, c) = pgenerate(default_worker_pool(), f, c) pgenerate(f, c1, c...) = pgenerate(a->f(a...), zip(c1, c...)) diff --git a/src/workerpool.jl b/src/workerpool.jl index 89e52667c82c91..5dd1c07044e098 100644 --- a/src/workerpool.jl +++ b/src/workerpool.jl @@ -239,12 +239,14 @@ perform a `remote_do` on it. """ remote_do(f, pool::AbstractWorkerPool, args...; kwargs...) = remotecall_pool(remote_do, f, pool, args...; kwargs...) -const _default_worker_pool = Ref{Union{WorkerPool, Nothing}}(nothing) +const _default_worker_pool = Ref{Union{AbstractWorkerPool, Nothing}}(nothing) """ default_worker_pool() -[`WorkerPool`](@ref) containing idle [`workers`](@ref) - used by `remote(f)` and [`pmap`](@ref) (by default). +[`AbstractWorkerPool`](@ref) containing idle [`workers`](@ref) - used by `remote(f)` and [`pmap`](@ref) +(by default). Unless one is explicitly set via `default_worker_pool!(pool)`, the default worker pool is +initialized to a [`WorkerPool`](@ref). # Examples ```julia-repl @@ -267,6 +269,15 @@ function default_worker_pool() return _default_worker_pool[] end +""" + default_worker_pool!(pool::AbstractWorkerPool) + +Set a [`AbstractWorkerPool`](@ref) to be used by `remote(f)` and [`pmap`](@ref) (by default). +""" +function default_worker_pool!(pool::AbstractWorkerPool) + _default_worker_pool[] = pool +end + """ remote([p::AbstractWorkerPool], f) -> Function diff --git a/test/distributed_exec.jl b/test/distributed_exec.jl index 548ac73d2fb4c7..16d1e4b100bf3d 100644 --- a/test/distributed_exec.jl +++ b/test/distributed_exec.jl @@ -701,6 +701,19 @@ wp = CachingPool(workers()) clear!(wp) @test length(wp.map_obj2ref) == 0 +# default_worker_pool! tests +wp_default = Distributed.default_worker_pool() +try + wp = CachingPool(workers()) + Distributed.default_worker_pool!(wp) + @test [1:100...] == pmap(x->x, wp, 1:100) + @test !isempty(wp.map_obj2ref) + clear!(wp) + @test isempty(wp.map_obj2ref) +finally + Distributed.default_worker_pool!(wp_default) +end + # The below block of tests are usually run only on local development systems, since: # - tests which print errors # - addprocs tests are memory intensive