Skip to content

Commit

Permalink
bpart: Turn on invalidation for guard->defined transitions
Browse files Browse the repository at this point in the history
This addresses one of the last remaining TODOs of the binding partition
work by performing invalidations when bindings transition from being
undefined to being defined. This in particular finally addresses the
performance issue that #54733 was intended to address (the issue was
closed when we merged the mechanism, but it had so far been turned
off). Turning on the invalidations themselves were always easy (a
one line deletion). What is harder is making sure that the additional
invalidations don't take extra time.

To this end, we add two additional flags, one on Bindings, and one on
methods. The flag on bindings tells us whether any method scan has so
far found an implicit (not tracked in ->backedges) reference to this
binding in any method body. The insight here is that most undefined
bindings will not have been referenced previously (because they did
not exist), so with a simple one bit saturating counter of the number
of edges that would exist (if we did store them), we can fast-path
the invalidation.

However, this is not quite sufficient, as people often do things like:
```
foo() = bar()
bar() = ...
...
```
which, without further improvements would incur an invalidation upon
the definition of `bar`.

The second insight (and what the flag on `Method` is for) is that we
don't actually need to scan the method body until there is something
to invalidate (i.e. until some `CodeInstance` has been created for
the method). By defering the scanning until the first time that inference
accesses the lowered code (with a flag to only do it once), we can
easily avoid invalidation in the above scenario (while still invalidating
if `foo()` was called before the definition of `bar`).

As a further bonus, this also speeds up bootstrap by about 20% (putting
us about back to where we used to be before the full bpart change) by
skipping unnecessary invalidations even for non-guard transitions.

Finally, this does not yet turn on inference's ability to infer guard
partitions as `Union{}`. The reason for this is that such partitions
can be replaced by backdated constants without invalidation. However,
as soon as we remove the backdated const mechanism, this PR will allow
us to turn on that change, further speeding up inference (by cutting off
inference on branches known to error due to missing bindings).
  • Loading branch information
Keno committed Mar 3, 2025
1 parent ae2914a commit 4a34be6
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 28 deletions.
19 changes: 19 additions & 0 deletions Compiler/src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,25 @@ function retrieve_code_info(mi::MethodInstance, world::UInt)
else
c = copy(src::CodeInfo)
end
if !def.did_scan_source
# This scan must happen:
# 1. After method definition
# 2. Before any code instances that may have relied on information
# from implicit GlobalRefs for this method are added to the cache
# 3. Preferably while the IR is already uncompressed
# 4. As late as possible, as early adding of the backedges may cause
# spurious invalidations.
#
# At the moment we do so here, because
# 1. It's reasonably late
# 2. It has easy access to the uncompressed IR
# 3. We necessarily pass through here before relying on any
# information obtained from implicit GlobalRefs.
#
# However, the exact placement of this scan is not as important as
# long as the above conditions are met.
ccall(:jl_scan_method_source_now, Cvoid, (Any, Any), def, c)
end
end
if c isa CodeInfo
c.parent = mi
Expand Down
4 changes: 4 additions & 0 deletions base/expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1352,6 +1352,10 @@ function make_atomic(order, ex)
op = :+
elseif ex.head === :(-=)
op = :-
elseif ex.head === :(|=)
op = :|
elseif ex.head === :(&=)
op = :&
elseif @isdefined string
shead = string(ex.head)
if endswith(shead, '=')
Expand Down
37 changes: 20 additions & 17 deletions base/invalidation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,28 +115,27 @@ end

