Skip to content

Commit

Permalink
Merge ff24278 into 4aa9dfa
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk authored Aug 18, 2024
2 parents 4aa9dfa + ff24278 commit b400192
Show file tree
Hide file tree
Showing 8 changed files with 359 additions and 122 deletions.
86 changes: 58 additions & 28 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2006,33 +2006,64 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs
return Conditional(aty.slot, thentype, elsetype)
end
elseif f === isdefined
uty = argtypes[2]
a = ssa_def_slot(fargs[2], sv)
if isa(uty, Union) && isa(a, SlotNumber)
fld = argtypes[3]
thentype = Bottom
elsetype = Bottom
for ty in uniontypes(uty)
cnd = isdefined_tfunc(𝕃ᵢ, ty, fld)
if isa(cnd, Const)
if cnd.val::Bool
thentype = thentype ty
if isa(a, SlotNumber)
argtype2 = argtypes[2]
if isa(argtype2, Union)
fld = argtypes[3]
thentype = Bottom
elsetype = Bottom
for ty in uniontypes(argtype2)
cnd = isdefined_tfunc(𝕃ᵢ, ty, fld)
if isa(cnd, Const)
if cnd.val::Bool
thentype = thentype ty
else
elsetype = elsetype ty
end
else
thentype = thentype ty
elsetype = elsetype ty
end
else
thentype = thentype ty
elsetype = elsetype ty
end
return Conditional(a, thentype, elsetype)
else
thentype = form_partially_defined_struct(argtype2, argtypes[3])
if thentype !== nothing
elsetype = argtype2
if rt === Const(false)
thentype = Bottom
elseif rt === Const(true)
elsetype = Bottom
end
return Conditional(a, thentype, elsetype)
end
end
return Conditional(a, thentype, elsetype)
end
end
end
@assert !isa(rt, TypeVar) "unhandled TypeVar"
return rt
end

function form_partially_defined_struct(@nospecialize(obj), @nospecialize(name))
obj isa Const && return nothing # nothing to refine
name isa Const || return nothing
objt0 = widenconst(obj)
objt = unwrap_unionall(objt0)
objt isa DataType || return nothing
isabstracttype(objt) && return nothing
fldidx = try_compute_fieldidx(objt, name.val)
fldidx === nothing && return nothing
nminfld = datatype_min_ninitialized(objt)
if ismutabletype(objt)
fldidx == nminfld+1 || return nothing
else
fldidx > nminfld || return nothing
end
return PartialStruct(objt0, Any[fieldtype(objt0, i) for i = 1:fldidx])
end

function abstract_call_unionall(interp::AbstractInterpreter, argtypes::Vector{Any}, call::CallMeta)
na = length(argtypes)
if isvarargtype(argtypes[end])
Expand Down Expand Up @@ -2573,20 +2604,18 @@ function abstract_eval_new(interp::AbstractInterpreter, e::Expr, vtypes::Union{V
end
ats[i] = at
end
# For now, don't allow:
# - Const/PartialStruct of mutables (but still allow PartialStruct of mutables
# with `const` fields if anything refined)
# - partially initialized Const/PartialStruct
if fcount == nargs
if consistent === ALWAYS_TRUE && allconst
argvals = Vector{Any}(undef, nargs)
for j in 1:nargs
argvals[j] = (ats[j]::Const).val
end
rt = Const(ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), rt, argvals, nargs))
elseif anyrefine
rt = PartialStruct(rt, ats)
if fcount == nargs && consistent === ALWAYS_TRUE && allconst
argvals = Vector{Any}(undef, nargs)
for j in 1:nargs
argvals[j] = (ats[j]::Const).val
end
rt = Const(ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), rt, argvals, nargs))
elseif anyrefine || nargs > datatype_min_ninitialized(rt)
# propagate partially initialized struct as `PartialStruct` when:
# - any refinement information is available (`anyrefine`), or when
# - `nargs` is greater than `n_initialized` derived from the struct type
# information alone
rt = PartialStruct(rt, ats)
end
else
rt = refine_partial_type(rt)
Expand Down Expand Up @@ -3094,7 +3123,8 @@ end
@nospecializeinfer function widenreturn_partials(𝕃ᵢ::PartialsLattice, @nospecialize(rt), info::BestguessInfo)
if isa(rt, PartialStruct)
fields = copy(rt.fields)
local anyrefine = false
anyrefine = !isvarargtype(rt.fields[end]) &&
length(rt.fields) > datatype_min_ninitialized(unwrap_unionall(rt.typ))
𝕃 = typeinf_lattice(info.interp)
= strictpartialorder(𝕃)
for i in 1:length(fields)
Expand Down
7 changes: 6 additions & 1 deletion base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1166,7 +1166,12 @@ struct IntermediaryCollector <: WalkerCallback
intermediaries::SPCSet
end
function (walker_callback::IntermediaryCollector)(@nospecialize(def), @nospecialize(defssa::AnySSAValue))
isa(def, Expr) || push!(walker_callback.intermediaries, defssa.id)
if !(def isa Expr)
push!(walker_callback.intermediaries, defssa.id)
if def isa PiNode
return LiftedValue(def.val)
end
end
return nothing
end

