From 0f7dbf7c40e22e4971cdd647f1b5c4c167c5f9ff Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki <40514306+aviatesk@users.noreply.github.com> Date: Wed, 17 Feb 2021 01:34:36 +0900 Subject: [PATCH] improve ~many~ some type stabilities in `Core.Compiler.typeinf` (#39549) All of them are detected by JET.jl's self-profiling. The following code will print type-instabilities/type-errors for all code paths reachable from `typeinf(::NativeInterpreter, ::InferenceState)`. ```julia julia> using JET julia> report_call(Core.Compiler.typeinf, (Core.Compiler.NativeInterpreter, Core.Compiler.InferenceState); annotate_types = true) ``` The remaining error reports (e.g. `variable Core.Compiler.string is not defined`) are because of missing functionality on error paths. (cherry picked from part of commit 1bc7f43c946269886983f165a5598e0c16adc63b) --- base/compiler/abstractinterpretation.jl | 36 ++++++++++++++----------- base/compiler/optimize.jl | 4 +-- base/compiler/ssair/driver.jl | 7 ++--- base/compiler/ssair/inlining.jl | 28 ++++++++++--------- base/compiler/ssair/ir.jl | 4 ++- base/compiler/ssair/slot2ssa.jl | 2 +- base/compiler/ssair/verify.jl | 4 +-- base/compiler/typeinfer.jl | 2 +- base/compiler/typelattice.jl | 6 ++--- base/compiler/typeutils.jl | 6 ++--- base/compiler/validation.jl | 2 +- 11 files changed, 56 insertions(+), 45 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index c8b4ff5052e47..2b64c1728f80c 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -434,7 +434,7 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp # Under direct self-recursion, permit much greater use of reducers. # here we assume that complexity(specTypes) :>= complexity(sig) comparison = sv.linfo.specTypes - l_comparison = length(unwrap_unionall(comparison).parameters) + l_comparison = length(unwrap_unionall(comparison).parameters)::Int spec_len = max(spec_len, l_comparison) else comparison = method.sig @@ -679,16 +679,20 @@ function abstract_apply(interp::AbstractInterpreter, @nospecialize(itft), @nospe res = Union{} nargs = length(aargtypes) splitunions = 1 < unionsplitcost(aargtypes) <= InferenceParams(interp).MAX_APPLY_UNION_ENUM - ctypes = Any[Any[aft]] + ctypes = [Any[aft]] infos = [Union{Nothing, AbstractIterationInfo}[]] for i = 1:nargs - ctypes´ = [] - infos′ = [] + ctypes´ = Vector{Any}[] + infos′ = Vector{Union{Nothing, AbstractIterationInfo}}[] for ti in (splitunions ? uniontypes(aargtypes[i]) : Any[aargtypes[i]]) if !isvarargtype(ti) - cti, info = precise_container_type(interp, itft, ti, sv) + cti_info = precise_container_type(interp, itft, ti, sv) + cti = cti_info[1]::Vector{Any} + info = cti_info[2]::Union{Nothing,AbstractIterationInfo} else - cti, info = precise_container_type(interp, itft, unwrapva(ti), sv) + cti_info = precise_container_type(interp, itft, unwrapva(ti), sv) + cti = cti_info[1]::Vector{Any} + info = cti_info[2]::Union{Nothing,AbstractIterationInfo} # We can't represent a repeating sequence of the same types, # so tmerge everything together to get one type that represents # everything. @@ -705,7 +709,7 @@ function abstract_apply(interp::AbstractInterpreter, @nospecialize(itft), @nospe continue end for j = 1:length(ctypes) - ct = ctypes[j] + ct = ctypes[j]::Vector{Any} if isvarargtype(ct[end]) # This is vararg, we're not gonna be able to do any inling, # drop the info @@ -1325,7 +1329,8 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) delete!(W, pc) frame.currpc = pc frame.cur_hand = frame.handler_at[pc] - frame.stmt_edges[pc] === nothing || empty!(frame.stmt_edges[pc]) + edges = frame.stmt_edges[pc] + edges === nothing || empty!(edges) stmt = frame.src.code[pc] changes = s[pc]::VarTable t = nothing @@ -1338,7 +1343,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) elseif isa(stmt, GotoNode) pc´ = (stmt::GotoNode).label elseif isa(stmt, GotoIfNot) - condt = abstract_eval_value(interp, stmt.cond, s[pc], frame) + condt = abstract_eval_value(interp, stmt.cond, changes, frame) if condt === Bottom empty!(frame.pclimitations) break @@ -1369,7 +1374,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) end end newstate_else = stupdate!(s[l], changes_else) - if newstate_else !== false + if newstate_else !== nothing # add else branch to active IP list if l < frame.pc´´ frame.pc´´ = l @@ -1380,7 +1385,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) end elseif isa(stmt, ReturnNode) pc´ = n + 1 - rt = widenconditional(abstract_eval_value(interp, stmt.val, s[pc], frame)) + rt = widenconditional(abstract_eval_value(interp, stmt.val, changes, frame)) if !isa(rt, Const) && !isa(rt, Type) && !isa(rt, PartialStruct) # only propagate information we know we can store # and is valid inter-procedurally @@ -1414,9 +1419,8 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) frame.cur_hand = Pair{Any,Any}(l, frame.cur_hand) # propagate type info to exception handler old = s[l] - new = s[pc]::Array{Any,1} - newstate_catch = stupdate!(old, new) - if newstate_catch !== false + newstate_catch = stupdate!(old, changes) + if newstate_catch !== nothing if l < frame.pc´´ frame.pc´´ = l end @@ -1483,12 +1487,12 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) # (such as a terminator for a loop, if-else, or try block), # consider whether we should jump to an older backedge first, # to try to traverse the statements in approximate dominator order - if newstate !== false + if newstate !== nothing s[pc´] = newstate end push!(W, pc´) pc = frame.pc´´ - elseif newstate !== false + elseif newstate !== nothing s[pc´] = newstate pc = pc´ elseif pc´ in W diff --git a/base/compiler/optimize.jl b/base/compiler/optimize.jl index 9d9bc45dc1e9f..f08b05dffe286 100644 --- a/base/compiler/optimize.jl +++ b/base/compiler/optimize.jl @@ -371,7 +371,7 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, sptypes::Vector{Any} end a = ex.args[2] if a isa Expr - cost = plus_saturate(cost, statement_cost(a, -1, src, sptypes, slottypes, params, error_path)) + cost = plus_saturate(cost, statement_cost(a, -1, src, sptypes, slottypes, union_penalties, params, error_path)) end return cost elseif head === :copyast @@ -392,7 +392,7 @@ function statement_or_branch_cost(@nospecialize(stmt), line::Int, src::CodeInfo, thiscost = 0 if stmt isa Expr thiscost = statement_cost(stmt, line, src, sptypes, slottypes, union_penalties, params, - params.unoptimize_throw_blocks && line in throw_blocks)::Int + throw_blocks !== nothing && line in throw_blocks)::Int elseif stmt isa GotoNode # loops are generally always expensive # but assume that forward jumps are already counted for from diff --git a/base/compiler/ssair/driver.jl b/base/compiler/ssair/driver.jl index 83205033342d6..7de0ceba6bee6 100644 --- a/base/compiler/ssair/driver.jl +++ b/base/compiler/ssair/driver.jl @@ -43,13 +43,14 @@ function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, narg labelmap = coverage ? fill(0, length(code)) : changemap prevloc = zero(eltype(ci.codelocs)) stmtinfo = sv.stmt_info + ssavaluetypes = ci.ssavaluetypes::Vector{Any} while idx <= length(code) codeloc = ci.codelocs[idx] if coverage && codeloc != prevloc && codeloc != 0 # insert a side-effect instruction before the current instruction in the same basic block insert!(code, idx, Expr(:code_coverage_effect)) insert!(ci.codelocs, idx, codeloc) - insert!(ci.ssavaluetypes, idx, Nothing) + insert!(ssavaluetypes, idx, Nothing) insert!(stmtinfo, idx, nothing) changemap[oldidx] += 1 if oldidx < length(labelmap) @@ -58,12 +59,12 @@ function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, narg idx += 1 prevloc = codeloc end - if code[idx] isa Expr && ci.ssavaluetypes[idx] === Union{} + if code[idx] isa Expr && ssavaluetypes[idx] === Union{} if !(idx < length(code) && isa(code[idx + 1], ReturnNode) && !isdefined((code[idx + 1]::ReturnNode), :val)) # insert unreachable in the same basic block after the current instruction (splitting it) insert!(code, idx + 1, ReturnNode()) insert!(ci.codelocs, idx + 1, ci.codelocs[idx]) - insert!(ci.ssavaluetypes, idx + 1, Union{}) + insert!(ssavaluetypes, idx + 1, Union{}) insert!(stmtinfo, idx + 1, nothing) if oldidx < length(changemap) changemap[oldidx + 1] += 1 diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index 0e95f812e5eb6..ae42b59bee28a 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -630,7 +630,7 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx:: call = thisarginfo.each[i] new_stmt = Expr(:call, argexprs[2], def, state...) state1 = insert_node!(ir, idx, call.rt, new_stmt) - new_sig = with_atype(call_sig(ir, new_stmt)) + new_sig = with_atype(call_sig(ir, new_stmt)::Signature) if isa(call.info, MethodMatchInfo) || isa(call.info, UnionSplitInfo) info = isa(call.info, MethodMatchInfo) ? MethodMatchInfo[call.info] : call.info.matches @@ -680,7 +680,7 @@ function resolve_todo(todo::InliningTodo, et::Union{EdgeTracker, Nothing}, cache spec = todo.spec::DelayedInliningSpec isconst, src = find_inferred(todo.mi, spec.atypes, caches, spec.stmttype) - if isconst + if isconst && et !== nothing push!(et, todo.mi) return ConstantCase(src) end @@ -988,9 +988,12 @@ function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, invoke_data::Invok sig.atype, method.sig)::SimpleVector methsp = methsp::SimpleVector match = MethodMatch(metharg, methsp, method, true) - result = analyze_method!(match, sig.atypes, state.et, state.caches, state.params, calltype) + et = state.et + result = analyze_method!(match, sig.atypes, et, state.caches, state.params, calltype) handle_single_case!(ir, stmt, idx, result, true, todo) - intersect!(state.et, WorldRange(invoke_data.min_valid, invoke_data.max_valid)) + if et !== nothing + intersect!(et, WorldRange(invoke_data.min_valid, invoke_data.max_valid)) + end return nothing end @@ -1118,6 +1121,7 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int sig.atype, only_method.sig)::SimpleVector match = MethodMatch(metharg, methsp, only_method, true) else + meth = meth::MethodLookupResult @assert length(meth) == 1 match = meth[1] end @@ -1145,6 +1149,8 @@ end function assemble_inline_todo!(ir::IRCode, state::InliningState) # todo = (inline_idx, (isva, isinvoke, na), method, spvals, inline_linetable, inline_ir, lie) todo = Pair{Int, Any}[] + et = state.et + method_table = state.method_table for idx in 1:length(ir.stmts) r = process_simple!(ir, todo, idx, state) r === nothing && continue @@ -1176,20 +1182,18 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState) nu = unionsplitcost(sig.atypes) if nu == 1 || nu > state.params.MAX_UNION_SPLITTING if !isa(info, MethodMatchInfo) - if state.method_table === nothing - continue - end - info = recompute_method_matches(sig.atype, state.params, state.et, state.method_table) + method_table === nothing && continue + et === nothing && continue + info = recompute_method_matches(sig.atype, state.params, et, method_table) end infos = MethodMatchInfo[info] else if !isa(info, UnionSplitInfo) - if state.method_table === nothing - continue - end + method_table === nothing && continue + et === nothing && continue infos = MethodMatchInfo[] for union_sig in UnionSplitSignature(sig.atypes) - push!(infos, recompute_method_matches(argtypes_to_type(union_sig), state.params, state.et, state.method_table)) + push!(infos, recompute_method_matches(argtypes_to_type(union_sig), state.params, et, method_table)) end else infos = info.matches diff --git a/base/compiler/ssair/ir.jl b/base/compiler/ssair/ir.jl index 3960ab44649b1..cc3b3e1ad245f 100644 --- a/base/compiler/ssair/ir.jl +++ b/base/compiler/ssair/ir.jl @@ -139,6 +139,7 @@ function compute_basic_blocks(stmts::Vector{Any}) return CFG(blocks, basic_block_index) end +# this function assumes insert position exists function first_insert_for_bb(code, cfg::CFG, block::Int) for idx in cfg.blocks[block].stmts stmt = code[idx] @@ -146,6 +147,7 @@ function first_insert_for_bb(code, cfg::CFG, block::Int) return idx end end + error("any insert position isn't found") end # SSA-indexed nodes @@ -890,7 +892,7 @@ function kill_edge!(compact::IncrementalCompact, active_bb::Int, from::Int, to:: # Check if the block is now dead if length(preds) == 0 for succ in copy(compact.result_bbs[compact.bb_rename_succ[to]].succs) - kill_edge!(compact, active_bb, to, findfirst(x->x === succ, compact.bb_rename_pred)) + kill_edge!(compact, active_bb, to, findfirst(x->x === succ, compact.bb_rename_pred)::Int) end if to < active_bb # Kill all statements in the block diff --git a/base/compiler/ssair/slot2ssa.jl b/base/compiler/ssair/slot2ssa.jl index 057bb72ff1152..46d727605bab4 100644 --- a/base/compiler/ssair/slot2ssa.jl +++ b/base/compiler/ssair/slot2ssa.jl @@ -764,7 +764,7 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, domtree::DomTree, defuse, narg # Having undef_token appear on the RHS is possible if we're on a dead branch. # Do something reasonable here, by marking the LHS as undef as well. if val !== undef_token - incoming_vals[id] = SSAValue(make_ssa!(ci, code, idx, id, typ)) + incoming_vals[id] = SSAValue(make_ssa!(ci, code, idx, id, typ)::Int) else code[idx] = nothing incoming_vals[id] = undef_token diff --git a/base/compiler/ssair/verify.jl b/base/compiler/ssair/verify.jl index 0365383a576f7..0f29b79b417e6 100644 --- a/base/compiler/ssair/verify.jl +++ b/base/compiler/ssair/verify.jl @@ -14,13 +14,13 @@ end function check_op(ir::IRCode, domtree::DomTree, @nospecialize(op), use_bb::Int, use_idx::Int, print::Bool) if isa(op, SSAValue) if op.id > length(ir.stmts) - def_bb = block_for_inst(ir.cfg, ir.new_nodes[op.id - length(ir.stmts)].pos) + def_bb = block_for_inst(ir.cfg, ir.new_nodes.info[op.id - length(ir.stmts)].pos) else def_bb = block_for_inst(ir.cfg, op.id) end if (def_bb == use_bb) if op.id > length(ir.stmts) - @assert ir.new_nodes[op.id - length(ir.stmts)].pos <= use_idx + @assert ir.new_nodes.info[op.id - length(ir.stmts)].pos <= use_idx else if op.id >= use_idx @verify_error "Def ($(op.id)) does not dominate use ($(use_idx)) in same BB" diff --git a/base/compiler/typeinfer.jl b/base/compiler/typeinfer.jl index 9a456660807da..510eec52ed623 100644 --- a/base/compiler/typeinfer.jl +++ b/base/compiler/typeinfer.jl @@ -712,8 +712,8 @@ function merge_call_chain!(parent::InferenceState, ancestor::InferenceState, chi add_cycle_backedge!(child, parent, parent.currpc) union_caller_cycle!(ancestor, child) child = parent - parent = child.parent child === ancestor && break + parent = child.parent::InferenceState end end diff --git a/base/compiler/typelattice.jl b/base/compiler/typelattice.jl index 5df8dfd411ba3..4f96a883b150c 100644 --- a/base/compiler/typelattice.jl +++ b/base/compiler/typelattice.jl @@ -299,7 +299,7 @@ function stupdate!(state::VarTable, changes::StateUpdate) if !isa(changes.var, Slot) return stupdate!(state, changes.state) end - newstate = false + newstate = nothing changeid = slot_id(changes.var::Slot) for i = 1:length(state) if i == changeid @@ -328,7 +328,7 @@ function stupdate!(state::VarTable, changes::StateUpdate) end function stupdate!(state::VarTable, changes::VarTable) - newstate = false + newstate = nothing for i = 1:length(state) newtype = changes[i] oldtype = state[i] @@ -342,7 +342,7 @@ end stupdate!(state::Nothing, changes::VarTable) = copy(changes) -stupdate!(state::Nothing, changes::Nothing) = false +stupdate!(state::Nothing, changes::Nothing) = nothing function stupdate1!(state::VarTable, change::StateUpdate) if !isa(change.var, Slot) diff --git a/base/compiler/typeutils.jl b/base/compiler/typeutils.jl index 954f14cfbfbbc..2600cfeafea24 100644 --- a/base/compiler/typeutils.jl +++ b/base/compiler/typeutils.jl @@ -187,17 +187,17 @@ end # unioncomplexity estimates the number of calls to `tmerge` to obtain the given type by # counting the Union instances, taking also into account those hidden in a Tuple or UnionAll function unioncomplexity(u::Union) - return unioncomplexity(u.a) + unioncomplexity(u.b) + 1 + return unioncomplexity(u.a)::Int + unioncomplexity(u.b)::Int + 1 end function unioncomplexity(t::DataType) t.name === Tuple.name || isvarargtype(t) || return 0 c = 0 for ti in t.parameters - c = max(c, unioncomplexity(ti)) + c = max(c, unioncomplexity(ti)::Int) end return c end -unioncomplexity(u::UnionAll) = max(unioncomplexity(u.body), unioncomplexity(u.var.ub)) +unioncomplexity(u::UnionAll) = max(unioncomplexity(u.body)::Int, unioncomplexity(u.var.ub)::Int) unioncomplexity(@nospecialize(x)) = 0 function improvable_via_constant_propagation(@nospecialize(t)) diff --git a/base/compiler/validation.jl b/base/compiler/validation.jl index df618d0033f60..4bf4447e39e94 100644 --- a/base/compiler/validation.jl +++ b/base/compiler/validation.jl @@ -1,7 +1,7 @@ # This file is a part of Julia. License is MIT: https://julialang.org/license # Expr head => argument count bounds -const VALID_EXPR_HEADS = IdDict{Any,Any}( +const VALID_EXPR_HEADS = IdDict{Symbol,UnitRange}( :call => 1:typemax(Int), :invoke => 2:typemax(Int), :static_parameter => 1:1,