From c1637af77d2fe251ef7e67f560d5bb604c1bece1 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Mon, 29 Jul 2024 17:26:46 +0900 Subject: [PATCH 1/4] inference: model partially initialized structs with `PartialStruct` There is still room for improvement in the accuracy of `getfield` and `isdefined` for structs with uninitialized fields. This commit aims to enhance the accuracy of struct field defined-ness by propagating such struct as `PartialStruct` in cases where fields that might be uninitialized are confirmed to be defined. Specifically, the improvements are made in the following situations: 1. when a `:new` expression receives arguments greater than the minimum number of initialized fields. 2. when new information about the initialized fields of `x` can be obtained in the `then` branch of `if isdefined(x, :f)`. Combined with the existing optimizations, these improvements enable DCE in scenarios such as: ```julia julia> @noinline broadcast_noescape1(a) = (broadcast(identity, a); nothing); julia> @allocated broadcast_noescape1(Ref("x")) 16 # master 0 # this PR ``` One important point to note is that, as revealed in JuliaLang/julia#48999, fields and globals can revert to `undef` during precompilation. This commit does not affect globals. Furthermore, even for fields, the refinements made by 1. and 2. are propagated along with data-flow, and field defined-ness information is only used when fields are confirmed to be initialized. Therefore, the same issues as JuliaLang/julia#48999 will not occur by this commit. --- base/compiler/abstractinterpretation.jl | 73 +++++++++------ base/compiler/ssair/passes.jl | 7 +- base/compiler/tfuncs.jl | 39 ++++++-- .../compiler/EscapeAnalysis/EscapeAnalysis.jl | 10 +-- test/compiler/inference.jl | 89 +++++++++++++++++++ 5 files changed, 171 insertions(+), 47 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 90d395600bbdea..9a474a4cbd878f 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -2006,26 +2006,33 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs return Conditional(aty.slot, thentype, elsetype) end elseif f === isdefined - uty = argtypes[2] a = ssa_def_slot(fargs[2], sv) - if isa(uty, Union) && isa(a, SlotNumber) - fld = argtypes[3] - thentype = Bottom - elsetype = Bottom - for ty in uniontypes(uty) - cnd = isdefined_tfunc(𝕃ᡒ, ty, fld) - if isa(cnd, Const) - if cnd.val::Bool - thentype = thentype βŠ” ty + if isa(a, SlotNumber) + argtype2 = argtypes[2] + if isa(argtype2, Union) + fld = argtypes[3] + thentype = Bottom + elsetype = Bottom + for ty in uniontypes(argtype2) + cnd = isdefined_tfunc(𝕃ᡒ, ty, fld) + if isa(cnd, Const) + if cnd.val::Bool + thentype = thentype βŠ” ty + else + elsetype = elsetype βŠ” ty + end else + thentype = thentype βŠ” ty elsetype = elsetype βŠ” ty end - else - thentype = thentype βŠ” ty - elsetype = elsetype βŠ” ty + end + return Conditional(a, thentype, elsetype) + else + thentype = form_partially_defined_struct(argtype2, argtypes[3]) + if thentype !== nothing + return Conditional(a, thentype, argtype2) end end - return Conditional(a, thentype, elsetype) end end end @@ -2033,6 +2040,18 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs return rt end +function form_partially_defined_struct(@nospecialize(obj), @nospecialize(name)) + obj isa Const && return nothing # nothing to refine + name isa Const || return nothing + objt0 = widenconst(obj) + objt = unwrap_unionall(objt0) + isabstracttype(objt) && return nothing + fldidx = try_compute_fieldidx(objt, name.val) + fldidx === nothing && return nothing + fldidx ≀ datatype_min_ninitialized(objt) && return nothing + return PartialStruct(objt0, Any[fieldtype(objt0, i) for i = 1:fldidx]) +end + function abstract_call_unionall(interp::AbstractInterpreter, argtypes::Vector{Any}, call::CallMeta) na = length(argtypes) if isvarargtype(argtypes[end]) @@ -2573,20 +2592,18 @@ function abstract_eval_new(interp::AbstractInterpreter, e::Expr, vtypes::Union{V end ats[i] = at end - # For now, don't allow: - # - Const/PartialStruct of mutables (but still allow PartialStruct of mutables - # with `const` fields if anything refined) - # - partially initialized Const/PartialStruct - if fcount == nargs - if consistent === ALWAYS_TRUE && allconst - argvals = Vector{Any}(undef, nargs) - for j in 1:nargs - argvals[j] = (ats[j]::Const).val - end - rt = Const(ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), rt, argvals, nargs)) - elseif anyrefine - rt = PartialStruct(rt, ats) + if fcount == nargs && consistent === ALWAYS_TRUE && allconst + argvals = Vector{Any}(undef, nargs) + for j in 1:nargs + argvals[j] = (ats[j]::Const).val end + rt = Const(ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), rt, argvals, nargs)) + elseif anyrefine || nargs > datatype_min_ninitialized(rt) + # propagate partially initialized struct as `PartialStruct` when: + # - any refinement information is available (`anyrefine`), or when + # - `nargs` is greater than `n_initialized` derived from the struct type + # information alone + rt = PartialStruct(rt, ats) end else rt = refine_partial_type(rt) @@ -3094,7 +3111,7 @@ end @nospecializeinfer function widenreturn_partials(𝕃ᡒ::PartialsLattice, @nospecialize(rt), info::BestguessInfo) if isa(rt, PartialStruct) fields = copy(rt.fields) - local anyrefine = false + anyrefine = !isvarargtype(rt.fields[end]) && length(rt.fields) > datatype_min_ninitialized(rt.typ) 𝕃 = typeinf_lattice(info.interp) ⊏ = strictpartialorder(𝕃) for i in 1:length(fields) diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 33cda9bf27d202..37d79e2bd7b0cc 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -1166,7 +1166,12 @@ struct IntermediaryCollector <: WalkerCallback intermediaries::SPCSet end function (walker_callback::IntermediaryCollector)(@nospecialize(def), @nospecialize(defssa::AnySSAValue)) - isa(def, Expr) || push!(walker_callback.intermediaries, defssa.id) + if !(def isa Expr) + push!(walker_callback.intermediaries, defssa.id) + if def isa PiNode + return LiftedValue(def.val) + end + end return nothing end diff --git a/base/compiler/tfuncs.jl b/base/compiler/tfuncs.jl index 9a4c761b4209bc..e3cf031770c361 100644 --- a/base/compiler/tfuncs.jl +++ b/base/compiler/tfuncs.jl @@ -436,6 +436,15 @@ end if !ismutabletype(a1) || isconst(a1, idx) return Const(isdefined(arg1.val, idx)) end + elseif isa(arg1, PartialStruct) + nflds = length(arg1.fields) + if !isvarargtype(arg1.fields[end]) + if 1 ≀ idx ≀ nflds + return Const(true) + elseif !ismutabletype(a1) || isconst(a1, idx) + return Const(false) + end + end elseif !isvatuple(a1) fieldT = fieldtype(a1, idx) if isa(fieldT, DataType) && isbitstype(fieldT) @@ -989,27 +998,39 @@ end βŠ‘ = partialorder(𝕃) # If we have s00 being a const, we can potentially refine our type-based analysis above - if isa(s00, Const) || isconstType(s00) - if !isa(s00, Const) - sv = (s00::DataType).parameters[1] - else + if isa(s00, Const) || isconstType(s00) || isa(s00, PartialStruct) + if isa(s00, Const) sv = s00.val + sty = typeof(sv) + nflds = nfields(sv) + ismod = sv isa Module + elseif isa(s00, PartialStruct) + sty = s00.typ + nflds = fieldcount_noerror(sty) + ismod = false + else + sv = (s00::DataType).parameters[1] + sty = typeof(sv) + nflds = nfields(sv) + ismod = sv isa Module end if isa(name, Const) nval = name.val if !isa(nval, Symbol) - isa(sv, Module) && return false + ismod && return false isa(nval, Int) || return false end return isdefined_tfunc(𝕃, s00, name) === Const(true) end - boundscheck && return false + # If bounds checking is disabled and all fields are assigned, # we may assume that we don't throw - isa(sv, Module) && return false + @assert !boundscheck + ismod && return false name βŠ‘ Int || name βŠ‘ Symbol || return false - typeof(sv).name.n_uninitialized == 0 && return true - for i = (datatype_min_ninitialized(typeof(sv)) + 1):nfields(sv) + sty.name.n_uninitialized == 0 && return true + nflds === nothing && return false + for i = (datatype_min_ninitialized(sty)+1):nflds isdefined_tfunc(𝕃, s00, Const(i)) === Const(true) || return false end return true diff --git a/test/compiler/EscapeAnalysis/EscapeAnalysis.jl b/test/compiler/EscapeAnalysis/EscapeAnalysis.jl index d8ea8be21fe07b..31c21f72280140 100644 --- a/test/compiler/EscapeAnalysis/EscapeAnalysis.jl +++ b/test/compiler/EscapeAnalysis/EscapeAnalysis.jl @@ -2139,21 +2139,13 @@ end # ======================== # propagate escapes imposed on call arguments -@noinline broadcast_noescape1(a) = (broadcast(identity, a); nothing) -let result = code_escapes() do - broadcast_noescape1(Ref("Hi")) - end - i = only(findall(isnew, result.ir.stmts.stmt)) - @test !has_return_escape(result.state[SSAValue(i)]) - @test_broken !has_thrown_escape(result.state[SSAValue(i)]) # TODO `getfield(RefValue{String}, :x)` isn't safe -end @noinline broadcast_noescape2(b) = broadcast(identity, b) let result = code_escapes() do broadcast_noescape2(Ref("Hi")) end i = only(findall(isnew, result.ir.stmts.stmt)) @test_broken !has_return_escape(result.state[SSAValue(i)]) # TODO interprocedural alias analysis - @test_broken !has_thrown_escape(result.state[SSAValue(i)]) # TODO `getfield(RefValue{String}, :x)` isn't safe + @test !has_thrown_escape(result.state[SSAValue(i)]) end @noinline allescape_argument(a) = (global GV = a) # obvious escape let result = code_escapes() do diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 9ae98b884bef4c..cfe6712075b2f5 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -5867,6 +5867,95 @@ bar54341(args...) = foo54341(4, args...) @test Core.Compiler.return_type(bar54341, Tuple{Vararg{Int}}) === Int +# `PartialStruct` for partially initialized structs: +struct PartiallyInitialized1 + a; b; c + PartiallyInitialized1(a) = (@nospecialize; new(a)) + PartiallyInitialized1(a, b) = (@nospecialize; new(a, b)) + PartiallyInitialized1(a, b, c) = (@nospecialize; new(a, b, c)) +end +mutable struct PartiallyInitialized2 + a; b; c + PartiallyInitialized2(a) = (@nospecialize; new(a)) + PartiallyInitialized2(a, b) = (@nospecialize; new(a, b)) + PartiallyInitialized2(a, b, c) = (@nospecialize; new(a, b, c)) +end + +# 1. isdefined modeling for partial struct +@test Base.infer_return_type((Any,Any)) do a, b + Val(isdefined(PartiallyInitialized1(a, b), :b)) +end == Val{true} +@test Base.infer_return_type((Any,Any,)) do a, b + Val(isdefined(PartiallyInitialized1(a, b), :c)) +end == Val{false} +@test Base.infer_return_type((Any,Any,Any)) do a, b, c + Val(isdefined(PartiallyInitialized1(a, b, c), :c)) +end == Val{true} +@test Base.infer_return_type((Any,Any)) do a, b + Val(isdefined(PartiallyInitialized2(a, b), :b)) +end == Val{true} +@test Base.infer_return_type((Any,Any,)) do a, b + Val(isdefined(PartiallyInitialized2(a, b), :c)) +end >: Val{false} +@test Base.infer_return_type((Any,Any,Any)) do a, b, c + s = PartiallyInitialized2(a, b) + s.c = c + Val(isdefined(s, :c)) +end >: Val{true} +@test Base.infer_return_type((Any,Any,Any)) do a, b, c + Val(isdefined(PartiallyInitialized2(a, b, c), :c)) +end == Val{true} +@test Base.infer_return_type((Vector{Int},)) do xs + Val(isdefined(tuple(1, xs...), 1)) +end == Val{true} +@test Base.infer_return_type((Vector{Int},)) do xs + Val(isdefined(tuple(1, xs...), 2)) +end == Val + +# 2. getfield modeling for partial struct +@test Base.infer_effects((Any,Any); optimize=false) do a, b + getfield(PartiallyInitialized1(a, b), :b) +end |> Core.Compiler.is_nothrow +@test Base.infer_effects((Any,Any,Symbol,); optimize=false) do a, b, f + getfield(PartiallyInitialized1(a, b), f, #=boundscheck=#false) +end |> !Core.Compiler.is_nothrow +@test Base.infer_effects((Any,Any,Any); optimize=false) do a, b, c + getfield(PartiallyInitialized1(a, b, c), :c) +end |> Core.Compiler.is_nothrow +@test Base.infer_effects((Any,Any,Any,Symbol); optimize=false) do a, b, c, f + getfield(PartiallyInitialized1(a, b, c), f, #=boundscheck=#false) +end |> Core.Compiler.is_nothrow +@test Base.infer_effects((Any,Any); optimize=false) do a, b + getfield(PartiallyInitialized2(a, b), :b) +end |> Core.Compiler.is_nothrow +@test Base.infer_effects((Any,Any,Symbol,); optimize=false) do a, b, f + getfield(PartiallyInitialized2(a, b), f, #=boundscheck=#false) +end |> !Core.Compiler.is_nothrow +@test Base.infer_effects((Any,Any,Any); optimize=false) do a, b, c + getfield(PartiallyInitialized2(a, b, c), :c) +end |> Core.Compiler.is_nothrow +@test Base.infer_effects((Any,Any,Any,Symbol); optimize=false) do a, b, c, f + getfield(PartiallyInitialized2(a, b, c), f, #=boundscheck=#false) +end |> Core.Compiler.is_nothrow + +# isdefined-Conditionals +@test Base.infer_effects((Base.RefValue{Any},)) do x + if isdefined(x, :x) + return getfield(x, :x) + end +end |> Core.Compiler.is_nothrow +@test Base.infer_effects((Base.RefValue{Any},)) do x + if isassigned(x) + return x[] + end +end |> Core.Compiler.is_nothrow + +# End to end test case for the partially initialized struct with `PartialStruct` +@noinline broadcast_noescape1(a) = (broadcast(identity, a); nothing) +@test fully_eliminated() do + broadcast_noescape1(Ref("x")) +end + # InterConditional rt with Vararg argtypes fcondvarargs(a, b, c, d) = isa(d, Int64) gcondvarargs(a, x...) = return fcondvarargs(a, x...) ? isa(a, Int64) : !isa(a, Int64) From aea63e6984ee11fcdd1cb8d94de25e93f0569fcc Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Wed, 31 Jul 2024 01:06:19 +0900 Subject: [PATCH 2/4] add docstring to `PartialStruct` --- base/compiler/typelattice.jl | 60 ++++++++++++++++++++++++------------ 1 file changed, 41 insertions(+), 19 deletions(-) diff --git a/base/compiler/typelattice.jl b/base/compiler/typelattice.jl index 1be76f7d8bea3b..a84608ba72b1e1 100644 --- a/base/compiler/typelattice.jl +++ b/base/compiler/typelattice.jl @@ -6,17 +6,42 @@ # N.B.: Const/PartialStruct/InterConditional are defined in Core, to allow them to be used # inside the global code cache. -# -# # The type of a value might be constant -# struct Const -# val -# end -# -# struct PartialStruct -# typ -# fields::Vector{Any} # elements are other type lattice members -# end + import Core: Const, PartialStruct + +""" + struct Const + val + end + +The type representing a constant value. +""" +:(Const) + +""" + struct PartialStruct + typ + fields::Vector{Any} # elements are other type lattice members + end + +This extended lattice element is introduced when we have information about an object's +fields beyond what can be obtained from the object type. E.g. it represents a tuple where +some elements are known to be constants or a struct whose `Any`-typed field is initialized +with `Int` values. + +- `typ` indicates the type of the object +- `fields` holds the lattice elements corresponding to each field of the object + +If `typ` is a struct, `fields` represents the fields of the struct that are guaranteed to be +initialized. For instance, if the length of `fields` of `PartialStruct` representing a +struct with 4 fields is 3, the 4th field may be uninitialized. If the length is four, all +fields are guaranteed to be initialized. + +If `typ` is a tuple, the last element of `fields` may be `Vararg`. In this case, it is +guaranteed that the number of elements in the tuple is at least `length(fields)-1`, but the +exact number of elements is unknown. +""" +:(PartialStruct) function PartialStruct(@nospecialize(typ), fields::Vector{Any}) for i = 1:length(fields) assert_nested_slotwrapper(fields[i]) @@ -57,8 +82,13 @@ end Conditional(var::SlotNumber, @nospecialize(thentype), @nospecialize(elsetype)) = Conditional(slot_id(var), thentype, elsetype) +import Core: InterConditional """ - cnd::InterConditional + struct InterConditional + slot::Int + thentype + elsetype + end Similar to `Conditional`, but conveys inter-procedural constraints imposed on call arguments. This is separate from `Conditional` to catch logic errors: the lattice element name is `InterConditional` @@ -66,14 +96,6 @@ while processing a call, then `Conditional` everywhere else. Thus `InterConditio `CompilerTypes`β€”these type's usages are disjointβ€”though we define the lattice for `InterConditional`. """ :(InterConditional) -import Core: InterConditional -# struct InterConditional -# slot::Int -# thentype -# elsetype -# InterConditional(slot::Int, @nospecialize(thentype), @nospecialize(elsetype)) = -# new(slot, thentype, elsetype) -# end InterConditional(var::SlotNumber, @nospecialize(thentype), @nospecialize(elsetype)) = InterConditional(slot_id(var), thentype, elsetype) From 62384cef2f2640496570c72639a92f5dcd772066 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Fri, 2 Aug 2024 16:51:47 +0900 Subject: [PATCH 3/4] fix pkgeval --- base/compiler/abstractinterpretation.jl | 19 ++++++- base/compiler/tfuncs.jl | 13 ++--- base/compiler/typelattice.jl | 2 +- base/compiler/typelimits.jl | 34 +++++++---- test/compiler/inference.jl | 75 ++++++++++++++++--------- test/tuple.jl | 6 +- 6 files changed, 95 insertions(+), 54 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 9a474a4cbd878f..351f241878e7d8 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -2030,7 +2030,13 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs else thentype = form_partially_defined_struct(argtype2, argtypes[3]) if thentype !== nothing - return Conditional(a, thentype, argtype2) + elsetype = argtype2 + if rt === Const(false) + thentype = Bottom + elseif rt === Const(true) + elsetype = Bottom + end + return Conditional(a, thentype, elsetype) end end end @@ -2045,10 +2051,16 @@ function form_partially_defined_struct(@nospecialize(obj), @nospecialize(name)) name isa Const || return nothing objt0 = widenconst(obj) objt = unwrap_unionall(objt0) + objt isa DataType || return nothing isabstracttype(objt) && return nothing fldidx = try_compute_fieldidx(objt, name.val) fldidx === nothing && return nothing - fldidx ≀ datatype_min_ninitialized(objt) && return nothing + nminfld = datatype_min_ninitialized(objt) + if ismutabletype(objt) + fldidx == nminfld+1 || return nothing + else + fldidx > nminfld || return nothing + end return PartialStruct(objt0, Any[fieldtype(objt0, i) for i = 1:fldidx]) end @@ -3111,7 +3123,8 @@ end @nospecializeinfer function widenreturn_partials(𝕃ᡒ::PartialsLattice, @nospecialize(rt), info::BestguessInfo) if isa(rt, PartialStruct) fields = copy(rt.fields) - anyrefine = !isvarargtype(rt.fields[end]) && length(rt.fields) > datatype_min_ninitialized(rt.typ) + anyrefine = !isvarargtype(rt.fields[end]) && + length(rt.fields) > datatype_min_ninitialized(unwrap_unionall(rt.typ)) 𝕃 = typeinf_lattice(info.interp) ⊏ = strictpartialorder(𝕃) for i in 1:length(fields) diff --git a/base/compiler/tfuncs.jl b/base/compiler/tfuncs.jl index e3cf031770c361..89874b9a6df10a 100644 --- a/base/compiler/tfuncs.jl +++ b/base/compiler/tfuncs.jl @@ -419,7 +419,7 @@ end else return Bottom end - if 1 <= idx <= datatype_min_ninitialized(a1) + if 1 ≀ idx ≀ datatype_min_ninitialized(a1) return Const(true) elseif a1.name === _NAMEDTUPLE_NAME if isconcretetype(a1) @@ -427,22 +427,19 @@ end else ns = a1.parameters[1] if isa(ns, Tuple) - return Const(1 <= idx <= length(ns)) + return Const(1 ≀ idx ≀ length(ns)) end end - elseif idx <= 0 || (!isvatuple(a1) && idx > fieldcount(a1)) + elseif idx ≀ 0 || (!isvatuple(a1) && idx > fieldcount(a1)) return Const(false) elseif isa(arg1, Const) if !ismutabletype(a1) || isconst(a1, idx) return Const(isdefined(arg1.val, idx)) end elseif isa(arg1, PartialStruct) - nflds = length(arg1.fields) if !isvarargtype(arg1.fields[end]) - if 1 ≀ idx ≀ nflds + if 1 ≀ idx ≀ length(arg1.fields) return Const(true) - elseif !ismutabletype(a1) || isconst(a1, idx) - return Const(false) end end elseif !isvatuple(a1) @@ -1005,7 +1002,7 @@ end nflds = nfields(sv) ismod = sv isa Module elseif isa(s00, PartialStruct) - sty = s00.typ + sty = unwrap_unionall(s00.typ) nflds = fieldcount_noerror(sty) ismod = false else diff --git a/base/compiler/typelattice.jl b/base/compiler/typelattice.jl index a84608ba72b1e1..7565740338a1dc 100644 --- a/base/compiler/typelattice.jl +++ b/base/compiler/typelattice.jl @@ -34,7 +34,7 @@ with `Int` values. If `typ` is a struct, `fields` represents the fields of the struct that are guaranteed to be initialized. For instance, if the length of `fields` of `PartialStruct` representing a -struct with 4 fields is 3, the 4th field may be uninitialized. If the length is four, all +struct with 4 fields is 3, the 4th field may not be initialized. If the length is 4, all fields are guaranteed to be initialized. If `typ` is a tuple, the last element of `fields` may be `Vararg`. In this case, it is diff --git a/base/compiler/typelimits.jl b/base/compiler/typelimits.jl index 318ac0b5c27e58..1a1aa7a35e840e 100644 --- a/base/compiler/typelimits.jl +++ b/base/compiler/typelimits.jl @@ -328,6 +328,9 @@ const issimpleenoughtupleelem = issimpleenoughtype typea === typeb && return true if typea isa PartialStruct aty = widenconst(typea) + if length(typea.fields) > datatype_min_ninitialized(unwrap_unionall(aty)) + return false # TODO more accuracy here? + end for i = 1:length(typea.fields) ai = unwrapva(typea.fields[i]) bi = fieldtype(aty, i) @@ -572,34 +575,43 @@ end # N.B. This can also be called with both typea::Const and typeb::Const to # to recover PartialStruct from `Const`s with overlapping fields. -@nospecializeinfer function tmerge_partial_struct(lattice::PartialsLattice, @nospecialize(typea), @nospecialize(typeb)) +@nospecializeinfer function tmerge_partial_struct(𝕃::PartialsLattice, @nospecialize(typea), @nospecialize(typeb)) aty = widenconst(typea) bty = widenconst(typeb) if aty === bty # must have egal here, since we do not create PartialStruct for non-concrete types - typea_nfields = nfields_tfunc(lattice, typea) - typeb_nfields = nfields_tfunc(lattice, typeb) + typea_nfields = nfields_tfunc(𝕃, typea) + typeb_nfields = nfields_tfunc(𝕃, typeb) isa(typea_nfields, Const) || return nothing isa(typeb_nfields, Const) || return nothing type_nfields = typea_nfields.val::Int - type_nfields === typeb_nfields.val::Int || return nothing + type_nfields == typeb_nfields.val::Int || return nothing type_nfields == 0 && return nothing + if typea isa PartialStruct + if typeb isa PartialStruct + length(typea.fields) == length(typeb.fields) || return nothing + else + length(typea.fields) == type_nfields || return nothing + end + elseif typeb isa PartialStruct + length(typeb.fields) == type_nfields || return nothing + end fields = Vector{Any}(undef, type_nfields) anyrefine = false for i = 1:type_nfields - ai = getfield_tfunc(lattice, typea, Const(i)) - bi = getfield_tfunc(lattice, typeb, Const(i)) + ai = getfield_tfunc(𝕃, typea, Const(i)) + bi = getfield_tfunc(𝕃, typeb, Const(i)) # N.B.: We're assuming here that !isType(aty), because that case # only arises when typea === typeb, which should have been caught # before calling this. ft = fieldtype(aty, i) - if is_lattice_equal(lattice, ai, bi) || is_lattice_equal(lattice, ai, ft) + if is_lattice_equal(𝕃, ai, bi) || is_lattice_equal(𝕃, ai, ft) # Since ai===bi, the given type has no restrictions on complexity. # and can be used to refine ft tyi = ai - elseif is_lattice_equal(lattice, bi, ft) + elseif is_lattice_equal(𝕃, bi, ft) tyi = bi - elseif (tyiβ€² = tmerge_field(lattice, ai, bi); tyiβ€² !== nothing) + elseif (tyiβ€² = tmerge_field(𝕃, ai, bi); tyiβ€² !== nothing) # allow external lattice implementation to provide a custom field-merge strategy tyi = tyiβ€² else @@ -621,8 +633,8 @@ end end fields[i] = tyi if !anyrefine - anyrefine = has_nontrivial_extended_info(lattice, tyi) || # extended information - β‹€(lattice, tyi, ft) # just a type-level information, but more precise than the declared type + anyrefine = has_nontrivial_extended_info(𝕃, tyi) || # extended information + β‹€(𝕃, tyi, ft) # just a type-level information, but more precise than the declared type end end anyrefine && return PartialStruct(aty, fields) diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index cfe6712075b2f5..b5c1321c94eb37 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -1538,7 +1538,7 @@ let nfields_tfunc(@nospecialize xs...) = @test sizeof_nothrow(String) @test !sizeof_nothrow(Type{String}) @test sizeof_tfunc(Type{Union{Int64, Int32}}) == Const(Core.sizeof(Union{Int64, Int32})) - let PT = Core.Compiler.PartialStruct(Tuple{Int64,UInt64}, Any[Const(10), UInt64]) + let PT = Core.PartialStruct(Tuple{Int64,UInt64}, Any[Const(10), UInt64]) @test sizeof_tfunc(PT) === Const(16) @test nfields_tfunc(PT) === Const(2) @test sizeof_nothrow(PT) @@ -4743,32 +4743,40 @@ end # issue #43784 @testset "issue #43784" begin - init = Base.ImmutableDict{Any,Any}() - a = Const(init) - b = Core.PartialStruct(typeof(init), Any[Const(init), Any, Any]) - c = Core.Compiler.tmerge(a, b) - @test βŠ‘(a, c) - @test βŠ‘(b, c) - - init = Base.ImmutableDict{Number,Number}() - a = Const(init) - b = Core.Compiler.PartialStruct(typeof(init), Any[Const(init), Any, ComplexF64]) - c = Core.Compiler.tmerge(a, b) - @test βŠ‘(a, c) && βŠ‘(b, c) - @test c === typeof(init) - - a = Core.Compiler.PartialStruct(typeof(init), Any[Const(init), ComplexF64, ComplexF64]) - c = Core.Compiler.tmerge(a, b) - @test βŠ‘(a, c) && βŠ‘(b, c) - @test c.fields[2] === Any # or Number - @test c.fields[3] === ComplexF64 - - b = Core.Compiler.PartialStruct(typeof(init), Any[Const(init), ComplexF32, Union{ComplexF32,ComplexF64}]) - c = Core.Compiler.tmerge(a, b) - @test βŠ‘(a, c) - @test βŠ‘(b, c) - @test c.fields[2] === Complex - @test c.fields[3] === Complex + βŠ‘ = Core.Compiler.partialorder(Core.Compiler.fallback_lattice) + βŠ” = Core.Compiler.join(Core.Compiler.fallback_lattice) + Const, PartialStruct = Core.Const, Core.PartialStruct + + let init = Base.ImmutableDict{Any,Any}() + a = Const(init) + b = PartialStruct(typeof(init), Any[Const(init), Any, Any]) + c = a βŠ” b + @test a βŠ‘ c && b βŠ‘ c + @test c === typeof(init) + end + let init = Base.ImmutableDict{Number,Number}() + a = Const(init) + b = PartialStruct(typeof(init), Any[Const(init), Number, ComplexF64]) + c = a βŠ” b + @test a βŠ‘ c && b βŠ‘ c + @test c === typeof(init) + end + let init = Base.ImmutableDict{Number,Number}() + a = PartialStruct(typeof(init), Any[Const(init), ComplexF64, ComplexF64]) + b = PartialStruct(typeof(init), Any[Const(init), Number, ComplexF64]) + c = a βŠ” b + @test a βŠ‘ c && b βŠ‘ c + @test c.fields[2] === Number + @test c.fields[3] === ComplexF64 + end + let init = Base.ImmutableDict{Number,Number}() + a = PartialStruct(typeof(init), Any[Const(init), ComplexF64, ComplexF64]) + b = PartialStruct(typeof(init), Any[Const(init), ComplexF32, Union{ComplexF32,ComplexF64}]) + c = a βŠ” b + @test a βŠ‘ c && b βŠ‘ c + @test c.fields[2] === Complex + @test c.fields[3] === Complex + end global const ginit43784 = Base.ImmutableDict{Any,Any}() @test Base.return_types() do @@ -5887,7 +5895,11 @@ end end == Val{true} @test Base.infer_return_type((Any,Any,)) do a, b Val(isdefined(PartiallyInitialized1(a, b), :c)) -end == Val{false} +end >: Val{false} +@test Base.infer_return_type((PartiallyInitialized1,)) do x + @assert isdefined(x, :a) + return Val(isdefined(x, :c)) +end == Val @test Base.infer_return_type((Any,Any,Any)) do a, b, c Val(isdefined(PartiallyInitialized1(a, b, c), :c)) end == Val{true} @@ -5949,6 +5961,13 @@ end |> Core.Compiler.is_nothrow return x[] end end |> Core.Compiler.is_nothrow +@test Base.infer_effects((Any,Any); optimize=false) do a, c + x = PartiallyInitialized2(a) + x.c = c + if isdefined(x, :c) + return x.b + end +end |> !Core.Compiler.is_nothrow # End to end test case for the partially initialized struct with `PartialStruct` @noinline broadcast_noescape1(a) = (broadcast(identity, a); nothing) diff --git a/test/tuple.jl b/test/tuple.jl index b1894bd2bb6ce1..355ad965f95840 100644 --- a/test/tuple.jl +++ b/test/tuple.jl @@ -533,7 +533,7 @@ end @test ntuple(identity, Val(n)) == ntuple(identity, n) end - @test Core.Compiler.return_type(ntuple, Tuple{typeof(identity), Val}) == Tuple{Vararg{Int}} + @test Base.infer_return_type(ntuple, Tuple{typeof(identity), Val}) == Tuple{Vararg{Int}} end struct A_15703{N} @@ -835,8 +835,8 @@ end @test @inferred(Base.circshift(t3, 7)) == ('b', 'c', 'd', 'a') @test @inferred(Base.circshift(t3, -1)) == ('b', 'c', 'd', 'a') @test_throws MethodError circshift(t1, 'a') - @test Core.Compiler.return_type(circshift, Tuple{Tuple,Integer}) <: Tuple - @test Core.Compiler.return_type(circshift, Tuple{Tuple{Vararg{Any,10}},Integer}) <: Tuple{Vararg{Any,10}} + @test Base.infer_return_type(circshift, Tuple{Tuple,Integer}) <: Tuple + @test Base.infer_return_type(circshift, Tuple{Tuple{Vararg{Any,10}},Integer}) <: Tuple{Vararg{Any,10}} for len ∈ 0:5 v = 1:len t = Tuple(v) From 500a38007620ede6b7c65ec530a02cf69013b81b Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Fri, 16 Aug 2024 15:10:08 +0900 Subject: [PATCH 4/4] more accurate lattice implementation --- base/compiler/typelattice.jl | 14 +++++++++---- base/compiler/typelimits.jl | 36 +++++++++++++++++--------------- test/compiler/inference.jl | 40 ++++++++++++++++++++++++++++++++++++ 3 files changed, 70 insertions(+), 20 deletions(-) diff --git a/base/compiler/typelattice.jl b/base/compiler/typelattice.jl index 7565740338a1dc..f3b61116154c03 100644 --- a/base/compiler/typelattice.jl +++ b/base/compiler/typelattice.jl @@ -469,8 +469,13 @@ end @nospecializeinfer function βŠ‘(lattice::PartialsLattice, @nospecialize(a), @nospecialize(b)) if isa(a, PartialStruct) if isa(b, PartialStruct) - if !(length(a.fields) == length(b.fields) && a.typ <: b.typ) - return false + a.typ <: b.typ || return false + if length(a.fields) β‰  length(b.fields) + if !(isvarargtype(a.fields[end]) || isvarargtype(b.fields[end])) + length(a.fields) β‰₯ length(b.fields) || return false + else + return false + end end for i in 1:length(b.fields) af = a.fields[i] @@ -493,8 +498,7 @@ end return isa(b, Type) && a.typ <: b elseif isa(b, PartialStruct) if isa(a, Const) - nf = nfields(a.val) - nf == length(b.fields) || return false + n_initialized(a) β‰₯ length(b.fields) || return false widea = widenconst(a)::DataType wideb = widenconst(b) widebβ€² = unwrap_unionall(wideb)::DataType @@ -504,8 +508,10 @@ end if widebβ€².name !== Tuple.name && !(widea <: wideb) return false end + nf = nfields(a.val) for i in 1:nf isdefined(a.val, i) || continue # since βˆ€ T Union{} βŠ‘ T + i > length(b.fields) && break bfα΅’ = b.fields[i] if i == nf bfα΅’ = unwrapva(bfα΅’) diff --git a/base/compiler/typelimits.jl b/base/compiler/typelimits.jl index 1a1aa7a35e840e..f11d3181fa1530 100644 --- a/base/compiler/typelimits.jl +++ b/base/compiler/typelimits.jl @@ -321,6 +321,11 @@ end # even after complicated recursion and other operations on it elsewhere const issimpleenoughtupleelem = issimpleenoughtype +function n_initialized(t::Const) + nf = nfields(t.val) + return something(findfirst(i::Int->!isdefined(t.val,i), 1:nf), nf+1)-1 +end + # A simplified type_more_complex query over the extended lattice # (assumes typeb βŠ‘ typea) @nospecializeinfer function issimplertype(𝕃::AbstractLattice, @nospecialize(typea), @nospecialize(typeb)) @@ -328,8 +333,12 @@ const issimpleenoughtupleelem = issimpleenoughtype typea === typeb && return true if typea isa PartialStruct aty = widenconst(typea) - if length(typea.fields) > datatype_min_ninitialized(unwrap_unionall(aty)) - return false # TODO more accuracy here? + if typeb isa Const + @assert length(typea.fields) ≀ n_initialized(typeb) "typeb βŠ‘ typea is assumed" + elseif typeb isa PartialStruct + @assert length(typea.fields) ≀ length(typeb.fields) "typeb βŠ‘ typea is assumed" + else + return false end for i = 1:length(typea.fields) ai = unwrapva(typea.fields[i]) @@ -579,26 +588,21 @@ end aty = widenconst(typea) bty = widenconst(typeb) if aty === bty - # must have egal here, since we do not create PartialStruct for non-concrete types - typea_nfields = nfields_tfunc(𝕃, typea) - typeb_nfields = nfields_tfunc(𝕃, typeb) - isa(typea_nfields, Const) || return nothing - isa(typeb_nfields, Const) || return nothing - type_nfields = typea_nfields.val::Int - type_nfields == typeb_nfields.val::Int || return nothing - type_nfields == 0 && return nothing if typea isa PartialStruct if typeb isa PartialStruct - length(typea.fields) == length(typeb.fields) || return nothing + nflds = min(length(typea.fields), length(typeb.fields)) else - length(typea.fields) == type_nfields || return nothing + nflds = min(n_initialized(typeb::Const), length(typea.fields)) end elseif typeb isa PartialStruct - length(typeb.fields) == type_nfields || return nothing + nflds = min(n_initialized(typea::Const), length(typeb.fields)) + else + nflds = min(n_initialized(typea::Const), n_initialized(typeb::Const)) end - fields = Vector{Any}(undef, type_nfields) - anyrefine = false - for i = 1:type_nfields + nflds == 0 && return nothing + fields = Vector{Any}(undef, nflds) + anyrefine = nflds > datatype_min_ninitialized(unwrap_unionall(aty)) + for i = 1:nflds ai = getfield_tfunc(𝕃, typea, Const(i)) bi = getfield_tfunc(𝕃, typeb, Const(i)) # N.B.: We're assuming here that !isType(aty), because that case diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index b5c1321c94eb37..af2d5ddf19f6c6 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -4754,6 +4754,14 @@ end @test a βŠ‘ c && b βŠ‘ c @test c === typeof(init) end + let init = Base.ImmutableDict{Any,Any}(1,2) + a = Const(init) + b = PartialStruct(typeof(init), Any[Const(getfield(init,1)), Any, Any]) + c = a βŠ” b + @test a βŠ‘ c && b βŠ‘ c + @test c isa PartialStruct + @test length(c.fields) == 3 + end let init = Base.ImmutableDict{Number,Number}() a = Const(init) b = PartialStruct(typeof(init), Any[Const(init), Number, ComplexF64]) @@ -4766,6 +4774,7 @@ end b = PartialStruct(typeof(init), Any[Const(init), Number, ComplexF64]) c = a βŠ” b @test a βŠ‘ c && b βŠ‘ c + @test c isa PartialStruct @test c.fields[2] === Number @test c.fields[3] === ComplexF64 end @@ -4774,9 +4783,40 @@ end b = PartialStruct(typeof(init), Any[Const(init), ComplexF32, Union{ComplexF32,ComplexF64}]) c = a βŠ” b @test a βŠ‘ c && b βŠ‘ c + @test c isa PartialStruct @test c.fields[2] === Complex @test c.fields[3] === Complex end + let T = Base.ImmutableDict{Number,Number} + a = PartialStruct(T, Any[T]) + b = PartialStruct(T, Any[T, Number, Number]) + @test b βŠ‘ a + c = a βŠ” b + @test a βŠ‘ c && b βŠ‘ c + @test c isa PartialStruct + @test length(c.fields) == 1 + end + let T = Base.ImmutableDict{Number,Number} + a = PartialStruct(T, Any[T]) + b = Const(T()) + c = a βŠ” b + @test a βŠ‘ c && b βŠ‘ c + @test c === T + end + let T = Base.ImmutableDict{Number,Number} + a = Const(T()) + b = PartialStruct(T, Any[T]) + c = a βŠ” b + @test a βŠ‘ c && b βŠ‘ c + @test c === T + end + let T = Base.ImmutableDict{Number,Number} + a = Const(T()) + b = Const(T(1,2)) + c = a βŠ” b + @test a βŠ‘ c && b βŠ‘ c + @test c === T + end global const ginit43784 = Base.ImmutableDict{Any,Any}() @test Base.return_types() do