Skip to content

Commit

Permalink
Backport inference fixes (#216)
Browse files Browse the repository at this point in the history
* inference: avoid inferring unreachable code methods (JuliaLang#51317)

(cherry picked from commit 0a82b71)

* inference: ensure inferring reachable code methods (JuliaLang#57088)

PR JuliaLang#51317 was a bit over-eager about inferring inferring unreachable
code methods. Filter out the Vararg case, since that can be handled by
simply removing it instead of discarding the whole call.

Fixes JuliaLang#56628

(cherry picked from commit eb9f24c)

---------

Co-authored-by: Jameson Nash <vtjnash@gmail.com>
  • Loading branch information
2 people authored and nickrobinson251 committed Feb 26, 2025
1 parent 181d6fb commit bb2cd86
Show file tree
Hide file tree
Showing 10 changed files with 158 additions and 123 deletions.
59 changes: 37 additions & 22 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ function find_matching_methods(𝕃::AbstractLattice,
for i in 1:length(split_argtypes)
arg_n = split_argtypes[i]::Vector{Any}
sig_n = argtypes_to_type(arg_n)
sig_n === Bottom && continue
mt = ccall(:jl_method_table_for, Any, (Any,), sig_n)
mt === nothing && return FailedMethodMatch("Could not identify method table for call")
mt = mt::MethodTable
Expand Down Expand Up @@ -506,7 +507,10 @@ function abstract_call_method(interp::AbstractInterpreter,
return MethodCallResult(Any, false, false, nothing, Effects())
end
sigtuple = unwrap_unionall(sig)
sigtuple isa DataType || return MethodCallResult(Any, false, false, nothing, Effects())
sigtuple isa DataType ||
return MethodCallResult(Any, false, false, nothing, Effects())
all(@nospecialize(x) -> isvarargtype(x) || valid_as_lattice(x, true), sigtuple.parameters) ||
return MethodCallResult(Union{}, false, false, nothing, EFFECTS_THROWS) # catch bad type intersections early

if is_nospecializeinfer(method)
sig = get_nospecializeinfer_sig(method, sig, sparams)
Expand Down Expand Up @@ -1385,25 +1389,35 @@ function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft)
end
if isa(tti, Union)
utis = uniontypes(tti)
if any(@nospecialize(t) -> !isa(t, DataType) || !(t <: Tuple) || !isknownlength(t), utis)
return AbstractIterationResult(Any[Vararg{Any}], nothing, Effects())
end
ltp = length((utis[1]::DataType).parameters)
for t in utis
if length((t::DataType).parameters) != ltp
return AbstractIterationResult(Any[Vararg{Any}], nothing)
# refine the Union to remove elements that are not valid tags for objects
filter!(@nospecialize(x) -> valid_as_lattice(x, true), utis)
if length(utis) == 0
return AbstractIterationResult(Any[], nothing) # oops, this statement was actually unreachable
elseif length(utis) == 1
tti = utis[1]
tti0 = rewrap_unionall(tti, tti0)
else
if any(@nospecialize(t) -> !isa(t, DataType) || !(t <: Tuple) || !isknownlength(t), utis)
return AbstractIterationResult(Any[Vararg{Any}], nothing, Effects())
end
end
result = Any[ Union{} for _ in 1:ltp ]
for t in utis
tps = (t::DataType).parameters
_all(valid_as_lattice, tps) || continue
for j in 1:ltp
result[j] = tmerge(result[j], rewrap_unionall(tps[j], tti0))
ltp = length((utis[1]::DataType).parameters)
for t in utis
if length((t::DataType).parameters) != ltp
return AbstractIterationResult(Any[Vararg{Any}], nothing)
end
end
result = Any[ Union{} for _ in 1:ltp ]
for t in utis
tps = (t::DataType).parameters
for j in 1:ltp
@assert valid_as_lattice(tps[j], true)
result[j] = tmerge(result[j], rewrap_unionall(tps[j], tti0))
end
end
return AbstractIterationResult(result, nothing)
end
return AbstractIterationResult(result, nothing)
elseif tti0 <: Tuple
end
if tti0 <: Tuple
if isa(tti0, DataType)
return AbstractIterationResult(Any[ p for p in tti0.parameters ], nothing)
elseif !isa(tti, DataType)
Expand Down Expand Up @@ -1667,7 +1681,7 @@ end
return isa_condition(xt, ty, max_union_splitting)
end
@inline function isa_condition(@nospecialize(xt), @nospecialize(ty), max_union_splitting::Int)
tty_ub, isexact_tty = instanceof_tfunc(ty)
tty_ub, isexact_tty = instanceof_tfunc(ty, true)
tty = widenconst(xt)
if isexact_tty && !isa(tty_ub, TypeVar)
tty_lb = tty_ub # TODO: this would be wrong if !isexact_tty, but instanceof_tfunc doesn't preserve this info
Expand All @@ -1677,7 +1691,7 @@ end
# `typeintersect` may be unable narrow down `Type`-type
thentype = tty_ub
end
valid_as_lattice(thentype) || (thentype = Bottom)
valid_as_lattice(thentype, true) || (thentype = Bottom)
elsetype = typesubtract(tty, tty_lb, max_union_splitting)
return ConditionalTypes(thentype, elsetype)
end
Expand Down Expand Up @@ -1923,7 +1937,7 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
ft′ = argtype_by_index(argtypes, 2)
ft = widenconst(ft′)
ft === Bottom && return CallMeta(Bottom, EFFECTS_THROWS, NoCallInfo())
(types, isexact, isconcrete, istype) = instanceof_tfunc(argtype_by_index(argtypes, 3))
(types, isexact, isconcrete, istype) = instanceof_tfunc(argtype_by_index(argtypes, 3), false)
isexact || return CallMeta(Any, Effects(), NoCallInfo())
unwrapped = unwrap_unionall(types)
if types === Bottom || !(unwrapped isa DataType) || unwrapped.name !== Tuple.name
Expand Down Expand Up @@ -2153,6 +2167,7 @@ function abstract_call_unknown(interp::AbstractInterpreter, @nospecialize(ft),
end
# non-constant function, but the number of arguments is known and the `f` is not a builtin or intrinsic
atype = argtypes_to_type(arginfo.argtypes)
atype === Bottom && return CallMeta(Union{}, Union{}, EFFECTS_THROWS, NoCallInfo()) # accidentally unreachable
return abstract_call_gf_by_type(interp, nothing, arginfo, si, atype, sv, max_methods)
end

Expand Down Expand Up @@ -2380,7 +2395,7 @@ function abstract_eval_statement_expr(interp::AbstractInterpreter, e::Expr, vtyp
(; rt, effects) = abstract_eval_call(interp, e, vtypes, sv)
t = rt
elseif ehead === :new
t, isexact = instanceof_tfunc(abstract_eval_value(interp, e.args[1], vtypes, sv))
t, isexact = instanceof_tfunc(abstract_eval_value(interp, e.args[1], vtypes, sv), true)
ut = unwrap_unionall(t)
consistent = ALWAYS_FALSE
nothrow = false
Expand Down Expand Up @@ -2444,7 +2459,7 @@ function abstract_eval_statement_expr(interp::AbstractInterpreter, e::Expr, vtyp
end
effects = Effects(EFFECTS_TOTAL; consistent, nothrow)
elseif ehead === :splatnew
t, isexact = instanceof_tfunc(abstract_eval_value(interp, e.args[1], vtypes, sv))
t, isexact = instanceof_tfunc(abstract_eval_value(interp, e.args[1], vtypes, sv), true)
nothrow = false # TODO: More precision
if length(e.args) == 2 && isconcretedispatch(t) && !ismutabletype(t)
at = abstract_eval_value(interp, e.args[2], vtypes, sv)
Expand Down
8 changes: 5 additions & 3 deletions base/compiler/abstractlattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,16 +98,18 @@ is_valid_lattice_norec(::InferenceLattice, @nospecialize(elem)) = isa(elem, Limi
"""
tmeet(𝕃::AbstractLattice, a, b::Type)
Compute the lattice meet of lattice elements `a` and `b` over the lattice `𝕃`.
If `𝕃` is `JLTypeLattice`, this is equivalent to type intersection.
Compute the lattice meet of lattice elements `a` and `b` over the lattice `𝕃`,
dropping any results that will not be inhabited at runtime.
If `𝕃` is `JLTypeLattice`, this is equivalent to type intersection plus the
elimination of results that have no concrete subtypes.
Note that currently `b` is restricted to being a type
(interpreted as a lattice element in the `JLTypeLattice` sub-lattice of `𝕃`).
"""
function tmeet end

function tmeet(::JLTypeLattice, @nospecialize(a::Type), @nospecialize(b::Type))
ti = typeintersect(a, b)
valid_as_lattice(ti) || return Bottom
valid_as_lattice(ti, true) || return Bottom
return ti
end

Expand Down
2 changes: 1 addition & 1 deletion base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ function stmt_effect_flags(𝕃ₒ::AbstractLattice, @nospecialize(stmt), @nospe
elseif head === :new_opaque_closure
length(args) < 4 && return (false, false, false)
typ = argextype(args[1], src)
typ, isexact = instanceof_tfunc(typ)
typ, isexact = instanceof_tfunc(typ, true)
isexact || return (false, false, false)
(𝕃ₒ, typ, Tuple) || return (false, false, false)
rt_lb = argextype(args[2], src)
Expand Down
7 changes: 4 additions & 3 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1202,7 +1202,7 @@ function handle_invoke_call!(todo::Vector{Pair{Int,Any}},
end

function invoke_signature(argtypes::Vector{Any})
ft, argtyps = widenconst(argtypes[2]), instanceof_tfunc(widenconst(argtypes[3]))[1]
ft, argtyps = widenconst(argtypes[2]), instanceof_tfunc(widenconst(argtypes[3]), false)[1]
return rewrap_unionall(Tuple{ft, unwrap_unionall(argtyps).parameters...}, argtyps)
end

Expand Down Expand Up @@ -1450,8 +1450,9 @@ function handle_call!(todo::Vector{Pair{Int,Any}},
cases = compute_inlining_cases(info, flag, sig, state)
cases === nothing && return nothing
cases, all_covered, joint_effects = cases
handle_cases!(todo, ir, idx, stmt, argtypes_to_type(sig.argtypes), cases,
all_covered, joint_effects)
atype = argtypes_to_type(sig.argtypes)
atype === Union{} && return nothing # accidentally actually unreachable
handle_cases!(todo, ir, idx, stmt, atype, cases, all_covered, joint_effects)
end

function handle_match!(cases::Vector{InliningCase},
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1705,7 +1705,7 @@ function adce_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
else
if is_known_call(stmt, typeassert, compact) && length(stmt.args) == 3
# nullify safe `typeassert` calls
ty, isexact = instanceof_tfunc(argextype(stmt.args[3], compact))
ty, isexact = instanceof_tfunc(argextype(stmt.args[3], compact), true)
if isexact && (𝕃ₒ, argextype(stmt.args[2], compact), ty)
compact[idx] = nothing
continue
Expand Down
Loading

0 comments on commit bb2cd86

Please sign in to comment.