function invalidate_code_for_globalref!(b::Core.Binding, invalidated_bpart::Core.BindingPartition, new_bpart::Union{Core.BindingPartition, Nothing}, new_max_world::UInt)
gr = b.globalref
if !is_some_guard(binding_kind(invalidated_bpart))
# TODO: We may want to invalidate for these anyway, since they have performance implications
if (b.flags & BINDING_FLAG_ANY_IMPLICIT_EDGES) != 0
foreach_module_mtable(gr.mod, new_max_world) do mt::Core.MethodTable
for method in MethodList(mt)
invalidate_method_for_globalref!(gr, method, invalidated_bpart, new_max_world)
end
return true
end
if isdefined(b, :backedges)
for edge in b.backedges
if isa(edge, CodeInstance)
ccall(:jl_invalidate_code_instance, Cvoid, (Any, UInt), edge, new_max_world)
elseif isa(edge, Core.Binding)
isdefined(edge, :partitions) || continue
latest_bpart = edge.partitions
latest_bpart.max_world == typemax(UInt) || continue
is_some_imported(binding_kind(latest_bpart)) || continue
partition_restriction(latest_bpart) === b || continue
invalidate_code_for_globalref!(edge, latest_bpart, nothing, new_max_world)
else
invalidate_method_for_globalref!(gr, edge::Method, invalidated_bpart, new_max_world)
end
end
if isdefined(b, :backedges)
for edge in b.backedges
if isa(edge, CodeInstance)
ccall(:jl_invalidate_code_instance, Cvoid, (Any, UInt), edge, new_max_world)
elseif isa(edge, Core.Binding)
isdefined(edge, :partitions) || continue
latest_bpart = edge.partitions
latest_bpart.max_world == typemax(UInt) || continue
is_some_imported(binding_kind(latest_bpart)) || continue
partition_restriction(latest_bpart) === b || continue
invalidate_code_for_globalref!(edge, latest_bpart, nothing, new_max_world)
else
invalidate_method_for_globalref!(gr, edge::Method, invalidated_bpart, new_max_world)
end
end
end
Expand Down Expand Up @@ -166,7 +165,11 @@ gr_needs_backedge_in_module(gr::GlobalRef, mod::Module) = gr.mod !== mod
# N.B.: This needs to match jl_maybe_add_binding_backedge
function maybe_add_binding_backedge!(b::Core.Binding, edge::Union{Method, CodeInstance})
method = isa(edge, Method) ? edge : edge.def.def::Method
gr_needs_backedge_in_module(b.globalref, method.module) || return
methmod = method.module
if !gr_needs_backedge_in_module(b.globalref, methmod)
@atomic :acquire_release b.flags |= BINDING_FLAG_ANY_IMPLICIT_EDGES
return
end
if !isdefined(b, :backedges)
b.backedges = Any[]
end
Expand Down
2 changes: 2 additions & 0 deletions base/runtime_internals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ const PARTITION_FLAG_DEPWARN = 0x40
const PARTITION_MASK_KIND = 0x0f
const PARTITION_MASK_FLAG = 0xf0

const BINDING_FLAG_ANY_IMPLICIT_EDGES = 0x8

