Skip to content

Commit

Permalink
inference: improve management of non-type parameters (JuliaLang#42693)
Browse files Browse the repository at this point in the history
Prevent occurrence of v or Type{v} in the type-lattice, where v is not a
Type (or TypeVar).

Fixes JuliaLang#42646, and similar problems from code-reading.
  • Loading branch information
vtjnash authored and LilithHafner committed Mar 8, 2022
1 parent c7f4f61 commit 1222af4
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 58 deletions.
24 changes: 13 additions & 11 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -806,26 +806,27 @@ function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft)
end
if isa(tti, Union)
utis = uniontypes(tti)
if _any(t -> !isa(t, DataType) || !(t <: Tuple) || !isknownlength(t), utis)
if _any(@nospecialize(t) -> !isa(t, DataType) || !(t <: Tuple) || !isknownlength(t), utis)
return Any[Vararg{Any}], nothing
end
result = Any[rewrap_unionall(p, tti0) for p in (utis[1]::DataType).parameters]
for t::DataType in utis[2:end]
if length(t.parameters) != length(result)
ltp = length((utis[1]::DataType).parameters)
for t in utis
if length((t::DataType).parameters) != ltp
return Any[Vararg{Any}], nothing
end
for j in 1:length(t.parameters)
result[j] = tmerge(result[j], rewrap_unionall(t.parameters[j], tti0))
end
result = Any[ Union{} for _ in 1:ltp ]
for t in utis
tps = (t::DataType).parameters
_all(valid_as_lattice, tps) || continue
for j in 1:ltp
result[j] = tmerge(result[j], rewrap_unionall(tps[j], tti0))
end
end
return result, nothing
elseif tti0 <: Tuple
if isa(tti0, DataType)
if isvatuple(tti0) && length(tti0.parameters) == 1
return Any[Vararg{unwrapva(tti0.parameters[1])}], nothing
else
return Any[ p for p in tti0.parameters ], nothing
end
return Any[ p for p in tti0.parameters ], nothing
elseif !isa(tti, DataType)
return Any[Vararg{Any}], nothing
else
Expand Down Expand Up @@ -1121,6 +1122,7 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs
tty_lb = tty_ub # TODO: this would be wrong if !isexact_tty, but instanceof_tfunc doesn't preserve this info
if !has_free_typevars(tty_lb) && !has_free_typevars(tty_ub)
ifty = typeintersect(aty, tty_ub)
valid_as_lattice(ifty) || (ifty = Union{})
elty = typesubtract(aty, tty_lb, InferenceParams(interp).MAX_UNION_SPLITTING)
return Conditional(a, ifty, elty)
end
Expand Down
10 changes: 6 additions & 4 deletions base/compiler/inferenceresult.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(spe
if !toplevel && isva
if specTypes == Tuple
if nargs > 1
linfo_argtypes = svec(Any[Any for i = 1:(nargs - 1)]..., Tuple.parameters[1])
linfo_argtypes = Any[Any for i = 1:nargs]
linfo_argtypes[end] = Vararg{Any}
end
vargtype = Tuple
else
Expand All @@ -121,9 +122,10 @@ function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(spe
end
else
vargtype_elements = Any[]
for p in linfo_argtypes[nargs:linfo_argtypes_length]
for i in nargs:linfo_argtypes_length
p = linfo_argtypes[i]
p = unwraptv(isvarargtype(p) ? unconstrain_vararg_length(p) : p)
push!(vargtype_elements, elim_free_typevars(rewrap(p, specTypes)))
push!(vargtype_elements, elim_free_typevars(rewrap_unionall(p, specTypes)))
end
for i in 1:length(vargtype_elements)
atyp = vargtype_elements[i]
Expand Down Expand Up @@ -162,7 +164,7 @@ function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(spe
elseif isconstType(atyp)
atyp = Const(atyp.parameters[1])
else
atyp = elim_free_typevars(rewrap(atyp, specTypes))
atyp = elim_free_typevars(rewrap_unionall(atyp, specTypes))
end
i == n && (lastatype = atyp)
cache_argtypes[i] = atyp
Expand Down
67 changes: 40 additions & 27 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ add_tfunc(throw, 1, 1, (@nospecialize(x)) -> Bottom, 0)
# if istype is true, the actual runtime value will definitely be a type (e.g. this is false for Union{Type{Int}, Int})
function instanceof_tfunc(@nospecialize(t))
if isa(t, Const)
if isa(t.val, Type)
if isa(t.val, Type) && valid_as_lattice(t.val)
return t.val, true, isconcretetype(t.val), true
end
return Bottom, true, false, false # runtime throws on non-Type
Expand All @@ -79,6 +79,7 @@ function instanceof_tfunc(@nospecialize(t))
return Bottom, true, false, false # literal Bottom or non-Type
elseif isType(t)
tp = t.parameters[1]
valid_as_lattice(tp) || return Bottom, true, false, false # runtime unreachable / throws on non-Type
return tp, !has_free_typevars(tp), isconcretetype(tp), true
elseif isa(t, UnionAll)
t′ = unwrap_unionall(t)
Expand Down Expand Up @@ -473,7 +474,8 @@ function pointer_eltype(@nospecialize(ptr))
unw = unwrap_unionall(a)
if isa(unw, DataType) && unw.name === Ptr.body.name
T = unw.parameters[1]
T isa Type && return rewrap_unionall(T, a)
valid_as_lattice(T) || return Bottom
return rewrap_unionall(T, a)
end
end
return Any
Expand All @@ -486,7 +488,8 @@ function atomic_pointermodify_tfunc(ptr, op, v, order)
if isa(unw, DataType) && unw.name === Ptr.body.name
T = unw.parameters[1]
# note: we could sometimes refine this to a PartialStruct if we analyzed `op(T, T)::T`
T isa Type && return rewrap_unionall(Pair{T, T}, a)
valid_as_lattice(T) || return Bottom
return rewrap_unionall(Pair{T, T}, a)
end
end
return Pair
Expand All @@ -498,7 +501,8 @@ function atomic_pointerreplace_tfunc(ptr, x, v, success_order, failure_order)
unw = unwrap_unionall(a)
if isa(unw, DataType) && unw.name === Ptr.body.name
T = unw.parameters[1]
T isa Type && return rewrap_unionall(ccall(:jl_apply_cmpswap_type, Any, (Any,), T), a)
valid_as_lattice(T) || return Bottom
return rewrap_unionall(ccall(:jl_apply_cmpswap_type, Any, (Any,), T), a)
end
end
return ccall(:jl_apply_cmpswap_type, Any, (Any,), T) where T
Expand Down Expand Up @@ -754,8 +758,8 @@ function getfield_nothrow(@nospecialize(s00), @nospecialize(name), boundscheck::
s0 = widenconst(s00)
s = unwrap_unionall(s0)
if isa(s, Union)
return getfield_nothrow(rewrap(s.a, s00), name, boundscheck) &&
getfield_nothrow(rewrap(s.b, s00), name, boundscheck)
return getfield_nothrow(rewrap_unionall(s.a, s00), name, boundscheck) &&
getfield_nothrow(rewrap_unionall(s.b, s00), name, boundscheck)
elseif isa(s, DataType)
# Can't say anything about abstract types
isabstracttype(s) && return false
Expand All @@ -782,8 +786,8 @@ getfield_tfunc(s00, name, order, boundscheck) = (@nospecialize; getfield_tfunc(s
function getfield_tfunc(@nospecialize(s00), @nospecialize(name))
s = unwrap_unionall(s00)
if isa(s, Union)
return tmerge(getfield_tfunc(rewrap(s.a,s00), name),
getfield_tfunc(rewrap(s.b,s00), name))
return tmerge(getfield_tfunc(rewrap_unionall(s.a, s00), name),
getfield_tfunc(rewrap_unionall(s.b, s00), name))
elseif isa(s, Conditional)
return Bottom # Bool has no fields
elseif isa(s, Const) || isconstType(s)
Expand Down Expand Up @@ -857,9 +861,6 @@ function getfield_tfunc(@nospecialize(s00), @nospecialize(name))
end
return Any
end
# If no value has this type, then this statement should be unreachable.
# Bail quickly now.
has_concrete_subtype(s) || return Union{}
if s.name === _NAMEDTUPLE_NAME && !isconcretetype(s)
if isa(name, Const) && isa(name.val, Symbol)
if isa(s.parameters[1], Tuple)
Expand All @@ -878,7 +879,9 @@ function getfield_tfunc(@nospecialize(s00), @nospecialize(name))
return getfield_tfunc(_ts, name)
end
ftypes = datatype_fieldtypes(s)
if isempty(ftypes)
# If no value has this type, then this statement should be unreachable.
# Bail quickly now.
if !has_concrete_subtype(s) || isempty(ftypes)
return Bottom
end
if isa(name, Conditional)
Expand Down Expand Up @@ -1072,8 +1075,8 @@ function fieldtype_tfunc(@nospecialize(s0), @nospecialize(name))

su = unwrap_unionall(s0)
if isa(su, Union)
return tmerge(fieldtype_tfunc(rewrap(su.a, s0), name),
fieldtype_tfunc(rewrap(su.b, s0), name))
return tmerge(fieldtype_tfunc(rewrap_unionall(su.a, s0), name),
fieldtype_tfunc(rewrap_unionall(su.b, s0), name))
end

s, exact = instanceof_tfunc(s0)
Expand All @@ -1085,8 +1088,8 @@ function _fieldtype_tfunc(@nospecialize(s), exact::Bool, @nospecialize(name))
exact = exact && !has_free_typevars(s)
u = unwrap_unionall(s)
if isa(u, Union)
ta0 = _fieldtype_tfunc(rewrap(u.a, s), exact, name)
tb0 = _fieldtype_tfunc(rewrap(u.b, s), exact, name)
ta0 = _fieldtype_tfunc(rewrap_unionall(u.a, s), exact, name)
tb0 = _fieldtype_tfunc(rewrap_unionall(u.b, s), exact, name)
ta0 tb0 && return tb0
tb0 ta0 && return ta0
ta, exacta, _, istypea = instanceof_tfunc(ta0)
Expand Down Expand Up @@ -1296,7 +1299,11 @@ function apply_type_tfunc(@nospecialize(headtypetype), @nospecialize args...)
end
end
end
largs == 1 && return isa(args[1], Type) ? typeintersect(args[1], Type) : Type
if largs == 1 # Union{T} --> T
u1 = typeintersect(widenconst(args[1]), Type)
valid_as_lattice(u1) || return Bottom
return u1
end
hasnonType && return Type
ty = Union{}
allconst = true
Expand Down Expand Up @@ -1471,21 +1478,26 @@ end

function arrayref_tfunc(@nospecialize(boundscheck), @nospecialize(a), @nospecialize i...)
a = widenconst(a)
if a <: Array
if isa(a, DataType) && isa(a.parameters[1], Type)
return a.parameters[1]
elseif isa(a, UnionAll) && !has_free_typevars(a)
unw = unwrap_unionall(a)
if isa(unw, DataType)
return rewrap_unionall(unw.parameters[1], a)
end
if !has_free_typevars(a) && a <: Array
a0 = a
if isa(a, UnionAll)
a = unwrap_unionall(a0)
end
if isa(a, DataType)
T = a.parameters[1]
valid_as_lattice(T) || return Bottom
return rewrap_unionall(T, a0)
end
end
return Any
end
add_tfunc(arrayref, 3, INT_INF, arrayref_tfunc, 20)
add_tfunc(const_arrayref, 3, INT_INF, arrayref_tfunc, 20)
add_tfunc(arrayset, 4, INT_INF, (@nospecialize(boundscheck), @nospecialize(a), @nospecialize(v), @nospecialize i...)->a, 20)
function arrayset_tfunc(@nospecialize(boundscheck), @nospecialize(a), @nospecialize(v), @nospecialize i...)
# TODO: we could check that the type-intersect of arrayref_tfunc and v is non-empty or always throws
return a
end
add_tfunc(arrayset, 4, INT_INF, arrayset_tfunc, 20)

function _opaque_closure_tfunc(@nospecialize(arg), @nospecialize(isva),
@nospecialize(lb), @nospecialize(ub), @nospecialize(source), env::Vector{Any},
Expand All @@ -1508,6 +1520,7 @@ function _opaque_closure_tfunc(@nospecialize(arg), @nospecialize(isva),
return PartialOpaque(t, tuple_tfunc(env), isva.val, linfo, source.val)
end

# whether getindex for the elements can potentially throw UndefRef
function array_type_undefable(@nospecialize(a))
if isa(a, Union)
return array_type_undefable(a.a) || array_type_undefable(a.b)
Expand Down Expand Up @@ -1550,7 +1563,7 @@ function _builtin_nothrow(@nospecialize(f), argtypes::Array{Any,1}, @nospecializ
# Check that we can determine the element type
(isa(a, DataType) && isa(a.parameters[1], Type)) || return false
# Check that the element type is compatible with the element we're assigning
(argtypes[3] a.parameters[1]::Type) || return false
(argtypes[3] a.parameters[1]) || return false
return true
elseif f === arrayref || f === const_arrayref
return array_builtin_common_nothrow(argtypes, 3)
Expand Down
8 changes: 4 additions & 4 deletions base/compiler/typelimits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -605,9 +605,7 @@ function tmeet(@nospecialize(v), @nospecialize(t))
return v
end
ti = typeintersect(widev, t)
if ti === Bottom
return Bottom
end
valid_as_lattice(ti) || return Bottom
@assert widev <: Tuple
new_fields = Vector{Any}(undef, length(v.fields))
for i = 1:length(new_fields)
Expand All @@ -628,5 +626,7 @@ function tmeet(@nospecialize(v), @nospecialize(t))
end
return v
end
return typeintersect(widenconst(v), t)
ti = typeintersect(widenconst(v), t)
valid_as_lattice(ti) || return Bottom
return ti
end
39 changes: 28 additions & 11 deletions base/compiler/typeutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,6 @@
# lattice utilities #
#####################

function rewrap(@nospecialize(t), @nospecialize(u))
if isa(t, TypeVar) || isa(t, Type) || isvarargtype(t)
return rewrap_unionall(t, u)
end
return t
end

isType(@nospecialize t) = isa(t, DataType) && t.name === _TYPE_NAME

# true if Type{T} is inlineable as constant T
Expand Down Expand Up @@ -42,8 +35,6 @@ end

has_const_info(@nospecialize x) = (!isa(x, Type) && !isvarargtype(x)) || isType(x)

has_concrete_subtype(d::DataType) = d.flags & 0x20 == 0x20

# Subtyping currently intentionally answers certain queries incorrectly for kind types. For
# some of these queries, this check can be used to somewhat protect against making incorrect
# decisions based on incorrect subtyping. Note that this check, itself, is broken for
Expand Down Expand Up @@ -89,6 +80,30 @@ function datatype_min_ninitialized(t::DataType)
return length(t.name.names) - t.name.n_uninitialized
end

has_concrete_subtype(d::DataType) = d.flags & 0x20 == 0x20 # n.b. often computed only after setting the type and layout fields

# determine whether x is a valid lattice element tag
# For example, Type{v} is not valid if v is a value
# Accepts TypeVars also, since it assumes the user will rewrap it correctly
function valid_as_lattice(@nospecialize(x))
x === Bottom && false
x isa TypeVar && return valid_as_lattice(x.ub)
x isa UnionAll && (x = unwrap_unionall(x))
if x isa Union
# the Union constructor ensures this (and we'll recheck after
# operations that might remove the Union itself)
return true
end
if x isa DataType
if isType(x)
p = x.parameters[1]
p isa Type || p isa TypeVar || return false
end
return true
end
return false
end

# test if non-Type, non-TypeVar `x` can be used to parameterize a type
function valid_tparam(@nospecialize(x))
if isa(x, Tuple)
Expand Down Expand Up @@ -119,8 +134,10 @@ function typesubtract(@nospecialize(a), @nospecialize(b), MAX_UNION_SPLITTING::I
end
ua = unwrap_unionall(a)
if isa(ua, Union)
return Union{typesubtract(rewrap_unionall(ua.a, a), b, MAX_UNION_SPLITTING),
typesubtract(rewrap_unionall(ua.b, a), b, MAX_UNION_SPLITTING)}
uua = typesubtract(rewrap_unionall(ua.a, a), b, MAX_UNION_SPLITTING)
uub = typesubtract(rewrap_unionall(ua.b, a), b, MAX_UNION_SPLITTING)
return Union{valid_as_lattice(uua) ? uua : Union{},
valid_as_lattice(uub) ? uub : Union{}}
elseif a isa DataType
ub = unwrap_unionall(b)
if ub isa DataType
Expand Down
6 changes: 5 additions & 1 deletion src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -1198,8 +1198,12 @@ void jl_precompute_memoized_dt(jl_datatype_t *dt, int cacheable)
dt->has_concrete_subtype = 0;
}
}
if (dt->name == jl_type_typename)
if (dt->name == jl_type_typename) {
cacheable = 0; // the cache for Type ignores parameter normalization, so it can't be used as a regular hash
jl_value_t *p = jl_tparam(dt, 0);
if (!jl_is_type(p) && !jl_is_typevar(p)) // Type{v} has no subtypes, if v is not a Type
dt->has_concrete_subtype = 0;
}
dt->hash = typekey_hash(dt->name, jl_svec_data(dt->parameters), l, cacheable);
dt->cached_by_hash = cacheable ? (typekey_hash(dt->name, jl_svec_data(dt->parameters), l, 0) != 0) : (dt->hash != 0);
}
Expand Down
3 changes: 3 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3642,3 +3642,6 @@ let
@test argtypes[10] == Any
@test argtypes[11] == Tuple{Integer,Integer}
end

# issue #42646
@test only(Base.return_types(getindex, (Array{undef}, Int))) >: Union{} # check that it does not throw

0 comments on commit 1222af4

Please sign in to comment.