Skip to content

Commit

Permalink
fix #41546, make using thread-safe
Browse files Browse the repository at this point in the history
  • Loading branch information
JeffBezanson authored and vtjnash committed Oct 19, 2021
1 parent dc74954 commit 71c14e3
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 8 deletions.
20 changes: 13 additions & 7 deletions base/loading.jl
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ function parsed_toml(project_file::AbstractString, toml_cache::TOMLCache, toml_l
end

## package identification: determine unique identity of package to be loaded ##
const require_lock = ReentrantLock()

# Used by Pkg but not used in loading itself
function find_package(arg)
Expand Down Expand Up @@ -873,7 +874,7 @@ function _require_search_from_serialized(pkg::PkgId, sourcepath::String, depth::
end

# to synchronize multiple tasks trying to import/using something
const package_locks = Dict{PkgId,Condition}()
const package_locks = Dict{PkgId,Threads.Condition}()

# to notify downstream consumers that a module was successfully loaded
# Callbacks take the form (mod::Base.PkgId) -> nothing.
Expand Down Expand Up @@ -968,6 +969,7 @@ For more details regarding code loading, see the manual sections on [modules](@r
[parallel computing](@ref code-availability).
"""
function require(into::Module, mod::Symbol)
@lock require_lock begin
LOADING_CACHE[] = LoadingCache()
try
uuidkey = identify_package(into, String(mod))
Expand Down Expand Up @@ -1019,6 +1021,7 @@ function require(into::Module, mod::Symbol)
finally
LOADING_CACHE[] = nothing
end
end
end

mutable struct PkgOrigin
Expand All @@ -1030,6 +1033,7 @@ PkgOrigin() = PkgOrigin(nothing, nothing)
const pkgorigins = Dict{PkgId,PkgOrigin}()

function require(uuidkey::PkgId)
@lock require_lock begin
if !root_module_exists(uuidkey)
cachefile = _require(uuidkey)
if cachefile !== nothing
Expand All @@ -1041,13 +1045,14 @@ function require(uuidkey::PkgId)
end
end
return root_module(uuidkey)
end
end

const loaded_modules = Dict{PkgId,Module}()
const module_keys = IdDict{Module,PkgId}() # the reverse

is_root_module(m::Module) = haskey(module_keys, m)
root_module_key(m::Module) = module_keys[m]
is_root_module(m::Module) = @lock require_lock haskey(module_keys, m)
root_module_key(m::Module) = @lock require_lock module_keys[m]

function register_root_module(m::Module)
key = PkgId(m, String(nameof(m)))
Expand All @@ -1074,12 +1079,13 @@ using Base
end

# get a top-level Module from the given key
root_module(key::PkgId) = loaded_modules[key]
root_module(key::PkgId) = @lock require_lock loaded_modules[key]
root_module(where::Module, name::Symbol) =
root_module(identify_package(where, String(name)))
maybe_root_module(key::PkgId) = @lock require_lock get(loaded_modules, key, nothing)

root_module_exists(key::PkgId) = haskey(loaded_modules, key)
loaded_modules_array() = collect(values(loaded_modules))
root_module_exists(key::PkgId) = @lock require_lock haskey(loaded_modules, key)
loaded_modules_array() = @lock require_lock collect(values(loaded_modules))

function unreference_module(key::PkgId)
if haskey(loaded_modules, key)
Expand All @@ -1098,7 +1104,7 @@ function _require(pkg::PkgId)
wait(loading)
return
end
package_locks[pkg] = Condition()
package_locks[pkg] = Threads.Condition(require_lock)

last = toplevel_load[]
try
Expand Down
2 changes: 1 addition & 1 deletion base/toml_parser.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ function Parser(str::String; filepath=nothing)
IdSet{TOMLDict}(), # defined_tables
root,
filepath,
isdefined(Base, :loaded_modules) ? get(Base.loaded_modules, DATES_PKGID, nothing) : nothing,
isdefined(Base, :loaded_modules) ? Base.maybe_root_module(DATES_PKGID) : nothing,
)
startup(l)
return l
Expand Down
28 changes: 28 additions & 0 deletions test/threads_exec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -912,3 +912,31 @@ end
@test reproducible_rand(r, 10) == val
end
end

# issue #41546, thread-safe package loading
@testset "package loading" begin
ch = Channel{Bool}(nthreads())
barrier = Base.Event()
old_act_proj = Base.ACTIVE_PROJECT[]
try
pushfirst!(LOAD_PATH, "@")
Base.ACTIVE_PROJECT[] = joinpath(@__DIR__, "TestPkg")
@sync begin
for _ in 1:nthreads()
Threads.@spawn begin
put!(ch, true)
wait(barrier)
@eval using TestPkg
end
end
for _ in 1:nthreads()
take!(ch)
end
notify(barrier)
end
@test Base.root_module(@__MODULE__, :TestPkg) isa Module
finally
Base.ACTIVE_PROJECT[] = old_act_proj
popfirst!(LOAD_PATH)
end
end

0 comments on commit 71c14e3

Please sign in to comment.