Skip to content

Commit

Permalink
inference: implement type-based alias analysis to refine constrained …
Browse files Browse the repository at this point in the history
…field (#41199)

This commit tries to propagate constraints imposed on object fields, e.g.:
```julia
struct SomeX{T}
    x::Union{Nothing,T}
end
mutable struct MutableSomeX{T}
    const x::Union{Nothing,T}
end

let # o1::SomeX{T}, o2::MutableSomeX{T}
    if !isnothing(o1.x)
        # now inference knows `o1.x::T` here
        ...
        if !isnothing(o2.x)
            # now inference knows `o2.x::T` here
            ...
        end
    end
end
```

The idea is that we can make `isa` and `===` propagate constraint
imposed on an object field if the _identity_ of that object.
We can have such a lattice element that wraps return type of abstract
`getfield` call together with the object _identity_, and then we can
form a conditional constraint that propagates the refinement information
imposed on the object field when we see `isa`/`===` applied the return
value of the preceding `getfield` call.

So this PR defines the new lattice element called `MustAlias` (and also
`InterMustAlias`, which just works in a similar way to `InterConditional`),
which may be formed upon `getfield` inference to hold the retrieved type
of the field and track the _identity_ of the object (in inference,
"object identity" can be represented as a `SlotNumber`).
This PR also implements the new logic in `abstract_call_builtin` so that
`isa` and `===` can form a conditional constraint (i.e. `Conditional`)
from `MustAlias`-argument that may later refine the wrapped object to
`PartialStruct` that holds the refined field type information.

One important note here is, `MustAlias` expects the invariant that the
field of wrapped slot object never changes. The biggest limitation with
this invariant is that it can't propagate constraints imposed on mutable
fields, because inference currently doesn't have a precise (per-object)
knowledge of memory effect.
  • Loading branch information
aviatesk authored Nov 19, 2022
1 parent 6707077 commit 1d8f7e0
Show file tree
Hide file tree
Showing 17 changed files with 901 additions and 145 deletions.
2 changes: 1 addition & 1 deletion base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ eval(Core, quote
end
Const(@nospecialize(v)) = $(Expr(:new, :Const, :v))
# NOTE the main constructor is defined within `Core.Compiler`
_PartialStruct(typ::DataType, fields::Array{Any, 1}) = $(Expr(:new, :PartialStruct, :typ, :fields))
_PartialStruct(@nospecialize(typ), fields::Array{Any, 1}) = $(Expr(:new, :PartialStruct, :typ, :fields))
PartialOpaque(@nospecialize(typ), @nospecialize(env), parent::MethodInstance, source) = $(Expr(:new, :PartialOpaque, :typ, :env, :parent, :source))
InterConditional(slot::Int, @nospecialize(thentype), @nospecialize(elsetype)) = $(Expr(:new, :InterConditional, :slot, :thentype, :elsetype))
MethodMatch(@nospecialize(spec_types), sparams::SimpleVector, method::Method, fully_covers::Bool) = $(Expr(:new, :MethodMatch, :spec_types, :sparams, :method, :fully_covers))
Expand Down
293 changes: 215 additions & 78 deletions base/compiler/abstractinterpretation.jl

Large diffs are not rendered by default.

29 changes: 28 additions & 1 deletion base/compiler/abstractlattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,30 @@ end
widenlattice(L::InterConditionalsLattice) = L.parent
is_valid_lattice_norec(lattice::InterConditionalsLattice, @nospecialize(elem)) = isa(elem, InterConditional)

const AnyConditionalsLattice{L} = Union{ConditionalsLattice{L}, InterConditionalsLattice{L}}
"""
struct MustAliasesLattice{𝕃}
A lattice extending lattice `𝕃` and adjoining `MustAlias`.
"""
struct MustAliasesLattice{𝕃 <: AbstractLattice} <: AbstractLattice
parent::𝕃
end
widenlattice(𝕃::MustAliasesLattice) = 𝕃.parent
is_valid_lattice_norec(𝕃::MustAliasesLattice, @nospecialize(elem)) = isa(elem, MustAlias)

"""
struct InterMustAliasesLattice{𝕃}
A lattice extending lattice `𝕃` and adjoining `InterMustAlias`.
"""
struct InterMustAliasesLattice{𝕃 <: AbstractLattice} <: AbstractLattice
parent::𝕃
end
widenlattice(𝕃::InterMustAliasesLattice) = 𝕃.parent
is_valid_lattice_norec(𝕃::InterMustAliasesLattice, @nospecialize(elem)) = isa(elem, InterMustAlias)

const AnyConditionalsLattice{𝕃} = Union{ConditionalsLattice{𝕃}, InterConditionalsLattice{𝕃}}
const AnyMustAliasesLattice{𝕃} = Union{MustAliasesLattice{𝕃}, InterMustAliasesLattice{𝕃}}

const SimpleInferenceLattice = typeof(PartialsLattice(ConstsLattice()))
const BaseInferenceLattice = typeof(ConditionalsLattice(SimpleInferenceLattice.instance))
Expand Down Expand Up @@ -159,6 +182,10 @@ has_conditional(𝕃::AbstractLattice) = has_conditional(widenlattice(𝕃))
has_conditional(::AnyConditionalsLattice) = true
has_conditional(::JLTypeLattice) = false

has_mustalias(𝕃::AbstractLattice) = has_mustalias(widenlattice(𝕃))
has_mustalias(::AnyMustAliasesLattice) = true
has_mustalias(::JLTypeLattice) = false

# Curried versions
(lattice::AbstractLattice) = (@nospecialize(a), @nospecialize(b)) -> (lattice, a, b)
(lattice::AbstractLattice) = (@nospecialize(a), @nospecialize(b)) -> (lattice, a, b)
Expand Down
5 changes: 3 additions & 2 deletions base/compiler/inferenceresult.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ function matching_cache_argtypes(linfo::MethodInstance, simple_argtypes::SimpleA
(; argtypes) = simple_argtypes
given_argtypes = Vector{Any}(undef, length(argtypes))
for i = 1:length(argtypes)
given_argtypes[i] = widenconditional(argtypes[i])
given_argtypes[i] = widenslotwrapper(argtypes[i])
end
given_argtypes = va_process_argtypes(given_argtypes, linfo)
return pick_const_args(linfo, given_argtypes)
Expand Down Expand Up @@ -78,6 +78,7 @@ function is_argtype_match(lattice::AbstractLattice,
return !overridden_by_const
end

# TODO MustAlias forwarding
function is_forwardable_argtype(@nospecialize x)
return isa(x, Const) ||
isa(x, Conditional) ||
Expand Down Expand Up @@ -223,7 +224,7 @@ function cache_lookup(lattice::AbstractLattice, linfo::MethodInstance, given_arg
cache_argtypes = cached_result.argtypes
cache_overridden_by_const = cached_result.overridden_by_const
for i in 1:nargs
if !is_argtype_match(lattice, given_argtypes[i],
if !is_argtype_match(lattice, widenmustalias(given_argtypes[i]),
cache_argtypes[i],
cache_overridden_by_const[i])
cache_match = false
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState,
# compute inlining and other related optimizations
result = caller.result
@assert !(result isa LimitedAccuracy)
result = isa(result, InterConditional) ? widenconditional(result) : result
result = widenslotwrapper(result)
if (isa(result, Const) || isconstType(result))
proven_pure = false
# must be proven pure to use constant calling convention;
Expand Down
3 changes: 3 additions & 0 deletions base/compiler/ssair/irinterp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ struct IRInterpretationState
function IRInterpretationState(interp::AbstractInterpreter,
ir::IRCode, mi::MethodInstance, world::UInt, argtypes::Vector{Any})
argtypes = va_process_argtypes(argtypes, mi)
for i = 1:length(argtypes)
argtypes[i] = widenslotwrapper(argtypes[i])
end
argtypes_refined = Bool[!(typeinf_lattice(interp), ir.argtypes[i], argtypes[i]) for i = 1:length(argtypes)]
empty!(ir.argtypes)
append!(ir.argtypes, argtypes)
Expand Down
54 changes: 35 additions & 19 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ add_tfunc(Core.Intrinsics.cglobal, 1, 2, cglobal_tfunc, 5)
add_tfunc(Core.Intrinsics.have_fma, 1, 1, @nospecialize(x)->Bool, 1)

function ifelse_tfunc(@nospecialize(cnd), @nospecialize(x), @nospecialize(y))
cnd = widenslotwrapper(cnd)
if isa(cnd, Const)
if cnd.val === true
return x
Expand All @@ -212,9 +213,7 @@ function ifelse_tfunc(@nospecialize(cnd), @nospecialize(x), @nospecialize(y))
else
return Bottom
end
elseif isa(cnd, Conditional)
# optimized (if applicable) in abstract_call
elseif !(Bool cnd)
elseif !hasintersect(widenconst(cnd), Bool)
return Bottom
end
return tmerge(x, y)
Expand All @@ -228,6 +227,9 @@ end

egal_tfunc(@specialize(𝕃::AbstractLattice), @nospecialize(x), @nospecialize(y)) =
egal_tfunc(widenlattice(𝕃), x, y)
function egal_tfunc(@specialize(𝕃::MustAliasesLattice), @nospecialize(x), @nospecialize(y))
return egal_tfunc(widenlattice(𝕃), widenmustalias(x), widenmustalias(y))
end
function egal_tfunc(@specialize(𝕃::ConditionalsLattice), @nospecialize(x), @nospecialize(y))
if isa(x, Conditional)
y = widenconditional(y)
Expand Down Expand Up @@ -337,8 +339,6 @@ function sizeof_nothrow(@nospecialize(x))
if !isa(x.val, Type) || x.val === DataType
return true
end
elseif isa(x, Conditional)
return true
end
xu = unwrap_unionall(x)
if isa(xu, Union)
Expand Down Expand Up @@ -385,7 +385,8 @@ function _const_sizeof(@nospecialize(x))
end
return Const(size)
end
function sizeof_tfunc(@nospecialize(x),)
function sizeof_tfunc(@nospecialize(x))
x = widenmustalias(x)
isa(x, Const) && return _const_sizeof(x.val)
isa(x, Conditional) && return _const_sizeof(Bool)
isconstType(x) && return _const_sizeof(x.parameters[1])
Expand Down Expand Up @@ -453,19 +454,25 @@ function typevar_tfunc(@nospecialize(n), @nospecialize(lb_arg), @nospecialize(ub
isa(nval, Symbol) || return Union{}
if isa(lb_arg, Const)
lb = lb_arg.val
elseif isType(lb_arg)
lb = lb_arg.parameters[1]
lb_certain = false
else
return TypeVar
lb_arg = widenslotwrapper(lb_arg)
if isType(lb_arg)
lb = lb_arg.parameters[1]
lb_certain = false
else
return TypeVar
end
end
if isa(ub_arg, Const)
ub = ub_arg.val
elseif isType(ub_arg)
ub = ub_arg.parameters[1]
ub_certain = false
else
return TypeVar
ub_arg = widenslotwrapper(ub_arg)
if isType(ub_arg)
ub = ub_arg.parameters[1]
ub_certain = false
else
return TypeVar
end
end
tv = TypeVar(nval, lb, ub)
return PartialTypeVar(tv, lb_certain, ub_certain)
Expand Down Expand Up @@ -966,6 +973,11 @@ function _getfield_tfunc(@specialize(lattice::AnyConditionalsLattice), @nospecia
return _getfield_tfunc(widenlattice(lattice), s00, name, setfield)
end

function _getfield_tfunc(@specialize(𝕃::AnyMustAliasesLattice), @nospecialize(s00), @nospecialize(name), setfield::Bool)
s00 = widenmustalias(s00)
return _getfield_tfunc(widenlattice(𝕃), s00, name, setfield)
end

function _getfield_tfunc(@specialize(lattice::PartialsLattice), @nospecialize(s00), @nospecialize(name), setfield::Bool)
if isa(s00, PartialStruct)
s = widenconst(s00)
Expand Down Expand Up @@ -1328,6 +1340,7 @@ end

fieldtype_tfunc(s0, name, boundscheck) = (@nospecialize; fieldtype_tfunc(s0, name))
function fieldtype_tfunc(@nospecialize(s0), @nospecialize(name))
s0 = widenmustalias(s0)
if s0 === Bottom
return Bottom
end
Expand Down Expand Up @@ -1525,6 +1538,7 @@ const _tvarnames = Symbol[:_A, :_B, :_C, :_D, :_E, :_F, :_G, :_H, :_I, :_J, :_K,

# TODO: handle e.g. apply_type(T, R::Union{Type{Int32},Type{Float64}})
function apply_type_tfunc(@nospecialize(headtypetype), @nospecialize args...)
headtypetype = widenslotwrapper(headtypetype)
if isa(headtypetype, Const)
headtype = headtypetype.val
elseif isconstType(headtypetype)
Expand Down Expand Up @@ -1591,7 +1605,7 @@ function apply_type_tfunc(@nospecialize(headtypetype), @nospecialize args...)
varnamectr = 1
ua = headtype
for i = 1:largs
ai = widenconditional(args[i])
ai = widenslotwrapper(args[i])
if isType(ai)
aip1 = ai.parameters[1]
canconst &= !has_free_typevars(aip1)
Expand Down Expand Up @@ -1689,7 +1703,7 @@ add_tfunc(apply_type, 1, INT_INF, apply_type_tfunc, 10)
# convert the dispatch tuple type argtype to the real (concrete) type of
# the tuple of those values
function tuple_tfunc(@specialize(lattice::AbstractLattice), argtypes::Vector{Any})
argtypes = anymap(widenconditional, argtypes)
argtypes = anymap(widenslotwrapper, argtypes)
all_are_const = true
for i in 1:length(argtypes)
if !isa(argtypes[i], Const)
Expand Down Expand Up @@ -2203,6 +2217,8 @@ function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtyp
return getfield_tfunc(𝕃ᵢ, argtypes...)
elseif f === (===)
return egal_tfunc(𝕃ᵢ, argtypes...)
elseif f === isa
return isa_tfunc(𝕃ᵢ, argtypes...)
end
return tf[3](argtypes...)
end
Expand Down Expand Up @@ -2324,9 +2340,9 @@ end
# while this assumes that it is an absolutely precise and accurate and exact model of both
function return_type_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, si::StmtInfo, sv::Union{InferenceState, IRCode})
if length(argtypes) == 3
tt = argtypes[3]
tt = widenslotwrapper(argtypes[3])
if isa(tt, Const) || (isType(tt) && !has_free_typevars(tt))
aft = argtypes[2]
aft = widenslotwrapper(argtypes[2])
if isa(aft, Const) || (isType(aft) && !has_free_typevars(aft)) ||
(isconcretetype(aft) && !(aft <: Builtin))
af_argtype = isa(tt, Const) ? tt.val : (tt::DataType).parameters[1]
Expand All @@ -2348,7 +2364,7 @@ function return_type_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, s
call = abstract_call(interp, ArgInfo(nothing, argtypes_vec), si, sv, -1)
end
info = verbose_stmt_info(interp) ? MethodResultPure(ReturnTypeCallInfo(call.info)) : MethodResultPure()
rt = widenconditional(call.rt)
rt = widenslotwrapper(call.rt)
if isa(rt, Const)
# output was computed to be constant
return CallMeta(Const(typeof(rt.val)), EFFECTS_TOTAL, info)
Expand Down
21 changes: 13 additions & 8 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,9 @@ function CodeInstance(
elseif isa(result_type, InterConditional)
rettype_const = result_type
const_flags = 0x2
elseif isa(result_type, InterMustAlias)
rettype_const = result_type
const_flags = 0x2
else
rettype_const = nothing
const_flags = 0x00
Expand Down Expand Up @@ -526,8 +529,8 @@ function finish(me::InferenceState, interp::AbstractInterpreter)
end
# inspect whether our inference had a limited result accuracy,
# else it may be suitable to cache
me.bestguess = cycle_fix_limited(me.bestguess, me)
limited_ret = me.bestguess isa LimitedAccuracy
bestguess = me.bestguess = cycle_fix_limited(me.bestguess, me)
limited_ret = bestguess isa LimitedAccuracy
limited_src = false
if !limited_ret
gt = me.ssavaluetypes
Expand Down Expand Up @@ -564,7 +567,7 @@ function finish(me::InferenceState, interp::AbstractInterpreter)
end
end
me.result.valid_worlds = me.valid_worlds
me.result.result = me.bestguess
me.result.result = bestguess
me.ipo_effects = me.result.ipo_effects = adjust_effects(me)
validate_code_in_debug_mode(me.linfo, me.src, "inferred")
nothing
Expand Down Expand Up @@ -640,7 +643,7 @@ function annotate_slot_load!(undefs::Vector{Bool}, idx::Int, sv::InferenceState,
state = sv.bb_vartables[block]::VarTable
vt = state[id]
undefs[id] |= vt.undef
typ = widenconditional(ignorelimited(vt.typ))
typ = widenslotwrapper(ignorelimited(vt.typ))
else
typ = sv.ssavaluetypes[pc]
@assert typ !== NOT_FOUND "active slot in unreached region"
Expand Down Expand Up @@ -719,7 +722,7 @@ function type_annotate!(interp::AbstractInterpreter, sv::InferenceState, run_opt
# 1. introduce temporary `TypedSlot`s that are supposed to be replaced with π-nodes later
# 2. mark used-undef slots (required by the `slot2reg` conversion)
# 3. mark unreached statements for a bulk code deletion (see issue #7836)
# 4. widen `Conditional`s and remove `NOT_FOUND` from `ssavaluetypes`
# 4. widen slot wrappers (`Conditional` and `MustAlias`) and remove `NOT_FOUND` from `ssavaluetypes`
# NOTE because of this, `was_reached` will no longer be available after this point
# 5. eliminate GotoIfNot if either branch target is unreachable
changemap = nothing # initialized if there is any dead region
Expand All @@ -739,7 +742,7 @@ function type_annotate!(interp::AbstractInterpreter, sv::InferenceState, run_opt
end
end
body[i] = annotate_slot_load!(undefs, i, sv, expr) # 1&2
ssavaluetypes[i] = widenconditional(ssavaluetypes[i]) # 4
ssavaluetypes[i] = widenslotwrapper(ssavaluetypes[i]) # 4
else # i.e. any runtime execution will never reach this statement
if is_meta_expr(expr) # keep any lexically scoped expressions
ssavaluetypes[i] = Any # 4
Expand Down Expand Up @@ -893,13 +896,15 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
rettype = code.rettype
if isdefined(code, :rettype_const)
rettype_const = code.rettype_const
# the second subtyping conditions are necessary to distinguish usual cases
# the second subtyping/egal conditions are necessary to distinguish usual cases
# from rare cases when `Const` wrapped those extended lattice type objects
if isa(rettype_const, Vector{Any}) && !(Vector{Any} <: rettype)
rettype = PartialStruct(rettype, rettype_const)
elseif isa(rettype_const, PartialOpaque) && rettype <: Core.OpaqueClosure
rettype = rettype_const
elseif isa(rettype_const, InterConditional) && !(InterConditional <: rettype)
elseif isa(rettype_const, InterConditional) && rettype !== InterConditional
rettype = rettype_const
elseif isa(rettype_const, InterMustAlias) && rettype !== InterMustAlias
rettype = rettype_const
else
rettype = Const(rettype_const)
Expand Down
Loading

0 comments on commit 1d8f7e0

Please sign in to comment.