Skip to content

Commit

Permalink
Extend PartialStruct to represent non-contiguously defined fields (#…
Browse files Browse the repository at this point in the history
…57304)

So far, `PartialStruct` has been unable to represent non-contiguously
defined fields, where e.g. a struct would have fields 1 and 3 defined
but not field 2. This PR extends it so that such information may be
represented with `PartialStruct`, extending the applicability of
optimizations e.g. introduced in #55297 by @aviatesk or #57222.

The semantics of `new` prevent the creation of a struct with
non-contiguously defined fields, therefore this change is mostly
relevant to model mutable structs whose fields may be previously set or
assumed to be defined after creation, or immutable structs whose
creation is opaque.

Notably, with this change we may now infer information about structs in
the following case:
```julia
mutable struct A; x; y; z; A() = new(); end

function f()
    mut = A()
   
    # some opaque call preventing optimizations
    # who knows, maybe `identity` will set fields from `mut` in a future world age!
    invokelatest(identity, mut)
   
    isdefined(mut, :z) && isdefined(mut, :x) || return
   
    isdefined(mut, :x) & isdefined(mut, :z) # this now infers as `true`
    isdefined(mut, :y) # this does not
end
```

whereas previously, only information gained successively with
`isdefined(mut, :x) && isdefined(mut, :y) && isdefined(mut, :z)` could
allow inference to model `mut` having its `z` field defined.

---------

Co-authored-by: Cédric Belmant <cedric.belmant@juliahub.com>
Co-authored-by: Shuhei Kadowaki <aviatesk@gmail.com>
  • Loading branch information
3 people authored Feb 25, 2025
1 parent f7b986d commit 58399e2
Show file tree
Hide file tree
Showing 13 changed files with 329 additions and 93 deletions.
3 changes: 2 additions & 1 deletion Compiler/src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ using Base: @_foldable_meta, @_gc_preserve_begin, @_gc_preserve_end, @nospeciali
partition_restriction, quoted, rename_unionall, rewrap_unionall, specialize_method,
structdiff, tls_world_age, unconstrain_vararg_length, unionlen, uniontype_layout,
uniontypes, unsafe_convert, unwrap_unionall, unwrapva, vect, widen_diagonal,
_uncompressed_ir, maybe_add_binding_backedge!
_uncompressed_ir, maybe_add_binding_backedge!, datatype_min_ninitialized,
partialstruct_undef_length, partialstruct_init_undef
using Base.Order

import Base: ==, _topmod, append!, convert, copy, copy!, findall, first, get, get!,
Expand Down
27 changes: 8 additions & 19 deletions Compiler/src/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2148,23 +2148,13 @@ function form_partially_defined_struct(@nospecialize(obj), @nospecialize(name))
isabstracttype(objt) && return nothing
fldidx = try_compute_fieldidx(objt, name.val)
fldidx === nothing && return nothing
isa(obj, PartialStruct) && return define_field(obj, fldidx)
nminfld = datatype_min_ninitialized(objt)
if ismutabletype(objt)
# A mutable struct can have non-contiguous undefined fields, but `PartialStruct` cannot
# model such a state. So here `PartialStruct` can be used to represent only the
# objects where the field following the minimum initialized fields is also defined.
if fldidx nminfld+1
# if it is already represented as a `PartialStruct`, we can add one more
# `isdefined`-field information on top of those implied by its `fields`
if !(obj isa PartialStruct && fldidx == length(obj.fields)+1)
return nothing
end
end
else
fldidx > nminfld || return nothing
end
return PartialStruct(fallback_lattice, objt0, Any[obj isa PartialStruct && ilength(obj.fields) ?
obj.fields[i] : fieldtype(objt0,i) for i = 1:fldidx])
fldidx > nminfld || return nothing
undef = partialstruct_init_undef(objt, fldidx; all_defined = false)
undef[fldidx] = false
fields = Any[fieldtype(objt0, i) for i = 1:fldidx]
return PartialStruct(fallback_lattice, objt0, undef, fields)
end

