From d22ecaee820a396887ddf104a0559c134d789aa6 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Fri, 16 Aug 2024 15:10:08 +0900 Subject: [PATCH] more accurate lattice implementation --- base/compiler/typelattice.jl | 14 +++++++++---- base/compiler/typelimits.jl | 36 +++++++++++++++++--------------- test/compiler/inference.jl | 40 ++++++++++++++++++++++++++++++++++++ 3 files changed, 70 insertions(+), 20 deletions(-) diff --git a/base/compiler/typelattice.jl b/base/compiler/typelattice.jl index 7565740338a1dc..f3b61116154c03 100644 --- a/base/compiler/typelattice.jl +++ b/base/compiler/typelattice.jl @@ -469,8 +469,13 @@ end @nospecializeinfer function ⊑(lattice::PartialsLattice, @nospecialize(a), @nospecialize(b)) if isa(a, PartialStruct) if isa(b, PartialStruct) - if !(length(a.fields) == length(b.fields) && a.typ <: b.typ) - return false + a.typ <: b.typ || return false + if length(a.fields) ≠ length(b.fields) + if !(isvarargtype(a.fields[end]) || isvarargtype(b.fields[end])) + length(a.fields) ≥ length(b.fields) || return false + else + return false + end end for i in 1:length(b.fields) af = a.fields[i] @@ -493,8 +498,7 @@ end return isa(b, Type) && a.typ <: b elseif isa(b, PartialStruct) if isa(a, Const) - nf = nfields(a.val) - nf == length(b.fields) || return false + n_initialized(a) ≥ length(b.fields) || return false widea = widenconst(a)::DataType wideb = widenconst(b) wideb′ = unwrap_unionall(wideb)::DataType @@ -504,8 +508,10 @@ end if wideb′.name !== Tuple.name && !(widea <: wideb) return false end + nf = nfields(a.val) for i in 1:nf isdefined(a.val, i) || continue # since ∀ T Union{} ⊑ T + i > length(b.fields) && break bfᵢ = b.fields[i] if i == nf bfᵢ = unwrapva(bfᵢ) diff --git a/base/compiler/typelimits.jl b/base/compiler/typelimits.jl index 1a1aa7a35e840e..91a44d3b117abd 100644 --- a/base/compiler/typelimits.jl +++ b/base/compiler/typelimits.jl @@ -321,6 +321,11 @@ end # even after complicated recursion and other operations on it elsewhere const issimpleenoughtupleelem = issimpleenoughtype +function n_initialized(t::Const) + nf = nfields(t.val) + return something(findfirst(i::Int->!isdefined(t.val,i), 1:nf), nf+1)-1 +end + # A simplified type_more_complex query over the extended lattice # (assumes typeb ⊑ typea) @nospecializeinfer function issimplertype(𝕃::AbstractLattice, @nospecialize(typea), @nospecialize(typeb)) @@ -328,8 +333,12 @@ const issimpleenoughtupleelem = issimpleenoughtype typea === typeb && return true if typea isa PartialStruct aty = widenconst(typea) - if length(typea.fields) > datatype_min_ninitialized(unwrap_unionall(aty)) - return false # TODO more accuracy here? + if typeb isa Const + @assert length(typea.fields) ≤ n_initialized(typeb) "typeb ⊑ typea is assumed" + elseif typeb isa PartialStruct + @assert length(typea.fields) ≤ length(typeb.fields) "typeb ⊑ typea is assumed" + else + return false end for i = 1:length(typea.fields) ai = unwrapva(typea.fields[i]) @@ -579,26 +588,21 @@ end aty = widenconst(typea) bty = widenconst(typeb) if aty === bty - # must have egal here, since we do not create PartialStruct for non-concrete types - typea_nfields = nfields_tfunc(𝕃, typea) - typeb_nfields = nfields_tfunc(𝕃, typeb) - isa(typea_nfields, Const) || return nothing - isa(typeb_nfields, Const) || return nothing - type_nfields = typea_nfields.val::Int - type_nfields == typeb_nfields.val::Int || return nothing - type_nfields == 0 && return nothing if typea isa PartialStruct if typeb isa PartialStruct - length(typea.fields) == length(typeb.fields) || return nothing + nflds = min(length(typea.fields), length(typeb.fields)) else - length(typea.fields) == type_nfields || return nothing + nflds = min(length(typea.fields), n_initialized(typeb::Const)) end elseif typeb isa PartialStruct - length(typeb.fields) == type_nfields || return nothing + nflds = min(n_initialized(typea::Const), length(typeb.fields)) + else + nflds = min(n_initialized(typea::Const), n_initialized(typeb::Const)) end - fields = Vector{Any}(undef, type_nfields) - anyrefine = false - for i = 1:type_nfields + nflds == 0 && return nothing + fields = Vector{Any}(undef, nflds) + anyrefine = nflds > datatype_min_ninitialized(unwrap_unionall(aty)) + for i = 1:nflds ai = getfield_tfunc(𝕃, typea, Const(i)) bi = getfield_tfunc(𝕃, typeb, Const(i)) # N.B.: We're assuming here that !isType(aty), because that case diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index b5c1321c94eb37..af2d5ddf19f6c6 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -4754,6 +4754,14 @@ end @test a ⊑ c && b ⊑ c @test c === typeof(init) end + let init = Base.ImmutableDict{Any,Any}(1,2) + a = Const(init) + b = PartialStruct(typeof(init), Any[Const(getfield(init,1)), Any, Any]) + c = a ⊔ b + @test a ⊑ c && b ⊑ c + @test c isa PartialStruct + @test length(c.fields) == 3 + end let init = Base.ImmutableDict{Number,Number}() a = Const(init) b = PartialStruct(typeof(init), Any[Const(init), Number, ComplexF64]) @@ -4766,6 +4774,7 @@ end b = PartialStruct(typeof(init), Any[Const(init), Number, ComplexF64]) c = a ⊔ b @test a ⊑ c && b ⊑ c + @test c isa PartialStruct @test c.fields[2] === Number @test c.fields[3] === ComplexF64 end @@ -4774,9 +4783,40 @@ end b = PartialStruct(typeof(init), Any[Const(init), ComplexF32, Union{ComplexF32,ComplexF64}]) c = a ⊔ b @test a ⊑ c && b ⊑ c + @test c isa PartialStruct @test c.fields[2] === Complex @test c.fields[3] === Complex end + let T = Base.ImmutableDict{Number,Number} + a = PartialStruct(T, Any[T]) + b = PartialStruct(T, Any[T, Number, Number]) + @test b ⊑ a + c = a ⊔ b + @test a ⊑ c && b ⊑ c + @test c isa PartialStruct + @test length(c.fields) == 1 + end + let T = Base.ImmutableDict{Number,Number} + a = PartialStruct(T, Any[T]) + b = Const(T()) + c = a ⊔ b + @test a ⊑ c && b ⊑ c + @test c === T + end + let T = Base.ImmutableDict{Number,Number} + a = Const(T()) + b = PartialStruct(T, Any[T]) + c = a ⊔ b + @test a ⊑ c && b ⊑ c + @test c === T + end + let T = Base.ImmutableDict{Number,Number} + a = Const(T()) + b = Const(T(1,2)) + c = a ⊔ b + @test a ⊑ c && b ⊑ c + @test c === T + end global const ginit43784 = Base.ImmutableDict{Any,Any}() @test Base.return_types() do