Skip to content

Commit

Permalink
Align module base between invalidation and edge tracking
Browse files Browse the repository at this point in the history
Our implicit edge tracking for bindings does not explicitly store
any edges for bindings in the *current* module. The idea behind
this is that this is a good time-space tradeoff for validation,
because substantially all binding references in a module will be
to its defining module, while the total number of methods within
a module is limited and substantially smaller than the total number
of methods in the entire system.

However, we have an issue where the code that stores these edges
and the invalidation code disagree on which module is the *current*
one. The edge storing code was using the module in which the method
was defined, while the invalidation code was using the one in which
the MethodTable is defined. With these being misaligned, we can
miss necessary invalidations.

Both options are in principle possible, but I think the former is
better, because the module in which the method is defined is also
the module that we are likely to have a lot of references to (since
they get referenced implicitly by just writing symbols in the code).

However, this presents a problem: We don't actually have a way to
iterate all the methods defined in a particular module, without just
doing the brute force thing of scanning all methods and filtering.

To address this, build on the deferred scanning code added in #57615
to also add any scanned modules to an explicit list in `Module`. This
costs some space, but only proportional to the number of defined methods,
(and thus proportional to the written source code).

Note that we don't actually observe any issues in the test suite on
master due to this bug. However, this is because we are grossly
over-invalidating, which hides the missing invalidations from this
issue (#57617).
  • Loading branch information
Keno committed Mar 3, 2025
1 parent 44975e1 commit a12e1cc
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 13 deletions.
11 changes: 5 additions & 6 deletions base/invalidation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,10 @@ function invalidate_code_for_globalref!(b::Core.Binding, invalidated_bpart::Core

if need_to_invalidate_code
if (b.flags & BINDING_FLAG_ANY_IMPLICIT_EDGES) != 0
foreach_module_mtable(gr.mod, new_max_world) do mt::Core.MethodTable
for method in MethodList(mt)
invalidate_method_for_globalref!(gr, method, invalidated_bpart, new_max_world)
end
return true
nmethods = ccall(:jl_module_scanned_methods_length, Csize_t, (Any,), gr.mod)
for i = 1:nmethods
method = ccall(:jl_module_scanned_methods_getindex, Any, (Any, Csize_t), gr.mod, i)::Method
invalidate_method_for_globalref!(gr, method, invalidated_bpart, new_max_world)
end
end
if isdefined(b, :backedges)
Expand All @@ -166,7 +165,7 @@ function invalidate_code_for_globalref!(b::Core.Binding, invalidated_bpart::Core
# have a binding that is affected by this change.
usings_backedges = ccall(:jl_get_module_usings_backedges, Any, (Any,), gr.mod)
if usings_backedges !== nothing
for user in usings_backedges::Vector{Any}
for user::Module in usings_backedges::Vector{Any}
user_binding = ccall(:jl_get_module_binding_or_nothing, Any, (Any, Any), user, gr.name)
user_binding === nothing && continue
isdefined(user_binding, :partitions) || continue
Expand Down
3 changes: 3 additions & 0 deletions src/gc-stock.c
Original file line number Diff line number Diff line change
Expand Up @@ -2147,6 +2147,9 @@ STATIC_INLINE void gc_mark_module_binding(jl_ptls_t ptls, jl_module_t *mb_parent
gc_assert_parent_validity((jl_value_t *)mb_parent, (jl_value_t *)mb_parent->usings_backedges);
gc_try_claim_and_push(mq, (jl_value_t *)mb_parent->usings_backedges, &nptr);
gc_heap_snapshot_record_binding_partition_edge((jl_value_t*)mb_parent, mb_parent->usings_backedges);
gc_assert_parent_validity((jl_value_t *)mb_parent, (jl_value_t *)mb_parent->scanned_methods);
gc_try_claim_and_push(mq, (jl_value_t *)mb_parent->scanned_methods, &nptr);
gc_heap_snapshot_record_binding_partition_edge((jl_value_t*)mb_parent, mb_parent->scanned_methods);
size_t nusings = module_usings_length(mb_parent);
if (nusings > 0) {
// this is only necessary because bindings for "using" modules
Expand Down
2 changes: 2 additions & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,7 @@ typedef struct _jl_module_t {
jl_sym_t *file;
int32_t line;
jl_value_t *usings_backedges;
jl_value_t *scanned_methods;
// hidden fields:
arraylist_t usings; /* arraylist of struct jl_module_using */ // modules with all bindings potentially imported
jl_uuid_t build_id;
Expand Down Expand Up @@ -2059,6 +2060,7 @@ JL_DLLEXPORT int jl_get_module_infer(jl_module_t *m);
JL_DLLEXPORT void jl_set_module_max_methods(jl_module_t *self, int value);
JL_DLLEXPORT int jl_get_module_max_methods(jl_module_t *m);
JL_DLLEXPORT jl_value_t *jl_get_module_usings_backedges(jl_module_t *m);
JL_DLLEXPORT jl_value_t *jl_get_module_scanned_methods(jl_module_t *m);
JL_DLLEXPORT jl_value_t *jl_get_module_binding_or_nothing(jl_module_t *m, jl_sym_t *s);

// get binding for reading
Expand Down
2 changes: 1 addition & 1 deletion src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ jl_code_info_t *jl_new_code_info_from_ir(jl_expr_t *ast);
JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void);
JL_DLLEXPORT void jl_resolve_definition_effects_in_ir(jl_array_t *stmts, jl_module_t *m, jl_svec_t *sparam_vals, jl_value_t *binding_edge,
int binding_effects);
JL_DLLEXPORT void jl_maybe_add_binding_backedge(jl_globalref_t *gr, jl_module_t *defining_module, jl_value_t *edge);
JL_DLLEXPORT int jl_maybe_add_binding_backedge(jl_globalref_t *gr, jl_module_t *defining_module, jl_value_t *edge);
JL_DLLEXPORT void jl_add_binding_backedge(jl_binding_t *b, jl_value_t *edge);

int get_next_edge(jl_array_t *list, int i, jl_value_t** invokesig, jl_code_instance_t **caller) JL_NOTSAFEPOINT;
Expand Down
19 changes: 16 additions & 3 deletions src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,20 @@ static void check_c_types(const char *where, jl_value_t *rt, jl_value_t *at)
}
}

static void jl_add_scanned_method(jl_module_t *m, jl_method_t *meth)
{
JL_LOCK(&m->lock);
if (m->scanned_methods == jl_nothing) {
m->scanned_methods = (jl_value_t*)jl_alloc_vec_any(0);
jl_gc_wb(m, m->scanned_methods);
}
jl_array_ptr_1d_push((jl_array_t*)m->scanned_methods, (jl_value_t*)meth);
JL_UNLOCK(&m->lock);
}

JL_DLLEXPORT void jl_scan_method_source_now(jl_method_t *m, jl_value_t *src)
{
if (!jl_atomic_load_relaxed(&m->did_scan_source)) {
if (!jl_atomic_fetch_or(&m->did_scan_source, 1)) {
jl_code_info_t *code = NULL;
JL_GC_PUSH1(&code);
if (!jl_is_code_info(src))
Expand All @@ -50,13 +61,15 @@ JL_DLLEXPORT void jl_scan_method_source_now(jl_method_t *m, jl_value_t *src)
code = (jl_code_info_t*)src;
jl_array_t *stmts = code->code;
size_t i, l = jl_array_nrows(stmts);
int any_implicit = 0;
for (i = 0; i < l; i++) {
jl_value_t *stmt = jl_array_ptr_ref(stmts, i);
if (jl_is_globalref(stmt)) {
jl_maybe_add_binding_backedge((jl_globalref_t*)stmt, m->module, (jl_value_t*)m);
any_implicit |= jl_maybe_add_binding_backedge((jl_globalref_t*)stmt, m->module, (jl_value_t*)m);
}
}
jl_atomic_store_relaxed(&m->did_scan_source, 1);
if (any_implicit)
jl_add_scanned_method(m->module, m);
JL_GC_POP();
}
}
Expand Down
24 changes: 21 additions & 3 deletions src/module.c
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ JL_DLLEXPORT jl_module_t *jl_new_module__(jl_sym_t *name, jl_module_t *parent)
m->build_id.hi = ~(uint64_t)0;
jl_atomic_store_relaxed(&m->counter, 1);
m->usings_backedges = jl_nothing;
m->scanned_methods = jl_nothing;
m->nospecialize = 0;
m->optlevel = -1;
m->compile = -1;
Expand Down Expand Up @@ -1163,6 +1164,22 @@ JL_DLLEXPORT jl_value_t *jl_get_module_usings_backedges(jl_module_t *m)
return m->usings_backedges;
}

JL_DLLEXPORT size_t jl_module_scanned_methods_length(jl_module_t *m)
{
JL_LOCK(&m->lock);
size_t len = jl_array_len(m->scanned_methods);
JL_UNLOCK(&m->lock);
return len;
}

JL_DLLEXPORT jl_value_t *jl_module_scanned_methods_getindex(jl_module_t *m, size_t i)
{
JL_LOCK(&m->lock);
jl_value_t *ret = jl_array_ptr_ref(m->scanned_methods, i-1);
JL_UNLOCK(&m->lock);
return ret;
}

JL_DLLEXPORT jl_value_t *jl_get_module_binding_or_nothing(jl_module_t *m, jl_sym_t *s)
{
jl_binding_t *b = jl_get_module_binding(m, s, 0);
Expand Down Expand Up @@ -1369,10 +1386,10 @@ JL_DLLEXPORT void jl_add_binding_backedge(jl_binding_t *b, jl_value_t *edge)

// Called for all GlobalRefs found in lowered code. Adds backedges for cross-module
// GlobalRefs.
JL_DLLEXPORT void jl_maybe_add_binding_backedge(jl_globalref_t *gr, jl_module_t *defining_module, jl_value_t *edge)
JL_DLLEXPORT int jl_maybe_add_binding_backedge(jl_globalref_t *gr, jl_module_t *defining_module, jl_value_t *edge)
{
if (!edge)
return;
return 0;
jl_binding_t *b = gr->binding;
if (!b)
b = jl_get_module_binding(gr->mod, gr->name, 1);
Expand All @@ -1381,9 +1398,10 @@ JL_DLLEXPORT void jl_maybe_add_binding_backedge(jl_globalref_t *gr, jl_module_t
if (gr->mod == defining_module) {
// No backedge required - invalidation will forward scan
jl_atomic_fetch_or(&b->flags, BINDING_FLAG_ANY_IMPLICIT_EDGES);
return;
return 1;
}
jl_add_binding_backedge(b, edge);
return 0;
}

JL_DLLEXPORT jl_binding_partition_t *jl_replace_binding_locked(jl_binding_t *b,
Expand Down
4 changes: 4 additions & 0 deletions src/staticdata.c
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,7 @@ static void jl_queue_module_for_serialization(jl_serializer_state *s, jl_module_
}

jl_queue_for_serialization(s, m->usings_backedges);
jl_queue_for_serialization(s, m->scanned_methods);
}

// Anything that requires uniquing or fixing during deserialization needs to be "toplevel"
Expand Down Expand Up @@ -1324,6 +1325,9 @@ static void jl_write_module(jl_serializer_state *s, uintptr_t item, jl_module_t
newm->usings_backedges = NULL;
arraylist_push(&s->relocs_list, (void*)(reloc_offset + offsetof(jl_module_t, usings_backedges)));
arraylist_push(&s->relocs_list, (void*)backref_id(s, m->usings_backedges, s->link_ids_relocs));
newm->scanned_methods = NULL;
arraylist_push(&s->relocs_list, (void*)(reloc_offset + offsetof(jl_module_t, scanned_methods)));
arraylist_push(&s->relocs_list, (void*)backref_id(s, m->scanned_methods, s->link_ids_relocs));

// After reload, everything that has happened in this process happened semantically at
// (for .incremental) or before jl_require_world, so reset this flag.
Expand Down

0 comments on commit a12e1cc

Please sign in to comment.