function abstract_call_unionall(interp::AbstractInterpreter, argtypes::Vector{Any}, call::CallMeta)
Expand Down Expand Up @@ -3725,8 +3715,7 @@ end
@nospecializeinfer function widenreturn_partials(𝕃ᵢ::PartialsLattice, @nospecialize(rt), info::BestguessInfo)
if isa(rt, PartialStruct)
fields = copy(rt.fields)
anyrefine = !isvarargtype(rt.fields[end]) &&
length(rt.fields) > datatype_min_ninitialized(rt.typ)
anyrefine = refines_definedness_information(rt)
𝕃 = typeinf_lattice(info.interp)
= strictpartialorder(𝕃)
for i in 1:length(fields)
Expand All @@ -3738,7 +3727,7 @@ end
end
fields[i] = a
end
anyrefine && return PartialStruct(𝕃ᵢ, rt.typ, fields)
anyrefine && return PartialStruct(𝕃ᵢ, rt.typ, rt.undef, fields)
end
if isa(rt, PartialOpaque)
return rt # XXX: this case was missed in #39512
Expand Down
6 changes: 3 additions & 3 deletions Compiler/src/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ end
end
elseif isa(arg1, PartialStruct)
if !isvarargtype(arg1.fields[end])
if 1 idx length(arg1.fields)
if !is_field_maybe_undef(arg1, idx)
return Const(true)
end
end
Expand Down Expand Up @@ -1141,8 +1141,8 @@ end
sty = unwrap_unionall(s)::DataType
if isa(name, Const)
nv = _getfield_fieldindex(sty, name)
if isa(nv, Int) && 1 <= nv <= length(s00.fields)
return unwrapva(s00.fields[nv])
if isa(nv, Int) && !is_field_maybe_undef(s00, nv)
return unwrapva(partialstruct_getfield(s00, nv))
end
end
s00 = s
Expand Down
7 changes: 4 additions & 3 deletions Compiler/src/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ function finishinfer!(me::InferenceState, interp::AbstractInterpreter)
rettype_const = result_type.parameters[1]
const_flags = 0x2
elseif isa(result_type, PartialStruct)
rettype_const = result_type.fields
rettype_const = (result_type.undef, result_type.fields)
const_flags = 0x2
elseif isa(result_type, InterConditional)
rettype_const = result_type
Expand Down Expand Up @@ -959,8 +959,9 @@ function cached_return_type(code::CodeInstance)
rettype_const = code.rettype_const
# the second subtyping/egal conditions are necessary to distinguish usual cases
# from rare cases when `Const` wrapped those extended lattice type objects
if isa(rettype_const, Vector{Any}) && !(Vector{Any} <: rettype)
return PartialStruct(fallback_lattice, rettype, rettype_const)
if isa(rettype_const, Tuple{BitVector, Vector{Any}}) && !(Tuple{BitVector, Vector{Any}} <: rettype)
undef, fields = rettype_const
return PartialStruct(fallback_lattice, rettype, undef, fields)
elseif isa(rettype_const, PartialOpaque) && rettype <: Core.OpaqueClosure
return rettype_const
elseif isa(rettype_const, InterConditional) && rettype !== InterConditional
Expand Down
46 changes: 30 additions & 16 deletions Compiler/src/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -318,15 +318,15 @@ end
fields = vartyp.fields
thenfields = thentype === Bottom ? nothing : copy(fields)
elsefields = elsetype === Bottom ? nothing : copy(fields)
for i in 1:length(fields)
if i == fldidx
thenfields === nothing || (thenfields[i] = thentype)
elsefields === nothing || (elsefields[i] = elsetype)
end
undef = copy(vartyp.undef)
if 1 fldidx length(fields)
thenfields === nothing || (thenfields[fldidx] = thentype)
elsefields === nothing || (elsefields[fldidx] = elsetype)
undef[fldidx] = false
end
return Conditional(slot,
thenfields === nothing ? Bottom : PartialStruct(fallback_lattice, vartyp.typ, thenfields),
elsefields === nothing ? Bottom : PartialStruct(fallback_lattice, vartyp.typ, elsefields))
thenfields === nothing ? Bottom : PartialStruct(fallback_lattice, vartyp.typ, undef, thenfields),
elsefields === nothing ? Bottom : PartialStruct(fallback_lattice, vartyp.typ, undef, elsefields))
else
vartyp_widened = widenconst(vartyp)
thenfields = thentype === Bottom ? nothing : Any[]
Expand Down Expand Up @@ -431,10 +431,14 @@ end
return false
end
end
for i in 1:length(b.fields)
af = a.fields[i]
bf = b.fields[i]
if i == length(b.fields)
na = length(a.fields)
nb = length(b.fields)
nmax = max(na, nb)
for i in 1:nmax
is_field_maybe_undef(a, i) is_field_maybe_undef(b, i) || return false
af = partialstruct_getfield(a, i)
bf = partialstruct_getfield(b, i)
if i == na || i == nb
if isvarargtype(af)
# If `af` is vararg, so must bf by the <: above
@assert isvarargtype(bf)
Expand Down Expand Up @@ -464,12 +468,15 @@ end
nfields(a.val) == length(b.fields) || return false
else
widea <: wideb || return false
# for structs we need to check that `a` has more information than `b` that may be partially initialized
n_initialized(a) length(b.fields) || return false
# for structs we need to check that `a` does not have less information than `b` that may be partially initialized
n_initialized(a) n_initialized(b) || return false
end
nf = nfields(a.val)
for i in 1:nf
isdefined(a.val, i) || continue # since ∀ T Union{} ⊑ T
if !isdefined(a.val, i)
is_field_maybe_undef(b, i) || return false # conflicting defined-ness information
continue # since ∀ T Union{} ⊑ T
end
i > length(b.fields) && break # `a` has more information than `b` that is partially initialized struct
bfᵢ = b.fields[i]
if i == nf
Expand Down Expand Up @@ -541,6 +548,7 @@ end
if isa(a, PartialStruct)
isa(b, PartialStruct) || return false
length(a.fields) == length(b.fields) || return false
a.undef == b.undef || return false
widenconst(a) == widenconst(b) || return false
a.fields === b.fields && return true # fast path
for i in 1:length(a.fields)
Expand Down Expand Up @@ -747,9 +755,15 @@ end
# The ::AbstractLattice argument is unused and simply serves to disambiguate
# different instances of the compiler that may share the `Core.PartialStruct`
# type.
function Core.PartialStruct(::AbstractLattice, @nospecialize(typ), fields::Vector{Any})

