Skip to content

Commit

Permalink
Align module base between invalidation and edge tracking
Browse files Browse the repository at this point in the history
Our implicit edge tracking for bindings does not explicitly store
any edges for bindings in the *current* module. The idea behind
this is that this is a good time-space tradeoff for validation,
because substantially all binding references in a module will be
to its defining module, while the total number of methods within
a module is limited and substantially smaller than the total number
of methods in the entire system.

However, we have an issue where the code that stores these edges
and the invalidation code disagree on which module is the *current*
one. The edge storing code was using the module in which the method
was defined, while the invalidation code was using the one in which
the MethodTable is defined. With these being misaligned, we can
miss necessary invalidations.

Both options are in principle possible, but I think the former is
better, because the module in which the method is defined is also
the module that we are likely to have a lot of references to (since
they get referenced implicitly by just writing symbols in the code).

However, this presents a problem: We don't actually have a way to
iterate all the methods defined in a particular module, without just
doing the brute force thing of scanning all methods and filtering.

To address this, build on the deferred scanning code added in #57615
to also add any scanned modules to an explicit list in `Module`. This
costs some space, but only proportional to the number of defined methods,
(and thus proportional to the written source code).

Note that we don't actually observe any issues in the test suite on
master due to this bug. However, this is because we are grossly
over-invalidating, which hides the missing invalidations from this
issue (#57617).
  • Loading branch information
Keno committed Mar 3, 2025
1 parent e7efe42 commit a8a1c3b
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 35 deletions.
2 changes: 1 addition & 1 deletion Compiler/src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ function retrieve_code_info(mi::MethodInstance, world::UInt)
else
c = copy(src::CodeInfo)
end
if !def.did_scan_source
if (def.did_scan_source & 0x1) == 0x0
# This scan must happen:
# 1. After method definition
# 2. Before any code instances that may have relied on information
Expand Down
28 changes: 8 additions & 20 deletions base/invalidation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,10 @@ function invalidate_code_for_globalref!(b::Core.Binding, invalidated_bpart::Core

if need_to_invalidate_code
if (b.flags & BINDING_FLAG_ANY_IMPLICIT_EDGES) != 0
foreach_module_mtable(gr.mod, new_max_world) do mt::Core.MethodTable
for method in MethodList(mt)
invalidate_method_for_globalref!(gr, method, invalidated_bpart, new_max_world)
end
return true
nmethods = ccall(:jl_module_scanned_methods_length, Csize_t, (Any,), gr.mod)
for i = 1:nmethods
method = ccall(:jl_module_scanned_methods_getindex, Any, (Any, Csize_t), gr.mod, i)::Method
invalidate_method_for_globalref!(gr, method, invalidated_bpart, new_max_world)
end
end
if isdefined(b, :backedges)
Expand All @@ -166,7 +165,7 @@ function invalidate_code_for_globalref!(b::Core.Binding, invalidated_bpart::Core
# have a binding that is affected by this change.
usings_backedges = ccall(:jl_get_module_usings_backedges, Any, (Any,), gr.mod)
if usings_backedges !== nothing
for user in usings_backedges::Vector{Any}
for user::Module in usings_backedges::Vector{Any}
user_binding = ccall(:jl_get_module_binding_or_nothing, Any, (Any, Any), user, gr.name)
user_binding === nothing && continue
isdefined(user_binding, :partitions) || continue
Expand All @@ -186,21 +185,10 @@ end
invalidate_code_for_globalref!(gr::GlobalRef, invalidated_bpart::Core.BindingPartition, new_bpart::Core.BindingPartition, new_max_world::UInt) =
invalidate_code_for_globalref!(convert(Core.Binding, gr), invalidated_bpart, new_bpart, new_max_world)

gr_needs_backedge_in_module(gr::GlobalRef, mod::Module) = gr.mod !== mod

# N.B.: This needs to match jl_maybe_add_binding_backedge
function maybe_add_binding_backedge!(b::Core.Binding, edge::Union{Method, CodeInstance})
method = isa(edge, Method) ? edge : edge.def.def::Method
methmod = method.module
if !gr_needs_backedge_in_module(b.globalref, methmod)
@atomic :acquire_release b.flags |= BINDING_FLAG_ANY_IMPLICIT_EDGES
return
end
if !isdefined(b, :backedges)
b.backedges = Any[]
end
!isempty(b.backedges) && b.backedges[end] === edge && return
push!(b.backedges, edge)
meth = isa(edge, Method) ? edge : Compiler.get_ci_mi(edge).def
ccall(:jl_maybe_add_binding_backedge, Cint, (Any, Any, Any), b, edge, meth)
return nothing
end

function binding_was_invalidated(b::Core.Binding)
Expand Down
3 changes: 3 additions & 0 deletions src/gc-stock.c
Original file line number Diff line number Diff line change
Expand Up @@ -2147,6 +2147,9 @@ STATIC_INLINE void gc_mark_module_binding(jl_ptls_t ptls, jl_module_t *mb_parent
gc_assert_parent_validity((jl_value_t *)mb_parent, (jl_value_t *)mb_parent->usings_backedges);
gc_try_claim_and_push(mq, (jl_value_t *)mb_parent->usings_backedges, &nptr);
gc_heap_snapshot_record_binding_partition_edge((jl_value_t*)mb_parent, mb_parent->usings_backedges);
gc_assert_parent_validity((jl_value_t *)mb_parent, (jl_value_t *)mb_parent->scanned_methods);
gc_try_claim_and_push(mq, (jl_value_t *)mb_parent->scanned_methods, &nptr);
gc_heap_snapshot_record_binding_partition_edge((jl_value_t*)mb_parent, mb_parent->scanned_methods);
size_t nusings = module_usings_length(mb_parent);
if (nusings > 0) {
// this is only necessary because bindings for "using" modules
Expand Down
2 changes: 1 addition & 1 deletion src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -3601,7 +3601,7 @@ void jl_init_types(void) JL_GC_DISABLED
jl_bool_type,
jl_bool_type,
jl_bool_type,
jl_bool_type,
jl_uint8_type,
jl_uint8_type,
jl_uint8_type,
jl_uint16_type),
Expand Down
4 changes: 4 additions & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,8 @@ typedef struct _jl_method_t {
uint8_t isva;
uint8_t is_for_opaque_closure;
uint8_t nospecializeinfer;
// bit flags, 0x01 = scanned
// 0x02 = added to module scanned list (either from scanning or inference edge)
_Atomic(uint8_t) did_scan_source;

// uint8 settings
Expand Down Expand Up @@ -782,6 +784,7 @@ typedef struct _jl_module_t {
jl_sym_t *file;
int32_t line;
jl_value_t *usings_backedges;
jl_value_t *scanned_methods;
// hidden fields:
arraylist_t usings; /* arraylist of struct jl_module_using */ // modules with all bindings potentially imported
jl_uuid_t build_id;
Expand Down Expand Up @@ -2059,6 +2062,7 @@ JL_DLLEXPORT int jl_get_module_infer(jl_module_t *m);
JL_DLLEXPORT void jl_set_module_max_methods(jl_module_t *self, int value);
JL_DLLEXPORT int jl_get_module_max_methods(jl_module_t *m);
JL_DLLEXPORT jl_value_t *jl_get_module_usings_backedges(jl_module_t *m);
JL_DLLEXPORT jl_value_t *jl_get_module_scanned_methods(jl_module_t *m);
JL_DLLEXPORT jl_value_t *jl_get_module_binding_or_nothing(jl_module_t *m, jl_sym_t *s);

// get binding for reading
Expand Down
3 changes: 2 additions & 1 deletion src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ jl_code_info_t *jl_new_code_info_from_ir(jl_expr_t *ast);
JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void);
JL_DLLEXPORT void jl_resolve_definition_effects_in_ir(jl_array_t *stmts, jl_module_t *m, jl_svec_t *sparam_vals, jl_value_t *binding_edge,
int binding_effects);
JL_DLLEXPORT void jl_maybe_add_binding_backedge(jl_globalref_t *gr, jl_module_t *defining_module, jl_value_t *edge);
JL_DLLEXPORT int jl_maybe_add_binding_backedge(jl_binding_t *b, jl_value_t *edge, jl_method_t *in_method);
JL_DLLEXPORT void jl_add_binding_backedge(jl_binding_t *b, jl_value_t *edge);

int get_next_edge(jl_array_t *list, int i, jl_value_t** invokesig, jl_code_instance_t **caller) JL_NOTSAFEPOINT;
Expand Down Expand Up @@ -878,6 +878,7 @@ STATIC_INLINE size_t module_usings_max(jl_module_t *m) JL_NOTSAFEPOINT {
}

JL_DLLEXPORT jl_sym_t *jl_module_name(jl_module_t *m) JL_NOTSAFEPOINT;
void jl_add_scanned_method(jl_module_t *m, jl_method_t *meth);
jl_value_t *jl_eval_global_var(jl_module_t *m JL_PROPAGATES_ROOT, jl_sym_t *e);
jl_value_t *jl_interpret_opaque_closure(jl_opaque_closure_t *clos, jl_value_t **args, size_t nargs);
jl_value_t *jl_interpret_toplevel_thunk(jl_module_t *m, jl_code_info_t *src);
Expand Down
23 changes: 20 additions & 3 deletions src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,20 @@ static void check_c_types(const char *where, jl_value_t *rt, jl_value_t *at)
}
}

void jl_add_scanned_method(jl_module_t *m, jl_method_t *meth)
{
JL_LOCK(&m->lock);
if (m->scanned_methods == jl_nothing) {
m->scanned_methods = (jl_value_t*)jl_alloc_vec_any(0);
jl_gc_wb(m, m->scanned_methods);
}
jl_array_ptr_1d_push((jl_array_t*)m->scanned_methods, (jl_value_t*)meth);
JL_UNLOCK(&m->lock);
}

JL_DLLEXPORT void jl_scan_method_source_now(jl_method_t *m, jl_value_t *src)
{
if (!jl_atomic_load_relaxed(&m->did_scan_source)) {
if (!jl_atomic_fetch_or(&m->did_scan_source, 1)) {
jl_code_info_t *code = NULL;
JL_GC_PUSH1(&code);
if (!jl_is_code_info(src))
Expand All @@ -50,13 +61,19 @@ JL_DLLEXPORT void jl_scan_method_source_now(jl_method_t *m, jl_value_t *src)
code = (jl_code_info_t*)src;
jl_array_t *stmts = code->code;
size_t i, l = jl_array_nrows(stmts);
int any_implicit = 0;
for (i = 0; i < l; i++) {
jl_value_t *stmt = jl_array_ptr_ref(stmts, i);
if (jl_is_globalref(stmt)) {
jl_maybe_add_binding_backedge((jl_globalref_t*)stmt, m->module, (jl_value_t*)m);
jl_globalref_t *gr = (jl_globalref_t*)stmt;
jl_binding_t *b = gr->binding;
if (!b)
b = jl_get_module_binding(gr->mod, gr->name, 1);
any_implicit |= jl_maybe_add_binding_backedge(b, (jl_value_t*)m, m);
}
}
jl_atomic_store_relaxed(&m->did_scan_source, 1);
if (any_implicit && !(jl_atomic_fetch_or(&m->did_scan_source, 0x2) & 0x2))
jl_add_scanned_method(m->module, m);
JL_GC_POP();
}
}
Expand Down
40 changes: 31 additions & 9 deletions src/module.c
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,8 @@ STATIC_INLINE jl_binding_partition_t *jl_get_binding_partition_(jl_binding_t *b
if (!new_bpart)
new_bpart = new_binding_partition();
jl_atomic_store_relaxed(&new_bpart->next, bpart);
jl_gc_wb(new_bpart, bpart); // Not fresh the second time around the loop
if (bpart)
jl_gc_wb(new_bpart, bpart); // Not fresh the second time around the loop
new_bpart->min_world = bpart ? jl_atomic_load_relaxed(&bpart->max_world) + 1 : 0;
jl_atomic_store_relaxed(&new_bpart->max_world, max_world);
JL_GC_PROMISE_ROOTED(new_bpart); // TODO: Analyzer doesn't understand MAYBE_UNROOTED properly
Expand Down Expand Up @@ -319,6 +320,7 @@ JL_DLLEXPORT jl_module_t *jl_new_module__(jl_sym_t *name, jl_module_t *parent)
m->build_id.hi = ~(uint64_t)0;
jl_atomic_store_relaxed(&m->counter, 1);
m->usings_backedges = jl_nothing;
m->scanned_methods = jl_nothing;
m->nospecialize = 0;
m->optlevel = -1;
m->compile = -1;
Expand Down Expand Up @@ -1163,6 +1165,25 @@ JL_DLLEXPORT jl_value_t *jl_get_module_usings_backedges(jl_module_t *m)
return m->usings_backedges;
}

JL_DLLEXPORT size_t jl_module_scanned_methods_length(jl_module_t *m)
{
JL_LOCK(&m->lock);
size_t len = 0;
if (m->scanned_methods != jl_nothing)
len = jl_array_len(m->scanned_methods);
JL_UNLOCK(&m->lock);
return len;
}

JL_DLLEXPORT jl_value_t *jl_module_scanned_methods_getindex(jl_module_t *m, size_t i)
{
JL_LOCK(&m->lock);
assert(m->scanned_methods != jl_nothing);
jl_value_t *ret = jl_array_ptr_ref(m->scanned_methods, i-1);
JL_UNLOCK(&m->lock);
return ret;
}

JL_DLLEXPORT jl_value_t *jl_get_module_binding_or_nothing(jl_module_t *m, jl_sym_t *s)
{
jl_binding_t *b = jl_get_module_binding(m, s, 0);
Expand Down Expand Up @@ -1369,21 +1390,22 @@ JL_DLLEXPORT void jl_add_binding_backedge(jl_binding_t *b, jl_value_t *edge)

// Called for all GlobalRefs found in lowered code. Adds backedges for cross-module
// GlobalRefs.
JL_DLLEXPORT void jl_maybe_add_binding_backedge(jl_globalref_t *gr, jl_module_t *defining_module, jl_value_t *edge)
JL_DLLEXPORT int jl_maybe_add_binding_backedge(jl_binding_t *b, jl_value_t *edge, jl_method_t *for_method)
{
if (!edge)
return;
jl_binding_t *b = gr->binding;
if (!b)
b = jl_get_module_binding(gr->mod, gr->name, 1);
return 0;
jl_module_t *defining_module = for_method->module;
// N.B.: The logic for evaluating whether a backedge is required must
// match the invalidation logic.
if (gr->mod == defining_module) {
if (b->globalref->mod == defining_module) {
// No backedge required - invalidation will forward scan
jl_atomic_fetch_or(&b->flags, BINDING_FLAG_ANY_IMPLICIT_EDGES);
return;
if (!(jl_atomic_fetch_or(&for_method->did_scan_source, 0x2) & 0x2))
jl_add_scanned_method(for_method->module, for_method);
return 1;
}
jl_add_binding_backedge(b, edge);
jl_add_binding_backedge(b, (jl_value_t*)edge);
return 0;
}

JL_DLLEXPORT jl_binding_partition_t *jl_replace_binding_locked(jl_binding_t *b,
Expand Down
4 changes: 4 additions & 0 deletions src/staticdata.c
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,7 @@ static void jl_queue_module_for_serialization(jl_serializer_state *s, jl_module_
}

jl_queue_for_serialization(s, m->usings_backedges);
jl_queue_for_serialization(s, m->scanned_methods);
}

// Anything that requires uniquing or fixing during deserialization needs to be "toplevel"
Expand Down Expand Up @@ -1324,6 +1325,9 @@ static void jl_write_module(jl_serializer_state *s, uintptr_t item, jl_module_t
newm->usings_backedges = NULL;
arraylist_push(&s->relocs_list, (void*)(reloc_offset + offsetof(jl_module_t, usings_backedges)));
arraylist_push(&s->relocs_list, (void*)backref_id(s, m->usings_backedges, s->link_ids_relocs));
newm->scanned_methods = NULL;
arraylist_push(&s->relocs_list, (void*)(reloc_offset + offsetof(jl_module_t, scanned_methods)));
arraylist_push(&s->relocs_list, (void*)backref_id(s, m->scanned_methods, s->link_ids_relocs));

// After reload, everything that has happened in this process happened semantically at
// (for .incremental) or before jl_require_world, so reset this flag.
Expand Down

0 comments on commit a8a1c3b

Please sign in to comment.