Skip to content

Commit

Permalink
Merge pull request #39606 from JuliaLang/jn/inference-widenreturn
Browse files Browse the repository at this point in the history
Fixes some inference issues
  • Loading branch information
vtjnash authored Feb 16, 2021
2 parents 6425845 + 55eca20 commit bbdf1cf
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 17 deletions.
47 changes: 34 additions & 13 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1202,8 +1202,8 @@ function collect_argtypes(interp::AbstractInterpreter, ea::Vector{Any}, vtypes::
argtypes = Vector{Any}(undef, n)
@inbounds for i = 1:n
ai = abstract_eval_value(interp, ea[i], vtypes, sv)
if bail_out_statement(interp, ai, sv)
return Bottom
if ai === Bottom
return nothing
end
argtypes[i] = ai
end
Expand All @@ -1218,7 +1218,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
if e.head === :call
ea = e.args
argtypes = collect_argtypes(interp, ea, vtypes, sv)
if argtypes === Bottom
if argtypes === nothing
t = Bottom
else
callinfo = abstract_call(interp, ea, argtypes, sv)
Expand Down Expand Up @@ -1280,7 +1280,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
if length(e.args) >= 5
ea = e.args
argtypes = collect_argtypes(interp, ea, vtypes, sv)
if argtypes === Bottom
if argtypes === nothing
t = Bottom
else
t = _opaque_closure_tfunc(argtypes[1], argtypes[2], argtypes[3],
Expand Down Expand Up @@ -1376,6 +1376,31 @@ function abstract_eval_ssavalue(s::SSAValue, src::CodeInfo)
return typ
end

function widenreturn(@nospecialize rt)
# only propagate information we know we can store
# and is valid and good inter-procedurally
rt = widenconditional(rt)
isa(rt, Const) && return rt
isa(rt, Type) && return rt
if isa(rt, PartialStruct)
fields = copy(rt.fields)
haveconst = false
for i in 1:length(fields)
a = widenreturn(fields[i])
if !haveconst && has_const_info(a)
# TODO: consider adding && const_prop_profitable(a) here?
haveconst = true
end
fields[i] = a
end
haveconst && return PartialStruct(rt.typ, fields)
end
if isa(rt, PartialOpaque)
return rt # XXX: this case was missed in #39512
end
return widenconst(rt)
end

# make as much progress on `frame` as possible (without handling cycles)
function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
@assert !frame.inferred
Expand All @@ -1399,6 +1424,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
frame.cur_hand = frame.handler_at[pc]
edges = frame.stmt_edges[pc]
edges === nothing || empty!(edges)
frame.stmt_info[pc] = nothing
stmt = frame.src.code[pc]
changes = s[pc]::VarTable
t = nothing
Expand All @@ -1415,7 +1441,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
if condt === Bottom
empty!(frame.pclimitations)
end
if bail_out_local(interp, condt, frame)
if condt === Bottom
break
end
condval = maybe_extract_const_bool(condt)
Expand Down Expand Up @@ -1455,12 +1481,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
end
elseif isa(stmt, ReturnNode)
pc´ = n + 1
rt = widenconditional(abstract_eval_value(interp, stmt.val, changes, frame))
if !isa(rt, Const) && !isa(rt, Type) && !isa(rt, PartialStruct) && !isa(rt, PartialOpaque)
# only propagate information we know we can store
# and is valid inter-procedurally
rt = widenconst(rt)
end
rt = widenreturn(abstract_eval_value(interp, stmt.val, changes, frame))
# copy limitations to return value
if !isempty(frame.pclimitations)
union!(frame.limitations, frame.pclimitations)
Expand Down Expand Up @@ -1506,7 +1527,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
else
if hd === :(=)
t = abstract_eval_statement(interp, stmt.args[2], changes, frame)
if bail_out_local(interp, t, frame)
if t === Bottom
break
end
frame.src.ssavaluetypes[pc] = t
Expand All @@ -1523,7 +1544,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
# these do not generate code
else
t = abstract_eval_statement(interp, stmt, changes, frame)
if bail_out_local(interp, t, frame)
if t === Bottom
break
end
if !isempty(frame.ssavalue_uses[pc])
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1318,7 +1318,7 @@ function tuple_tfunc(atypes::Vector{Any})
x = atypes[i]
# TODO ignore singleton Const (don't forget to update cache logic if you implement this)
if !anyinfo
anyinfo = (!isa(x, Type) && !isvarargtype(x)) || isType(x)
anyinfo = has_const_info(x)
end
if isa(x, Const)
params[i] = typeof(x.val)
Expand Down
2 changes: 0 additions & 2 deletions base/compiler/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,6 @@ method_table(ai::AbstractInterpreter) = InternalMethodTable(get_world_counter(ai
# - inferring non-concrete toplevel call sites
bail_out_call(interp::AbstractInterpreter, @nospecialize(t), sv) = t === Any
bail_out_apply(interp::AbstractInterpreter, @nospecialize(t), sv) = t === Any
bail_out_statement(interp::AbstractInterpreter, @nospecialize(t), sv) = t === Bottom
bail_out_local(interp::AbstractInterpreter, @nospecialize(t), sv) = t === Bottom
function bail_out_toplevel_call(interp::AbstractInterpreter, @nospecialize(sig), sv)
return isa(sv.linfo.def, Module) && !isdispatchtuple(sig)
end
2 changes: 2 additions & 0 deletions base/compiler/typeutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ function has_nontrivial_const_info(@nospecialize t)
return !isdefined(typeof(val), :instance) && !(isa(val, Type) && hasuniquerep(val))
end

has_const_info(@nospecialize x) = (!isa(x, Type) && !isvarargtype(x)) || isType(x)

# Subtyping currently intentionally answers certain queries incorrectly for kind types. For
# some of these queries, this check can be used to somewhat protect against making incorrect
# decisions based on incorrect subtyping. Note that this check, itself, is broken for
Expand Down
2 changes: 1 addition & 1 deletion src/jlapi.c
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ static int exec_program(char *program)
int shown_err = 0;
jl_printf(JL_STDERR, "error during bootstrap:\n");
jl_value_t *exc = jl_current_exception();
jl_value_t *showf = jl_get_function(jl_base_module, "show");
jl_value_t *showf = jl_base_module ? jl_get_function(jl_base_module, "show") : NULL;
if (showf) {
jl_value_t *errs = jl_stderr_obj();
if (errs) {
Expand Down

0 comments on commit bbdf1cf

Please sign in to comment.