is_defined_const_binding(kind::UInt8) = (kind == PARTITION_KIND_CONST || kind == PARTITION_KIND_CONST_IMPORT || kind == PARTITION_KIND_BACKDATED_CONST)
is_some_const_binding(kind::UInt8) = (is_defined_const_binding(kind) || kind == PARTITION_KIND_UNDEF_CONST)
is_some_imported(kind::UInt8) = (kind == PARTITION_KIND_IMPLICIT || kind == PARTITION_KIND_EXPLICIT || kind == PARTITION_KIND_IMPORTED)
Expand Down
8 changes: 5 additions & 3 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -3275,7 +3275,7 @@ void jl_init_types(void) JL_GC_DISABLED
jl_svec(5, jl_any_type/*jl_globalref_type*/, jl_any_type, jl_binding_partition_type,
jl_any_type, jl_uint8_type),
jl_emptysvec, 0, 1, 0);
const static uint32_t binding_atomicfields[] = { 0x0005 }; // Set fields 2, 3 as atomic
const static uint32_t binding_atomicfields[] = { 0x0016 }; // Set fields 2, 3, 5 as atomic
jl_binding_type->name->atomicfields = binding_atomicfields;
const static uint32_t binding_constfields[] = { 0x0001 }; // Set fields 1 as constant
jl_binding_type->name->constfields = binding_constfields;
Expand Down Expand Up @@ -3539,7 +3539,7 @@ void jl_init_types(void) JL_GC_DISABLED
jl_method_type =
jl_new_datatype(jl_symbol("Method"), core,
jl_any_type, jl_emptysvec,
jl_perm_symsvec(31,
jl_perm_symsvec(32,
"name",
"module",
"file",
Expand Down Expand Up @@ -3568,10 +3568,11 @@ void jl_init_types(void) JL_GC_DISABLED
"isva",
"is_for_opaque_closure",
"nospecializeinfer",
"did_scan_source",
"constprop",
"max_varargs",
"purity"),
jl_svec(31,
jl_svec(32,
jl_symbol_type,
jl_module_type,
jl_symbol_type,
Expand Down Expand Up @@ -3600,6 +3601,7 @@ void jl_init_types(void) JL_GC_DISABLED
jl_bool_type,
jl_bool_type,
jl_bool_type,
jl_bool_type,
jl_uint8_type,
jl_uint8_type,
jl_uint16_type),
Expand Down
7 changes: 6 additions & 1 deletion src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,8 @@ typedef struct _jl_method_t {
uint8_t isva;
uint8_t is_for_opaque_closure;
uint8_t nospecializeinfer;
_Atomic(uint8_t) did_scan_source;

// uint8 settings
uint8_t constprop; // 0x00 = use heuristic; 0x01 = aggressive; 0x02 = none
uint8_t max_varargs; // 0xFF = use heuristic; otherwise, max # of args to expand
Expand Down Expand Up @@ -751,7 +753,10 @@ enum jl_binding_flags {
BINDING_FLAG_DID_PRINT_BACKDATE_ADMONITION = 0x1,
BINDING_FLAG_DID_PRINT_IMPLICIT_IMPORT_ADMONITION = 0x2,
// `export` is tracked in partitions, but sets this as well
BINDING_FLAG_PUBLICP = 0x4
BINDING_FLAG_PUBLICP = 0x4,
// Set if any methods defined in this module implicitly reference
// this binding. If not, invalidation is optimized.
BINDING_FLAG_ANY_IMPLICIT_EDGES = 0x8
};

typedef struct _jl_binding_t {
Expand Down
28 changes: 24 additions & 4 deletions src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,28 @@ static void check_c_types(const char *where, jl_value_t *rt, jl_value_t *at)
}
}

JL_DLLEXPORT void jl_scan_method_source_now(jl_method_t *m, jl_value_t *src)
{
if (!jl_atomic_load_relaxed(&m->did_scan_source)) {
jl_code_info_t *code = NULL;
JL_GC_PUSH1(&code);
if (!jl_is_code_info(src))
code = jl_uncompress_ir(m, NULL, src);
else
code = (jl_code_info_t*)src;
jl_array_t *stmts = code->code;
size_t i, l = jl_array_nrows(stmts);
for (i = 0; i < l; i++) {
jl_value_t *stmt = jl_array_ptr_ref(stmts, i);
if (jl_is_globalref(stmt)) {
jl_maybe_add_binding_backedge((jl_globalref_t*)stmt, m->module, (jl_value_t*)m);
}
}
jl_atomic_store_relaxed(&m->did_scan_source, 1);
JL_GC_POP();
}
}

// Resolve references to non-locally-defined variables to become references to global
// variables in `module` (unless the rvalue is one of the type parameters in `sparam_vals`).
static jl_value_t *resolve_definition_effects(jl_value_t *expr, jl_module_t *module, jl_svec_t *sparam_vals, jl_value_t *binding_edge,
Expand All @@ -47,10 +69,7 @@ static jl_value_t *resolve_definition_effects(jl_value_t *expr, jl_module_t *mod
if (jl_is_symbol(expr)) {
jl_error("Found raw symbol in code returned from lowering. Expected all symbols to have been resolved to GlobalRef or slots.");
}
if (jl_is_globalref(expr)) {
jl_maybe_add_binding_backedge((jl_globalref_t*)expr, module, binding_edge);
return expr;
}

if (!jl_is_expr(expr)) {
return expr;
}
Expand Down Expand Up @@ -973,6 +992,7 @@ JL_DLLEXPORT jl_method_t *jl_new_method_uninit(jl_module_t *module)
jl_atomic_store_relaxed(&m->deleted_world, 1);
m->is_for_opaque_closure = 0;
m->nospecializeinfer = 0;
jl_atomic_store_relaxed(&m->did_scan_source, 0);
m->constprop = 0;
m->purity.bits = 0;
m->max_varargs = UINT8_MAX;
Expand Down
7 changes: 4 additions & 3 deletions src/module.c
Original file line number Diff line number Diff line change
Expand Up @@ -1345,15 +1345,16 @@ JL_DLLEXPORT void jl_maybe_add_binding_backedge(jl_globalref_t *gr, jl_module_t
{
if (!edge)
return;
jl_binding_t *b = gr->binding;
if (!b)
b = jl_get_module_binding(gr->mod, gr->name, 1);
// N.B.: The logic for evaluating whether a backedge is required must
// match the invalidation logic.
if (gr->mod == defining_module) {
// No backedge required - invalidation will forward scan
jl_atomic_fetch_or(&b->flags, BINDING_FLAG_ANY_IMPLICIT_EDGES);
return;
}
jl_binding_t *b = gr->binding;
if (!b)
b = jl_get_module_binding(gr->mod, gr->name, 1);
jl_add_binding_backedge(b, edge);
}

Expand Down
20 changes: 20 additions & 0 deletions test/rebinding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,3 +298,23 @@ module RangeMerge

@test !contains(get_llvm(f, Tuple{}), "jl_get_binding_value")
end

# Test that we invalidate for undefined -> defined transitions (#54733)
module UndefinedTransitions
using Test
function foo54733()
for i = 1:1_000_000_000
bar54733(i)
end
return 1
end
@test_throws UndefVarError foo54733()
let ci = first(methods(foo54733)).specializations.cache
@test !Base.Compiler.is_nothrow(Base.Compiler.decode_effects(ci.ipo_purity_bits))
end
bar54733(x) = 3x
@test foo54733() === 1
let ci = first(methods(foo54733)).specializations.cache
@test Base.Compiler.is_nothrow(Base.Compiler.decode_effects(ci.ipo_purity_bits))
end
end

0 comments on commit 4a34be6

Please sign in to comment.