function Core.PartialStruct(𝕃::AbstractLattice, @nospecialize(typ), fields::Vector{Any}; all_defined::Bool = true)
undef = partialstruct_init_undef(typ, fields; all_defined)
return PartialStruct(𝕃, typ, undef, fields)
end

function Core.PartialStruct(::AbstractLattice, @nospecialize(typ), undef::BitVector, fields::Vector{Any})
for i = 1:length(fields)
assert_nested_slotwrapper(fields[i])
end
return Core._PartialStruct(typ, fields)
return PartialStruct(typ, undef, fields)
end
80 changes: 74 additions & 6 deletions Compiler/src/typelimits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -326,17 +326,74 @@ function n_initialized(t::Const)
return something(findfirst(i::Int->!isdefined(t.val,i), 1:nf), nf+1)-1
end

is_field_maybe_undef(t::Const, i) = !isdefined(t.val, i)

function n_initialized(pstruct::PartialStruct)
i = findfirst(pstruct.undef)
nmin = datatype_min_ninitialized(pstruct.typ)
i === nothing && return max(length(pstruct.undef), nmin)
n = i::Int - 1
@assert n nmin
n
end

function is_field_maybe_undef(pstruct::PartialStruct, fi)
fi 1 || return true
fi length(pstruct.undef) && return pstruct.undef[fi]
fi > datatype_min_ninitialized(pstruct.typ)
end

function partialstruct_getfield(pstruct::PartialStruct, fi::Integer)
@assert fi > 0
fi length(pstruct.fields) && return pstruct.fields[fi]
fieldtype(pstruct.typ, fi)
end

function refines_definedness_information(pstruct::PartialStruct)
nflds = length(pstruct.undef)
something(findfirst(pstruct.undef), nflds + 1) - 1 > datatype_min_ninitialized(pstruct.typ)
end

function define_field(pstruct::PartialStruct, fi::Int)
if !is_field_maybe_undef(pstruct, fi)
# no new information to be gained
return nothing
end

new = expand_partialstruct(pstruct, fi)
if new === nothing
new = PartialStruct(fallback_lattice, pstruct.typ, copy(pstruct.undef), copy(pstruct.fields))
end
new.undef[fi] = false
return new
end

function expand_partialstruct(pstruct::PartialStruct, until::Int)
n = length(pstruct.undef)
until n && return nothing

undef = partialstruct_init_undef(pstruct.typ, until; all_defined = false)
for i in 1:n
undef[i] &= pstruct.undef[i]
end
nf = length(pstruct.fields)
typ = pstruct.typ
fields = Any[i nf ? pstruct.fields[i] : fieldtype(typ, i) for i in 1:until]
return PartialStruct(fallback_lattice, typ, undef, fields)
end

