Skip to content

Commit

Permalink
Add PartialOpaque lattice element for OpaqueClosure
Browse files Browse the repository at this point in the history
This adds a lattice element for tracking OpaqueClosures in inference,
but does not yet do anything with it. The reason I'm separating
this out is that just the introduction of the lattice element
raises some tricky issues. In particular, the lattice element
refers back to the OpaqueClosure method, which we currently
don't support in the serializer. I played with several ways
of adding support for that, but in the end it all ended up
super complicated for questionable benefit, so in this PR,
CodeInstances that get inferred to `PartialOpaque` get
omitted during serialization (i.e. they will be reinfered
upon loading the .ji).
  • Loading branch information
Keno committed Feb 4, 2021
1 parent 7b19e09 commit d2ef8d0
Show file tree
Hide file tree
Showing 13 changed files with 190 additions and 38 deletions.
58 changes: 55 additions & 3 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ function const_prop_profitable(@nospecialize(arg))
const_prop_profitable(b) && return true
end
end
isa(arg, PartialOpaque) && return true
isa(arg, Const) || return true
val = arg.val
# don't consider mutable values or Strings useful constants
Expand Down Expand Up @@ -255,7 +256,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nosp
# see if any or all of the arguments are constant and propagating constants may be worthwhile
for a in argtypes
a = widenconditional(a)
if allconst && !isa(a, Const) && !isconstType(a) && !isa(a, PartialStruct)
if allconst && !isa(a, Const) && !isconstType(a) && !isa(a, PartialStruct) && !isa(a, PartialOpaque)
allconst = false
end
if !haveconst && has_nontrivial_const_info(a) && const_prop_profitable(a)
Expand Down Expand Up @@ -1044,6 +1045,31 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
return abstract_call_gf_by_type(interp, f, argtypes, atype, sv, max_methods)
end

function abstract_call_opaque_closure(interp::AbstractInterpreter, closure::PartialOpaque, argtypes::Vector{Any}, sv::InferenceState)
return CallMeta(Any, nothing)
end

function most_general_argtypes(closure::PartialOpaque)
ret = Any[]
cc = widenconst(closure)
argt = unwrap_unionall(cc).parameters[1]
@assert isa(argt, DataType) && argt.name === typename(Tuple)
params = argt.parameters
for i = 2:closure.source.nargs
rt = unwrapva(params[max(i-1, length(params))])
if closure.isva
if length(params) > i-1
for j = (i):length(params)
rt = tmerge(rt, unwrapva(params[j]))
end
end
rt = Vararg{rt}
end
push!(ret, rt)
end
ret
end