Expand Down
42 changes: 30 additions & 12 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -419,23 +419,29 @@ end
else
return Bottom
end
if 1 <= idx <= datatype_min_ninitialized(a1)
if 1 idx datatype_min_ninitialized(a1)
return Const(true)
elseif a1.name === _NAMEDTUPLE_NAME
if isconcretetype(a1)
return Const(false)
else
ns = a1.parameters[1]
if isa(ns, Tuple)
return Const(1 <= idx <= length(ns))
return Const(1 idx length(ns))
end
end
elseif idx <= 0 || (!isvatuple(a1) && idx > fieldcount(a1))
elseif idx 0 || (!isvatuple(a1) && idx > fieldcount(a1))
return Const(false)
elseif isa(arg1, Const)
if !ismutabletype(a1) || isconst(a1, idx)
return Const(isdefined(arg1.val, idx))
end
elseif isa(arg1, PartialStruct)
if !isvarargtype(arg1.fields[end])
if 1 idx length(arg1.fields)
return Const(true)
end
end
elseif !isvatuple(a1)
fieldT = fieldtype(a1, idx)
if isa(fieldT, DataType) && isbitstype(fieldT)
Expand Down Expand Up @@ -989,27 +995,39 @@ end
= partialorder(𝕃)

# If we have s00 being a const, we can potentially refine our type-based analysis above
if isa(s00, Const) || isconstType(s00)
if !isa(s00, Const)
sv = (s00::DataType).parameters[1]
else
if isa(s00, Const) || isconstType(s00) || isa(s00, PartialStruct)
if isa(s00, Const)
sv = s00.val
sty = typeof(sv)
nflds = nfields(sv)
ismod = sv isa Module
elseif isa(s00, PartialStruct)
sty = unwrap_unionall(s00.typ)
nflds = fieldcount_noerror(sty)
ismod = false
else
sv = (s00::DataType).parameters[1]
sty = typeof(sv)
nflds = nfields(sv)
ismod = sv isa Module
end
if isa(name, Const)
nval = name.val
if !isa(nval, Symbol)
isa(sv, Module) && return false
ismod && return false
isa(nval, Int) || return false
end
return isdefined_tfunc(𝕃, s00, name) === Const(true)
end
boundscheck && return false

# If bounds checking is disabled and all fields are assigned,
# we may assume that we don't throw
isa(sv, Module) && return false
@assert !boundscheck
ismod && return false
name Int || name Symbol || return false
typeof(sv).name.n_uninitialized == 0 && return true
for i = (datatype_min_ninitialized(typeof(sv)) + 1):nfields(sv)
sty.name.n_uninitialized == 0 && return true
nflds === nothing && return false
for i = (datatype_min_ninitialized(sty)+1):nflds
isdefined_tfunc(𝕃, s00, Const(i)) === Const(true) || return false
end
return true
Expand Down
74 changes: 51 additions & 23 deletions base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,42 @@

