Skip to content

Commit

Permalink
Align module base between invalidation and edge tracking (#57625)
Browse files Browse the repository at this point in the history
Our implicit edge tracking for bindings does not explicitly store any
edges for bindings in the *current* module. The idea behind this is that
this is a good time-space tradeoff for validation, because substantially
all binding references in a module will be to its defining module, while
the total number of methods within a module is limited and substantially
smaller than the total number of methods in the entire system.

However, we have an issue where the code that stores these edges and the
invalidation code disagree on which module is the *current* one. The
edge storing code was using the module in which the method was defined,
while the invalidation code was using the one in which the MethodTable
is defined. With these being misaligned, we can miss necessary
invalidations.

Both options are in principle possible, but I think the former is
better, because the module in which the method is defined is also the
module that we are likely to have a lot of references to (since they get
referenced implicitly by just writing symbols in the code).

However, this presents a problem: We don't actually have a way to
iterate all the methods defined in a particular module, without just
doing the brute force thing of scanning all methods and filtering.

To address this, build on the deferred scanning code added in #57615 to
also add any scanned modules to an explicit list in `Module`. This costs
some space, but only proportional to the number of defined methods, (and
thus proportional to the written source code).

Note that we don't actually observe any issues in the test suite on
master due to this bug. However, this is because we are grossly
over-invalidating, which hides the missing invalidations from this issue
(#57617).

(cherry picked from commit 274d80e)
  • Loading branch information
Keno authored and KristofferC committed Mar 4, 2025
1 parent 7cf1db3 commit 686c917
Show file tree
Hide file tree
Showing 14 changed files with 264 additions and 64 deletions.
18 changes: 10 additions & 8 deletions Compiler/src/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3594,18 +3594,20 @@ scan_partitions(query::Function, interp, g::GlobalRef, wwr::WorldWithRange) =
abstract_load_all_consistent_leaf_partitions(interp, g::GlobalRef, wwr::WorldWithRange) =
scan_leaf_partitions(abstract_eval_partition_load, interp, g, wwr)

function abstract_eval_globalref_partition(interp, binding::Core.Binding, partition::Core.BindingPartition)
# For inference purposes, we don't particularly care which global binding we end up loading, we only
# care about its type. However, we would still like to terminate the world range for the particular
# binding we end up reaching such that codegen can emit a simpler pointer load.
Pair{RTEffects, Union{Nothing, Core.Binding}}(
abstract_eval_partition_load(interp, partition),
binding_kind(partition) in (PARTITION_KIND_GLOBAL, PARTITION_KIND_DECLARED) ? binding : nothing)
end

function abstract_eval_globalref(interp, g::GlobalRef, saw_latestworld::Bool, sv::AbsIntState)
if saw_latestworld
return RTEffects(Any, Any, generic_getglobal_effects)
end
(valid_worlds, (ret, binding_if_global)) = scan_leaf_partitions(interp, g, sv.world) do interp, binding, partition
# For inference purposes, we don't particularly care which global binding we end up loading, we only
# care about its type. However, we would still like to terminate the world range for the particular
# binding we end up reaching such that codegen can emit a simpler pointer load.
Pair{RTEffects, Union{Nothing, Core.Binding}}(
abstract_eval_partition_load(interp, partition),
binding_kind(partition) in (PARTITION_KIND_GLOBAL, PARTITION_KIND_DECLARED) ? binding : nothing)
end
(valid_worlds, (ret, binding_if_global)) = scan_leaf_partitions(abstract_eval_globalref_partition, interp, g, sv.world)
update_valid_age!(sv, valid_worlds)
if ret.rt !== Union{} && ret.exct === UndefVarError && binding_if_global !== nothing && InferenceParams(interp).assume_bindings_static
if isdefined(binding_if_global, :value)
Expand Down
19 changes: 19 additions & 0 deletions Compiler/src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,25 @@ function retrieve_code_info(mi::MethodInstance, world::UInt)
else
c = copy(src::CodeInfo)
end
if (def.did_scan_source & 0x1) == 0x0
# This scan must happen:
# 1. After method definition
# 2. Before any code instances that may have relied on information
# from implicit GlobalRefs for this method are added to the cache
# 3. Preferably while the IR is already uncompressed
# 4. As late as possible, as early adding of the backedges may cause
# spurious invalidations.
#
# At the moment we do so here, because
# 1. It's reasonably late
# 2. It has easy access to the uncompressed IR
# 3. We necessarily pass through here before relying on any
# information obtained from implicit GlobalRefs.
#
# However, the exact placement of this scan is not as important as
# long as the above conditions are met.
ccall(:jl_scan_method_source_now, Cvoid, (Any, Any), def, c)
end
end
if c isa CodeInfo
c.parent = mi
Expand Down
4 changes: 4 additions & 0 deletions base/expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1344,6 +1344,10 @@ function make_atomic(order, ex)
op = :+
elseif ex.head === :(-=)
op = :-
elseif ex.head === :(|=)
op = :|
elseif ex.head === :(&=)
op = :&
elseif @isdefined string
shead = string(ex.head)
if endswith(shead, '=')
Expand Down
59 changes: 38 additions & 21 deletions base/invalidation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,34 @@ function invalidate_method_for_globalref!(gr::GlobalRef, method::Method, invalid
end
end

function invalidate_code_for_globalref!(b::Core.Binding, invalidated_bpart::Core.BindingPartition, new_bpart::Union{Core.BindingPartition, Nothing}, new_max_world::UInt)
export_affecting_partition_flags(bpart::Core.BindingPartition) =
((bpart.kind & PARTITION_MASK_KIND) == PARTITION_KIND_GUARD,
(bpart.kind & PARTITION_FLAG_EXPORTED) != 0,
(bpart.kind & PARTITION_FLAG_DEPRECATED) != 0)

function invalidate_code_for_globalref!(b::Core.Binding, invalidated_bpart::Core.BindingPartition, new_bpart::Core.BindingPartition, new_max_world::UInt)
gr = b.globalref
if !is_some_guard(binding_kind(invalidated_bpart))
# TODO: We may want to invalidate for these anyway, since they have performance implications
foreach_module_mtable(gr.mod, new_max_world) do mt::Core.MethodTable
for method in MethodList(mt)

(_, (ib, ibpart)) = Compiler.walk_binding_partition(b, invalidated_bpart, new_max_world)
(_, (nb, nbpart)) = Compiler.walk_binding_partition(b, new_bpart, new_max_world+1)

# abstract_eval_globalref_partition is the maximum amount of information that inference
# reads from a binding partition. If this information does not change - we do not need to
# invalidate any code that inference created, because we know that the result will not change.
need_to_invalidate_code =
Compiler.abstract_eval_globalref_partition(nothing, ib, ibpart) !==
Compiler.abstract_eval_globalref_partition(nothing, nb, nbpart)

need_to_invalidate_export = export_affecting_partition_flags(invalidated_bpart) !==
export_affecting_partition_flags(new_bpart)

if need_to_invalidate_code
if (b.flags & BINDING_FLAG_ANY_IMPLICIT_EDGES) != 0
nmethods = ccall(:jl_module_scanned_methods_length, Csize_t, (Any,), gr.mod)
for i = 1:nmethods
method = ccall(:jl_module_scanned_methods_getindex, Any, (Any, Csize_t), gr.mod, i)::Method
invalidate_method_for_globalref!(gr, method, invalidated_bpart, new_max_world)
end
return true
end
if isdefined(b, :backedges)
for edge in b.backedges
Expand All @@ -133,45 +152,43 @@ function invalidate_code_for_globalref!(b::Core.Binding, invalidated_bpart::Core
latest_bpart.max_world == typemax(UInt) || continue
is_some_imported(binding_kind(latest_bpart)) || continue
partition_restriction(latest_bpart) === b || continue
invalidate_code_for_globalref!(edge, latest_bpart, nothing, new_max_world)
invalidate_code_for_globalref!(edge, latest_bpart, latest_bpart, new_max_world)
else
invalidate_method_for_globalref!(gr, edge::Method, invalidated_bpart, new_max_world)
end
end
end
end
if (invalidated_bpart.kind & PARTITION_FLAG_EXPORTED != 0) || (new_bpart !== nothing && (new_bpart.kind & PARTITION_FLAG_EXPORTED != 0))

if need_to_invalidate_code || need_to_invalidate_export
# This binding was exported - we need to check all modules that `using` us to see if they
# have a binding that is affected by this change.
usings_backedges = ccall(:jl_get_module_usings_backedges, Any, (Any,), gr.mod)
if usings_backedges !== nothing
for user in usings_backedges::Vector{Any}
for user::Module in usings_backedges::Vector{Any}
user_binding = ccall(:jl_get_module_binding_or_nothing, Any, (Any, Any), user, gr.name)
user_binding === nothing && continue
isdefined(user_binding, :partitions) || continue
latest_bpart = user_binding.partitions
latest_bpart.max_world == typemax(UInt) || continue
binding_kind(latest_bpart) in (PARTITION_KIND_IMPLICIT, PARTITION_KIND_FAILED, PARTITION_KIND_GUARD) || continue
@atomic :release latest_bpart.max_world = new_max_world
invalidate_code_for_globalref!(convert(Core.Binding, user_binding), latest_bpart, nothing, new_max_world)
new_bpart = need_to_invalidate_export ?
ccall(:jl_maybe_reresolve_implicit, Any, (Any, Any, Csize_t), user_binding, latest_bpart, new_max_world) :
latest_bpart
if need_to_invalidate_code || new_bpart !== latest_bpart
invalidate_code_for_globalref!(convert(Core.Binding, user_binding), latest_bpart, new_bpart, new_max_world)
end
end
end
end
end
invalidate_code_for_globalref!(gr::GlobalRef, invalidated_bpart::Core.BindingPartition, new_bpart::Core.BindingPartition, new_max_world::UInt) =
invalidate_code_for_globalref!(convert(Core.Binding, gr), invalidated_bpart, new_bpart, new_max_world)

gr_needs_backedge_in_module(gr::GlobalRef, mod::Module) = gr.mod !== mod

# N.B.: This needs to match jl_maybe_add_binding_backedge
function maybe_add_binding_backedge!(b::Core.Binding, edge::Union{Method, CodeInstance})
method = isa(edge, Method) ? edge : edge.def.def::Method
gr_needs_backedge_in_module(b.globalref, method.module) || return
if !isdefined(b, :backedges)
b.backedges = Any[]
end
!isempty(b.backedges) && b.backedges[end] === edge && return
push!(b.backedges, edge)
meth = isa(edge, Method) ? edge : Compiler.get_ci_mi(edge).def
ccall(:jl_maybe_add_binding_backedge, Cint, (Any, Any, Any), b, edge, meth)
return nothing
end

function binding_was_invalidated(b::Core.Binding)
Expand Down
2 changes: 2 additions & 0 deletions base/runtime_internals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ const PARTITION_FLAG_DEPWARN = 0x40
const PARTITION_MASK_KIND = 0x0f
const PARTITION_MASK_FLAG = 0xf0

const BINDING_FLAG_ANY_IMPLICIT_EDGES = 0x8

is_defined_const_binding(kind::UInt8) = (kind == PARTITION_KIND_CONST || kind == PARTITION_KIND_CONST_IMPORT || kind == PARTITION_KIND_BACKDATED_CONST)
is_some_const_binding(kind::UInt8) = (is_defined_const_binding(kind) || kind == PARTITION_KIND_UNDEF_CONST)
is_some_imported(kind::UInt8) = (kind == PARTITION_KIND_IMPLICIT || kind == PARTITION_KIND_EXPLICIT || kind == PARTITION_KIND_IMPORTED)
Expand Down
3 changes: 3 additions & 0 deletions src/gc-stock.c
Original file line number Diff line number Diff line change
Expand Up @@ -2147,6 +2147,9 @@ STATIC_INLINE void gc_mark_module_binding(jl_ptls_t ptls, jl_module_t *mb_parent
gc_assert_parent_validity((jl_value_t *)mb_parent, (jl_value_t *)mb_parent->usings_backedges);
gc_try_claim_and_push(mq, (jl_value_t *)mb_parent->usings_backedges, &nptr);
gc_heap_snapshot_record_binding_partition_edge((jl_value_t*)mb_parent, mb_parent->usings_backedges);
gc_assert_parent_validity((jl_value_t *)mb_parent, (jl_value_t *)mb_parent->scanned_methods);
gc_try_claim_and_push(mq, (jl_value_t *)mb_parent->scanned_methods, &nptr);
gc_heap_snapshot_record_binding_partition_edge((jl_value_t*)mb_parent, mb_parent->scanned_methods);
size_t nusings = module_usings_length(mb_parent);
if (nusings > 0) {
// this is only necessary because bindings for "using" modules
Expand Down
49 changes: 37 additions & 12 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -1839,7 +1839,7 @@ JL_DLLEXPORT jl_value_t *jl_debug_method_invalidation(int state)
return jl_nothing;
}

static void _invalidate_backedges(jl_method_instance_t *replaced_mi, size_t max_world, int depth);
static void _invalidate_backedges(jl_method_instance_t *replaced_mi, jl_code_instance_t *replaced_ci, size_t max_world, int depth);

// recursively invalidate cached methods that had an edge to a replaced method
static void invalidate_code_instance(jl_code_instance_t *replaced, size_t max_world, int depth)
Expand All @@ -1858,13 +1858,15 @@ static void invalidate_code_instance(jl_code_instance_t *replaced, size_t max_wo
if (!jl_is_method(replaced_mi->def.method))
return; // shouldn't happen, but better to be safe
JL_LOCK(&replaced_mi->def.method->writelock);
if (jl_atomic_load_relaxed(&replaced->max_world) == ~(size_t)0) {
size_t replacedmaxworld = jl_atomic_load_relaxed(&replaced->max_world);
if (replacedmaxworld == ~(size_t)0) {
assert(jl_atomic_load_relaxed(&replaced->min_world) - 1 <= max_world && "attempting to set illogical world constraints (probable race condition)");
jl_atomic_store_release(&replaced->max_world, max_world);
// recurse to all backedges to update their valid range also
_invalidate_backedges(replaced_mi, replaced, max_world, depth + 1);
} else {
assert(jl_atomic_load_relaxed(&replaced->max_world) <= max_world);
}
assert(jl_atomic_load_relaxed(&replaced->max_world) <= max_world);
// recurse to all backedges to update their valid range also
_invalidate_backedges(replaced_mi, max_world, depth + 1);
JL_UNLOCK(&replaced_mi->def.method->writelock);
}

Expand All @@ -1873,19 +1875,42 @@ JL_DLLEXPORT void jl_invalidate_code_instance(jl_code_instance_t *replaced, size
invalidate_code_instance(replaced, max_world, 1);
}

static void _invalidate_backedges(jl_method_instance_t *replaced_mi, size_t max_world, int depth) {
static void _invalidate_backedges(jl_method_instance_t *replaced_mi, jl_code_instance_t *replaced_ci, size_t max_world, int depth) {
jl_array_t *backedges = replaced_mi->backedges;
if (backedges) {
// invalidate callers (if any)
replaced_mi->backedges = NULL;
JL_GC_PUSH1(&backedges);
size_t i = 0, l = jl_array_nrows(backedges);
size_t ins = 0;
jl_code_instance_t *replaced;
while (i < l) {
i = get_next_edge(backedges, i, NULL, &replaced);
jl_value_t *invokesig = NULL;
i = get_next_edge(backedges, i, &invokesig, &replaced);
JL_GC_PROMISE_ROOTED(replaced); // propagated by get_next_edge from backedges
if (replaced_ci) {
// If we're invalidating a particular codeinstance, only invalidate
// this backedge it actually has an edge for our codeinstance.
jl_svec_t *edges = jl_atomic_load_relaxed(&replaced->edges);
for (size_t j = 0; j < jl_svec_len(edges); ++j) {
jl_value_t *edge = jl_svecref(edges, j);
if (edge == (jl_value_t*)replaced_mi || edge == (jl_value_t*)replaced_ci)
goto found;
}
// Keep this entry in the backedge list, but compact it
ins = set_next_edge(backedges, ins, invokesig, replaced);
continue;
found:;
}
invalidate_code_instance(replaced, max_world, depth);
}
if (replaced_ci && ins != 0) {
jl_array_del_end(backedges, l - ins);
// If we're only invalidating one ci, we don't know which ci any particular
// backedge was for, so we can't delete them. Put them back.
replaced_mi->backedges = backedges;
jl_gc_wb(replaced_mi, backedges);
}
JL_GC_POP();
}
}
Expand All @@ -1894,7 +1919,7 @@ static void _invalidate_backedges(jl_method_instance_t *replaced_mi, size_t max_
static void invalidate_backedges(jl_method_instance_t *replaced_mi, size_t max_world, const char *why)
{
JL_LOCK(&replaced_mi->def.method->writelock);
_invalidate_backedges(replaced_mi, max_world, 1);
_invalidate_backedges(replaced_mi, NULL, max_world, 1);
JL_UNLOCK(&replaced_mi->def.method->writelock);
if (why && _jl_debug_method_invalidation) {
jl_array_ptr_1d_push(_jl_debug_method_invalidation, (jl_value_t*)replaced_mi);
Expand Down Expand Up @@ -1928,8 +1953,8 @@ JL_DLLEXPORT void jl_method_instance_add_backedge(jl_method_instance_t *callee,
size_t i = 0, l = jl_array_nrows(callee->backedges);
for (i = 0; i < l; i++) {
// optimized version of while (i < l) i = get_next_edge(callee->backedges, i, &invokeTypes, &mi);
jl_value_t *mi = jl_array_ptr_ref(callee->backedges, i);
if (mi != (jl_value_t*)caller)
jl_value_t *ciedge = jl_array_ptr_ref(callee->backedges, i);
if (ciedge != (jl_value_t*)caller)
continue;
jl_value_t *invokeTypes = i > 0 ? jl_array_ptr_ref(callee->backedges, i - 1) : NULL;
if (invokeTypes && jl_is_method_instance(invokeTypes))
Expand Down Expand Up @@ -2372,7 +2397,7 @@ void jl_method_table_activate(jl_methtable_t *mt, jl_typemap_entry_t *newentry)
continue;
loctag = jl_atomic_load_relaxed(&m->specializations); // use loctag for a gcroot
_Atomic(jl_method_instance_t*) *data;
size_t i, l;
size_t l;
if (jl_is_svec(loctag)) {
data = (_Atomic(jl_method_instance_t*)*)jl_svec_data(loctag);
l = jl_svec_len(loctag);
Expand All @@ -2382,7 +2407,7 @@ void jl_method_table_activate(jl_methtable_t *mt, jl_typemap_entry_t *newentry)
l = 1;
}
enum morespec_options ambig = morespec_unknown;
for (i = 0; i < l; i++) {
for (size_t i = 0; i < l; i++) {
jl_method_instance_t *mi = jl_atomic_load_relaxed(&data[i]);
if ((jl_value_t*)mi == jl_nothing)
continue;
Expand Down
8 changes: 5 additions & 3 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -3275,7 +3275,7 @@ void jl_init_types(void) JL_GC_DISABLED
jl_svec(5, jl_any_type/*jl_globalref_type*/, jl_any_type, jl_binding_partition_type,
jl_any_type, jl_uint8_type),
jl_emptysvec, 0, 1, 0);
const static uint32_t binding_atomicfields[] = { 0x0005 }; // Set fields 2, 3 as atomic
const static uint32_t binding_atomicfields[] = { 0x0016 }; // Set fields 2, 3, 5 as atomic
jl_binding_type->name->atomicfields = binding_atomicfields;
const static uint32_t binding_constfields[] = { 0x0001 }; // Set fields 1 as constant
jl_binding_type->name->constfields = binding_constfields;
Expand Down Expand Up @@ -3539,7 +3539,7 @@ void jl_init_types(void) JL_GC_DISABLED
jl_method_type =
jl_new_datatype(jl_symbol("Method"), core,
jl_any_type, jl_emptysvec,
jl_perm_symsvec(31,
jl_perm_symsvec(32,
"name",
"module",
"file",
Expand Down Expand Up @@ -3568,10 +3568,11 @@ void jl_init_types(void) JL_GC_DISABLED
"isva",
"is_for_opaque_closure",
"nospecializeinfer",
"did_scan_source",
"constprop",
"max_varargs",
"purity"),
jl_svec(31,
jl_svec(32,
jl_symbol_type,
jl_module_type,
jl_symbol_type,
Expand Down Expand Up @@ -3602,6 +3603,7 @@ void jl_init_types(void) JL_GC_DISABLED
jl_bool_type,
jl_uint8_type,
jl_uint8_type,
jl_uint8_type,
jl_uint16_type),
jl_emptysvec,
0, 1, 10);
Expand Down
Loading

0 comments on commit 686c917

Please sign in to comment.