# call where the function is any lattice element
function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any},
sv::InferenceState, max_methods::Int = InferenceParams(interp).MAX_METHODS)
Expand All @@ -1055,10 +1081,14 @@ function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{
f = ft.parameters[1]
elseif isa(ft, DataType) && isdefined(ft, :instance)
f = ft.instance
elseif isa(ft, PartialOpaque)
return abstract_call_opaque_closure(interp, ft, argtypes, sv)
elseif isa(ft, DataType) && unwrap_unionall(ft).name === typename(Core.OpaqueClosure)
return CallMeta(rewrap_unionall(unwrap_unionall(ft).parameters[2], ft), false)
else
# non-constant function, but the number of arguments is known
# and the ft is not a Builtin or IntrinsicFunction
if typeintersect(widenconst(ft), Builtin) != Union{}
if typeintersect(widenconst(ft), Union{Builtin, Core.OpaqueClosure}) != Union{}
add_remark!(interp, sv, "Could not identify method table for call")
return CallMeta(Any, false)
end
Expand Down Expand Up @@ -1229,6 +1259,28 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
t = PartialStruct(t, at.fields)
end
end
elseif e.head === :new_opaque_closure
t = Union{}
if length(e.args) >= 5
ea = e.args
n = length(ea)
argtypes = Vector{Any}(undef, n)
@inbounds for i = 1:n
ai = abstract_eval_value(interp, ea[i], vtypes, sv)
if ai === Bottom
return Bottom
end
argtypes[i] = ai
end
t = _opaque_closure_tfunc(argtypes[1], argtypes[2], argtypes[3],
argtypes[4], argtypes[5], argtypes[6:end], sv.linfo)
if isa(t, PartialOpaque)
# Infer this now so that the specialization is available to
# optimization.
abstract_call_opaque_closure(interp, t,
most_general_argtypes(t), sv)
end
end
elseif e.head === :foreigncall
abstract_eval_value(interp, e.args[1], vtypes, sv)
t = sp_type_rewrap(e.args[2], sv.linfo, true)
Expand Down Expand Up @@ -1389,7 +1441,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
elseif isa(stmt, ReturnNode)
pc´ = n + 1
rt = widenconditional(abstract_eval_value(interp, stmt.val, s[pc], frame))
if !isa(rt, Const) && !isa(rt, Type) && !isa(rt, PartialStruct)
if !isa(rt, Const) && !isa(rt, Type) && !isa(rt, PartialStruct) && !isa(rt, PartialOpaque)
# only propagate information we know we can store
# and is valid inter-procedurally
rt = widenconst(rt)
Expand Down
21 changes: 21 additions & 0 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1362,6 +1362,27 @@ add_tfunc(arrayref, 3, INT_INF, arrayref_tfunc, 20)
add_tfunc(const_arrayref, 3, INT_INF, arrayref_tfunc, 20)
add_tfunc(arrayset, 4, INT_INF, (@nospecialize(boundscheck), @nospecialize(a), @nospecialize(v), @nospecialize i...)->a, 20)

function _opaque_closure_tfunc(@nospecialize(arg), @nospecialize(isva),
@nospecialize(lb), @nospecialize(ub), @nospecialize(source), env::Vector{Any},
linfo::MethodInstance)

argt, argt_exact = instanceof_tfunc(arg)
lbt, lb_exact = instanceof_tfunc(lb)
if !lb_exact
lbt = Union{}
end

ubt, ub_exact = instanceof_tfunc(ub)

t = argt_exact ? Core.OpaqueClosure{argt} : Core.OpaqueClosure{<:argt}
t = lbt == ubt ? t{ubt} : (t{T} where lbt <: T <: ubt)

isa(source, Const) || return t
(isa(isva, Const) && isa(isva.val, Bool)) || return t

return PartialOpaque(t, env, linfo, isva.val, source.val)
end

function array_type_undefable(@nospecialize(a))
if isa(a, Union)
return array_type_undefable(a.a) || array_type_undefable(a.b)
Expand Down
5 changes: 5 additions & 0 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,9 @@ function CodeInstance(result::InferenceResult, @nospecialize(inferred_result::An
if isa(result_type, Const)
rettype_const = result_type.val
const_flags = 0x2
elseif isa(result_type, PartialOpaque)
rettype_const = result_type
const_flags = 0x2
elseif isconstType(result_type)
rettype_const = result_type.parameters[1]
const_flags = 0x2
Expand Down Expand Up @@ -773,6 +776,8 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
if isdefined(code, :rettype_const)
if isa(code.rettype_const, Vector{Any}) && !(Vector{Any} <: code.rettype)
return PartialStruct(code.rettype, code.rettype_const), mi
elseif code.rettype <: Core.OpaqueClosure && isa(code.rettype_const, PartialOpaque)
return code.rettype_const, mi
else
return Const(code.rettype_const), mi
end
Expand Down
17 changes: 17 additions & 0 deletions base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ struct PartialTypeVar
PartialTypeVar(tv::TypeVar, lb_certain::Bool, ub_certain::Bool) = new(tv, lb_certain, ub_certain)
end

mutable struct PartialOpaque
t::Type
env_ts::Vector{Any}
parent::MethodInstance
isva::Bool
source::Method
end

# Wraps a type and represents that the value may also be undef at this point.
# (only used in optimize, not abstractinterpret)
# N.B. in the lattice, this is epsilon bigger than `typ` (even Any)
Expand Down Expand Up @@ -185,6 +193,14 @@ function ⊑(@nospecialize(a), @nospecialize(b))
end
return false
end
if isa(a, PartialOpaque)
if isa(b, PartialOpaque)
(a.parent === b.parent && a.source === b.source) || return false
return (widenconst(a) <: widenconst(b)) &&
(a.env, b.env)
end
return widenconst(a) <: widenconst(b)
end
if isa(a, Const)
if isa(b, Const)
return a.val === b.val
Expand Down Expand Up @@ -240,6 +256,7 @@ end
widenconst(m::MaybeUndef) = widenconst(m.typ)
widenconst(c::PartialTypeVar) = TypeVar
widenconst(t::PartialStruct) = t.typ
widenconst(t::PartialOpaque) = t.t
widenconst(t::Type) = t
widenconst(t::TypeVar) = t
widenconst(t::Core.TypeofVararg) = t
Expand Down
1 change: 1 addition & 0 deletions base/compiler/typeutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ end

function has_nontrivial_const_info(@nospecialize t)
isa(t, PartialStruct) && return true
isa(t, PartialOpaque) && return true
isa(t, Const) || return false
val = t.val
return !isdefined(typeof(val), :instance) && !(isa(val, Type) && hasuniquerep(val))
Expand Down
21 changes: 19 additions & 2 deletions base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,29 @@ function get_staged(li::MethodInstance)
end
end

function has_opaque_closure(c::CodeInfo)
for i = 1:length(c.code)
stmt = c.code[i]
(isa(stmt, Expr) && stmt.head === :new_opaque_closure) && return true
end
return false
end

function retrieve_code_info(linfo::MethodInstance)
m = linfo.def::Method
c = nothing
if isdefined(m, :generator)
# user code might throw errors – ignore them
c = get_staged(linfo)
if isdefined(linfo, :uninferred)
c = copy(linfo.uninferred::CodeInfo)
else
# user code might throw errors – ignore them
c = get_staged(linfo)
# For opaque closures, cache the generated code info to make sure
# that Opaque Closure method identity remains stable.
if c !== nothing && has_opaque_closure(c)
linfo.uninferred = copy(c)
end
end
end
if c === nothing && isdefined(m, :source)
src = m.source
Expand Down
1 change: 1 addition & 0 deletions src/builtins.c
Original file line number Diff line number Diff line change
Expand Up @@ -1611,6 +1611,7 @@ void jl_init_primitives(void) JL_GC_DISABLED
add_builtin("Argument", (jl_value_t*)jl_argument_type);
add_builtin("Const", (jl_value_t*)jl_const_type);
add_builtin("PartialStruct", (jl_value_t*)jl_partial_struct_type);
add_builtin("PartialOpaque", (jl_value_t*)jl_partial_opaque_type);
add_builtin("MethodMatch", (jl_value_t*)jl_method_match_type);
add_builtin("IntrinsicFunction", (jl_value_t*)jl_intrinsic_type);
add_builtin("Function", (jl_value_t*)jl_function_type);
Expand Down
96 changes: 63 additions & 33 deletions src/dump.c
Original file line number Diff line number Diff line change
Expand Up @@ -380,11 +380,11 @@ static void jl_serialize_module(jl_serializer_state *s, jl_module_t *m)
write_uint8(s->s, m->infer);
}

static void jl_serialize_value_(jl_serializer_state *s, jl_value_t *v, int as_literal) JL_GC_DISABLED
static inline int jl_serialize_generic(jl_serializer_state *s, jl_value_t *v)
{
if (v == NULL) {
write_uint8(s->s, TAG_NULL);
return;
return 1;
}

void *tag = ptrhash_get(&ser_tag, v);
Expand All @@ -393,28 +393,29 @@ static void jl_serialize_value_(jl_serializer_state *s, jl_value_t *v, int as_li
if (t8 <= LAST_TAG)
write_uint8(s->s, 0);
write_uint8(s->s, t8);
return;
return 1;
}

if (jl_is_symbol(v)) {
void *idx = ptrhash_get(&common_symbol_tag, v);
if (idx != HT_NOTFOUND) {
write_uint8(s->s, TAG_COMMONSYM);
write_uint8(s->s, (uint8_t)(size_t)idx);
return;
return 1;
}
}
else if (v == (jl_value_t*)jl_core_module) {
write_uint8(s->s, TAG_CORE);
return;
return 1;
}
else if (v == (jl_value_t*)jl_base_module) {
write_uint8(s->s, TAG_BASE);
return;
return 1;
}

if (jl_typeis(v, jl_string_type) && jl_string_len(v) == 0) {
jl_serialize_value(s, jl_an_empty_string);
return;
return 1;
}
else if (!jl_is_uint8(v)) {
void **bp = ptrhash_bp(&backref_table, v);
Expand All @@ -428,7 +429,7 @@ static void jl_serialize_value_(jl_serializer_state *s, jl_value_t *v, int as_li
write_uint8(s->s, TAG_BACKREF);
write_int32(s->s, pos);
}
return;
return 1;
}
intptr_t pos = backref_table_numel++;
if (((jl_datatype_t*)(jl_typeof(v)))->name == jl_idtable_typename) {
Expand All @@ -453,6 +454,57 @@ static void jl_serialize_value_(jl_serializer_state *s, jl_value_t *v, int as_li
ptrhash_put(&backref_table, v, (char*)HT_NOTFOUND + pos + 1);
}

return 0;
}

static void jl_serialize_code_instance(jl_serializer_state *s, jl_code_instance_t *codeinst, int skip_partial_opaque)
{
if (jl_serialize_generic(s, (jl_value_t*)codeinst)) {
return;
}

int validate = 0;
if (codeinst->max_world == ~(size_t)0)
validate = 1; // can check on deserialize if this cache entry is still valid
int flags = validate << 0;
if (codeinst->invoke == jl_fptr_const_return)
flags |= 1 << 2;
if (codeinst->precompile)
flags |= 1 << 3;

int write_ret_type = validate || codeinst->min_world == 0;
if (write_ret_type && codeinst->rettype_const &&
jl_typeis(codeinst->rettype_const, jl_partial_opaque_type)) {
if (skip_partial_opaque) {
jl_serialize_code_instance(s, codeinst->next, skip_partial_opaque);
} else {
jl_error("Cannot serialize CodeInstance with PartialOpaque rettype");
}
}

write_uint8(s->s, TAG_CODE_INSTANCE);
write_uint8(s->s, flags);
jl_serialize_value(s, (jl_value_t*)codeinst->def);
if (write_ret_type) {
jl_serialize_value(s, codeinst->inferred);
jl_serialize_value(s, codeinst->rettype_const);
jl_serialize_value(s, codeinst->rettype);
}
else {
// skip storing useless data
jl_serialize_value(s, NULL);
jl_serialize_value(s, NULL);
jl_serialize_value(s, jl_any_type);
}
jl_serialize_code_instance(s, codeinst->next, skip_partial_opaque);
}

static void jl_serialize_value_(jl_serializer_state *s, jl_value_t *v, int as_literal) JL_GC_DISABLED
{
if (jl_serialize_generic(s, v)) {
return;
}

size_t i;
if (jl_is_svec(v)) {
size_t l = jl_svec_len(v);
Expand Down Expand Up @@ -645,33 +697,10 @@ static void jl_serialize_value_(jl_serializer_state *s, jl_value_t *v, int as_li
}
jl_serialize_value(s, (jl_value_t*)backedges);
jl_serialize_value(s, (jl_value_t*)NULL); //callbacks
jl_serialize_value(s, (jl_value_t*)mi->cache);
jl_serialize_code_instance(s, mi->cache, 1);
}
else if (jl_is_code_instance(v)) {
write_uint8(s->s, TAG_CODE_INSTANCE);
jl_code_instance_t *codeinst = (jl_code_instance_t*)v;
int validate = 0;
if (codeinst->max_world == ~(size_t)0)
validate = 1; // can check on deserialize if this cache entry is still valid
int flags = validate << 0;
if (codeinst->invoke == jl_fptr_const_return)
flags |= 1 << 2;
if (codeinst->precompile)
flags |= 1 << 3;
write_uint8(s->s, flags);
jl_serialize_value(s, (jl_value_t*)codeinst->def);
if (validate || codeinst->min_world == 0) {
jl_serialize_value(s, codeinst->inferred);
jl_serialize_value(s, codeinst->rettype_const);
jl_serialize_value(s, codeinst->rettype);
}
else {
// skip storing useless data
jl_serialize_value(s, NULL);
jl_serialize_value(s, NULL);
jl_serialize_value(s, jl_any_type);
}
jl_serialize_value(s, codeinst->next);
jl_serialize_code_instance(s, (jl_code_instance_t*)v, 0);
}
else if (jl_typeis(v, jl_module_type)) {
jl_serialize_module(s, (jl_module_t*)v);
Expand Down Expand Up @@ -2422,6 +2451,7 @@ static jl_method_t *jl_lookup_method(jl_methtable_t *mt, jl_datatype_t *sig, siz

static jl_method_t *jl_recache_method(jl_method_t *m)
{
assert(!m->is_for_opaque_closure);
jl_datatype_t *sig = (jl_datatype_t*)m->sig;
jl_methtable_t *mt = jl_method_table_for((jl_value_t*)m->sig);
assert((jl_value_t*)mt != jl_nothing);
Expand Down
Loading

0 comments on commit d2ef8d0

Please sign in to comment.