Skip to content

Commit

Permalink
inference: prevent tmerge from picking a larger type
Browse files Browse the repository at this point in the history
We want tmerge to form a smaller supertype, so we need to make sure the
result is on the intersection of the supertype and simplicity lattice.
Previously, we first only checked the supertype lattice, then considered
the simplicity lattice only if that failed.

fix #34834
  • Loading branch information
vtjnash committed Mar 3, 2020
1 parent 4eb0943 commit 1d08d70
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 19 deletions.
26 changes: 16 additions & 10 deletions base/compiler/typelimits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ function _limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVec
else
ut = unwrap_unionall(t)
if isa(ut, DataType) && ut.name !== _va_typename && isa(c, Type) && c !== Union{} && c <: t
# TODO: need to check that the UnionAll bounds on t are limited enough too
return t # t is already wider than the comparison in the type lattice
elseif is_derived_type_from_any(ut, sources, depth)
return t # t isn't something new
Expand Down Expand Up @@ -187,6 +188,7 @@ function type_more_complex(@nospecialize(t), @nospecialize(c), sources::SimpleVe
elseif isa(t, DataType) && isempty(t.parameters)
return false # fastpath: unparameterized types are always finite
elseif tupledepth > 0 && isa(unwrap_unionall(t), DataType) && isa(c, Type) && c !== Union{} && c <: t
# TODO: need to check that the UnionAll bounds on t are limited enough too
return false # t is already wider than the comparison in the type lattice
elseif tupledepth > 0 && is_derived_type_from_any(unwrap_unionall(t), sources, depth)
return false # t isn't something new
Expand Down Expand Up @@ -266,13 +268,21 @@ function type_more_complex(@nospecialize(t), @nospecialize(c), sources::SimpleVe
return true
end

function issimpleenoughtype(@nospecialize t)
return unionlen(t) <= MAX_TYPEUNION_LENGTH && unioncomplexity(t) <= MAX_TYPEUNION_COMPLEXITY
end

# pick a wider type that contains both typea and typeb,
# with some limits on how "large" it can get,
# but without losing too much precision in common cases
# and also trying to be mostly associative and commutative
function tmerge(@nospecialize(typea), @nospecialize(typeb))
typea typeb && return typeb
typeb typea && return typea
suba = typea typeb
suba && issimpleenoughtype(typeb) && return typeb
subb = typeb typea
suba && subb && return typea
subb && issimpleenoughtype(typea) && return typea

# type-lattice for MaybeUndef wrapper
if isa(typea, MaybeUndef) || isa(typeb, MaybeUndef)
return MaybeUndef(tmerge(
Expand Down Expand Up @@ -331,7 +341,7 @@ function tmerge(@nospecialize(typea), @nospecialize(typeb))
end
# no special type-inference lattice, join the types
typea, typeb = widenconst(typea), widenconst(typeb)
typea === typeb && return typea
typea == typeb && return typea
if !(isa(typea, Type) || isa(typea, TypeVar)) ||
!(isa(typeb, Type) || isa(typeb, TypeVar))
# XXX: this should never happen
Expand Down Expand Up @@ -387,8 +397,8 @@ function tmerge(@nospecialize(typea), @nospecialize(typeb))
end
end
u = Union{types...}
if unionlen(u) <= MAX_TYPEUNION_LENGTH && unioncomplexity(u) <= MAX_TYPEUNION_COMPLEXITY
# don't let type unions get too big, if the above didn't reduce it enough
# don't let type unions get too big, if the above didn't reduce it enough
if issimpleenoughtype(u)
return u
end
# finally, just return the widest possible type
Expand All @@ -414,11 +424,7 @@ function tuplemerge(a::DataType, b::DataType)
p = Vector{Any}(undef, lt + vt)
for i = 1:lt
ui = Union{ap[i], bp[i]}
if unionlen(ui) <= MAX_TYPEUNION_LENGTH && unioncomplexity(ui) <= MAX_TYPEUNION_COMPLEXITY
p[i] = ui
else
p[i] = Any
end
p[i] = issimpleenoughtype(ui) ? ui : Any
end
# merge the remaining tail into a single, simple Tuple{Vararg{T}} (#22120)
if vt
Expand Down
12 changes: 4 additions & 8 deletions base/compiler/typeutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,17 +146,13 @@ end
# unioncomplexity estimates the number of calls to `tmerge` to obtain the given type by
# counting the Union instances, taking also into account those hidden in a Tuple or UnionAll
function unioncomplexity(u::Union)
inner = max(unioncomplexity(u.a), unioncomplexity(u.b))
return inner == 0 ? 0 : 1 + inner
return unioncomplexity(u.a) + unioncomplexity(u.b) + 1
end
function unioncomplexity(t::DataType)
t.name === Tuple.name || return 0
c = 1
t.name === Tuple.name || isvarargtype(t) || return 0
c = 0
for ti in t.parameters
ci = unioncomplexity(ti)
if ci > c
c = ci
end
c = max(c, unioncomplexity(ti))
end
return c
end
Expand Down
38 changes: 37 additions & 1 deletion test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,28 @@ let ref = Tuple{T, Val{T}} where T<:(Val{T} where T<:(Val{T} where T<:(Val{T} wh
end


@test Core.Compiler.unionlen(Union{}) == 1
@test Core.Compiler.unionlen(Int8) == 1
@test Core.Compiler.unionlen(Union{Int8, Int16}) == 2
@test Core.Compiler.unionlen(Union{Int8, Int16, Int32, Int64}) == 4
@test Core.Compiler.unionlen(Tuple{Union{Int8, Int16, Int32, Int64}}) == 1
@test Core.Compiler.unionlen(Union{Int8, Int16, Int32, T} where T) == 1

@test Core.Compiler.unioncomplexity(Union{}) == 0
@test Core.Compiler.unioncomplexity(Int8) == 0
@test Core.Compiler.unioncomplexity(Val{Union{Int8, Int16, Int32, Int64}}) == 0
@test Core.Compiler.unioncomplexity(Union{Int8, Int16}) == 1
@test Core.Compiler.unioncomplexity(Union{Int8, Int16, Int32, Int64}) == 3
@test Core.Compiler.unioncomplexity(Tuple{Union{Int8, Int16, Int32, Int64}}) == 3
@test Core.Compiler.unioncomplexity(Union{Int8, Int16, Int32, T} where T) == 3
@test Core.Compiler.unioncomplexity(Tuple{Val{T}, Union{Int8, Int16}, Int8} where T<:Union{Int8, Int16, Int32, Int64}) == 3
@test Core.Compiler.unioncomplexity(Tuple{Vararg{Tuple{Union{Int8, Int16}}}}) == 1
@test Core.Compiler.unioncomplexity(Tuple{Vararg{Symbol}}) == 0
@test Core.Compiler.unioncomplexity(Tuple{Vararg{Union{Symbol, Tuple{Vararg{Symbol}}}}}) == 1
@test Core.Compiler.unioncomplexity(Tuple{Vararg{Union{Symbol, Tuple{Vararg{Union{Symbol, Tuple{Vararg{Symbol}}}}}}}}) == 2
@test Core.Compiler.unioncomplexity(Tuple{Vararg{Union{Symbol, Tuple{Vararg{Union{Symbol, Tuple{Vararg{Union{Symbol, Tuple{Vararg{Symbol}}}}}}}}}}}) == 3


# PR 22120
function tmerge_test(a, b, r, commutative=true)
@test r == Core.Compiler.tuplemerge(a, b)
Expand Down Expand Up @@ -1931,7 +1953,7 @@ function h27316()
end
return x
end
@test Tuple{Tuple{Vector{Int}}} <: Base.return_types(h27316, Tuple{})[1] == Union{Vector{Int}, Tuple{Any}} # we may be able to improve this bound in the future
@test Tuple{Tuple{Vector{Int}}} <: only(Base.return_types(h27316, Tuple{})) == Any # we may be able to improve this bound in the future

# PR 27434, inference when splatting iterators with type-based state
splat27434(x) = (x...,)
Expand Down Expand Up @@ -2503,3 +2525,17 @@ function h34752()
a34752(g...)
end
@test h34752() === true

# issue 34834
pickvarnames(x::Symbol) = x
function pickvarnames(x::Vector{Any})
varnames = ()
for a in x
varnames = (varnames..., pickvarnames(a) )
end
return varnames
end
@test pickvarnames(:a) === :a
@test pickvarnames(Any[:a, :b]) === (:a, :b)
@test only(Base.return_types(pickvarnames, (Vector{Any},))) == Tuple
@test only(Base.code_typed(pickvarnames, (Vector{Any},), optimize=false))[2] == Tuple{Vararg{Union{Symbol, Tuple}}}

0 comments on commit 1d08d70

Please sign in to comment.