# A simplified type_more_complex query over the extended lattice
# (assumes typeb ⊑ typea)
@nospecializeinfer function issimplertype(𝕃::AbstractLattice, @nospecialize(typea), @nospecialize(typeb))
@assert !isa(typea, LimitedAccuracy) && !isa(typeb, LimitedAccuracy) "LimitedAccuracy not supported by simplertype lattice" # n.b. the caller was supposed to handle these
typea === typeb && return true
if typea isa PartialStruct
aty = widenconst(typea)
if typeb isa Const
@assert length(typea.fields) n_initialized(typeb) "typeb ⊑ typea is assumed"
if typeb isa Const || typeb isa PartialStruct
@assert n_initialized(typea) n_initialized(typeb) "typeb ⊑ typea is assumed"
elseif typeb isa PartialStruct
@assert length(typea.fields) length(typeb.fields) "typeb ⊑ typea is assumed"
@assert n_initialized(typea) n_initialized(typeb) &&
all(b < a for (a, b) in zip(typea.undef, typeb.undef)) "typeb ⊑ typea is assumed"
else
return false
end
Expand Down Expand Up @@ -591,17 +648,24 @@ end
if typea isa PartialStruct
if typeb isa PartialStruct
nflds = min(length(typea.fields), length(typeb.fields))
nundef = nflds - (isvarargtype(typea.fields[end]) && isvarargtype(typeb.fields[end]))
else
nflds = min(length(typea.fields), n_initialized(typeb::Const))
nundef = nflds
end
elseif typeb isa PartialStruct
nflds = min(n_initialized(typea::Const), length(typeb.fields))
nundef = nflds
else
nflds = min(n_initialized(typea::Const), n_initialized(typeb::Const))
nundef = nflds
end
nflds == 0 && return nothing
_undef = partialstruct_init_undef(aty, nundef; all_defined = false)
fields = Vector{Any}(undef, nflds)
anyrefine = nflds > datatype_min_ninitialized(aty)
fldmin = datatype_min_ninitialized(aty)
n_initialized_merged = min(n_initialized(typea::Union{Const, PartialStruct}), n_initialized(typeb::Union{Const, PartialStruct}))
anyrefine = n_initialized_merged > fldmin
for i = 1:nflds
ai = getfield_tfunc(𝕃, typea, Const(i))
bi = getfield_tfunc(𝕃, typeb, Const(i))
Expand Down Expand Up @@ -633,12 +697,16 @@ end
end
end
fields[i] = tyi
if i nundef
_undef[i] = is_field_maybe_undef(typea, i) || is_field_maybe_undef(typeb, i)
end
if !anyrefine
anyrefine = has_nontrivial_extended_info(𝕃, tyi) || # extended information
(𝕃, tyi, ft) # just a type-level information, but more precise than the declared type
(𝕃, tyi, ft) || # just a type-level information, but more precise than the declared type
!get(_undef, i, true) && i > fldmin # possibly uninitialized field is known to be initialized
end
end
anyrefine && return PartialStruct(𝕃, aty, fields)
anyrefine && return PartialStruct(𝕃, aty, _undef, fields)
end
return nothing
end
Expand Down
33 changes: 0 additions & 33 deletions Compiler/src/typeutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,39 +61,6 @@ function isknownlength(t::DataType)
return isdefined(va, :N) && va.N isa Int
end

# Compute the minimum number of initialized fields for a particular datatype
# (therefore also a lower bound on the number of fields)
function datatype_min_ninitialized(@nospecialize t0)
t = unwrap_unionall(t0)
t isa DataType || return 0
isabstracttype(t) && return 0
if t.name === _NAMEDTUPLE_NAME
names, types = t.parameters[1], t.parameters[2]
if names isa Tuple
return length(names)
end
t = argument_datatype(types)
t isa DataType || return 0
t.name === Tuple.name || return 0
end
if t.name === Tuple.name
n = length(t.parameters)
n == 0 && return 0
va = t.parameters[n]
if isvarargtype(va)
n -= 1
if isdefined(va, :N)
va = va.N
if va isa Int
n += va
end
end
end
return n
end
return length(t.name.names) - t.name.n_uninitialized
end

has_concrete_subtype(d::DataType) = d.flags & 0x0020 == 0x0020 # n.b. often computed only after setting the type and layout fields

# determine whether x is a valid lattice element
Expand Down
Loading

0 comments on commit 58399e2

Please sign in to comment.