# N.B.: Const/PartialStruct/InterConditional are defined in Core, to allow them to be used
# inside the global code cache.
#
# # The type of a value might be constant
# struct Const
# val
# end
#
# struct PartialStruct
# typ
# fields::Vector{Any} # elements are other type lattice members
# end

import Core: Const, PartialStruct

"""
struct Const
val
end
The type representing a constant value.
"""
:(Const)

"""
struct PartialStruct
typ
fields::Vector{Any} # elements are other type lattice members
end
This extended lattice element is introduced when we have information about an object's
fields beyond what can be obtained from the object type. E.g. it represents a tuple where
some elements are known to be constants or a struct whose `Any`-typed field is initialized
with `Int` values.
- `typ` indicates the type of the object
- `fields` holds the lattice elements corresponding to each field of the object
If `typ` is a struct, `fields` represents the fields of the struct that are guaranteed to be
initialized. For instance, if the length of `fields` of `PartialStruct` representing a
struct with 4 fields is 3, the 4th field may not be initialized. If the length is 4, all
fields are guaranteed to be initialized.
If `typ` is a tuple, the last element of `fields` may be `Vararg`. In this case, it is
guaranteed that the number of elements in the tuple is at least `length(fields)-1`, but the
exact number of elements is unknown.
"""
:(PartialStruct)
function PartialStruct(@nospecialize(typ), fields::Vector{Any})
for i = 1:length(fields)
assert_nested_slotwrapper(fields[i])
Expand Down Expand Up @@ -57,23 +82,20 @@ end
Conditional(var::SlotNumber, @nospecialize(thentype), @nospecialize(elsetype)) =
Conditional(slot_id(var), thentype, elsetype)

import Core: InterConditional
"""
cnd::InterConditional
struct InterConditional
slot::Int
thentype
elsetype
end
Similar to `Conditional`, but conveys inter-procedural constraints imposed on call arguments.
This is separate from `Conditional` to catch logic errors: the lattice element name is `InterConditional`
while processing a call, then `Conditional` everywhere else. Thus `InterConditional` does not appear in
`CompilerTypes`—these type's usages are disjoint—though we define the lattice for `InterConditional`.
"""
:(InterConditional)
import Core: InterConditional
# struct InterConditional
# slot::Int
# thentype
# elsetype
# InterConditional(slot::Int, @nospecialize(thentype), @nospecialize(elsetype)) =
# new(slot, thentype, elsetype)
# end
InterConditional(var::SlotNumber, @nospecialize(thentype), @nospecialize(elsetype)) =
InterConditional(slot_id(var), thentype, elsetype)

Expand Down Expand Up @@ -447,8 +469,13 @@ end
@nospecializeinfer function (lattice::PartialsLattice, @nospecialize(a), @nospecialize(b))
if isa(a, PartialStruct)
if isa(b, PartialStruct)
if !(length(a.fields) == length(b.fields) && a.typ <: b.typ)
return false
a.typ <: b.typ || return false
if length(a.fields) length(b.fields)
if !(isvarargtype(a.fields[end]) || isvarargtype(b.fields[end]))
length(a.fields) length(b.fields) || return false
else
return false
end
end
for i in 1:length(b.fields)
af = a.fields[i]
Expand All @@ -471,8 +498,7 @@ end
return isa(b, Type) && a.typ <: b
elseif isa(b, PartialStruct)
if isa(a, Const)
nf = nfields(a.val)
nf == length(b.fields) || return false
n_initialized(a) length(b.fields) || return false
widea = widenconst(a)::DataType
wideb = widenconst(b)
wideb′ = unwrap_unionall(wideb)::DataType
Expand All @@ -482,8 +508,10 @@ end
if wideb′.name !== Tuple.name && !(widea <: wideb)
return false
end
nf = nfields(a.val)
for i in 1:nf
isdefined(a.val, i) || continue # since ∀ T Union{} ⊑ T
i > length(b.fields) && break
bfᵢ = b.fields[i]
if i == nf
bfᵢ = unwrapva(bfᵢ)
Expand Down
Loading

0 comments on commit b400192

Please sign in to comment.