diff --git a/base/compiler/optimize.jl b/base/compiler/optimize.jl index 23333b30cdce1..22f31ad1f3656 100644 --- a/base/compiler/optimize.jl +++ b/base/compiler/optimize.jl @@ -186,7 +186,7 @@ function stmt_affects_purity(@nospecialize(stmt), ir) return false end if isa(stmt, GotoIfNot) - t = argextype(stmt.cond, ir, ir.sptypes) + t = argextype(stmt.cond, ir) return !(t ⊑ Bool) end if isa(stmt, Expr) @@ -195,6 +195,127 @@ function stmt_affects_purity(@nospecialize(stmt), ir) return true end +""" + stmt_effect_free(stmt, rt, src::Union{IRCode,IncrementalCompact}) + +Determine whether a `stmt` is "side-effect-free", i.e. may be removed if it has no uses. +""" +function stmt_effect_free(@nospecialize(stmt), @nospecialize(rt), src::Union{IRCode,IncrementalCompact}) + isa(stmt, PiNode) && return true + isa(stmt, PhiNode) && return true + isa(stmt, ReturnNode) && return false + isa(stmt, GotoNode) && return false + isa(stmt, GotoIfNot) && return false + isa(stmt, Slot) && return false # Slots shouldn't occur in the IR at this point, but let's be defensive here + isa(stmt, GlobalRef) && return isdefined(stmt.mod, stmt.name) + if isa(stmt, Expr) + (; head, args) = stmt + if head === :static_parameter + etyp = (isa(src, IRCode) ? src.sptypes : src.ir.sptypes)[args[1]::Int] + # if we aren't certain enough about the type, it might be an UndefVarError at runtime + return isa(etyp, Const) + end + if head === :call + f = argextype(args[1], src) + f = singleton_type(f) + f === nothing && return false + is_return_type(f) && return true + if isa(f, IntrinsicFunction) + intrinsic_effect_free_if_nothrow(f) || return false + return intrinsic_nothrow(f, + Any[argextype(args[i], src) for i = 2:length(args)]) + end + contains_is(_PURE_BUILTINS, f) && return true + contains_is(_PURE_OR_ERROR_BUILTINS, f) || return false + rt === Bottom && return false + return _builtin_nothrow(f, Any[argextype(args[i], src) for i = 2:length(args)], rt) + elseif head === :new + typ = argextype(args[1], src) + # `Expr(:new)` of unknown type could raise arbitrary TypeError. + typ, isexact = instanceof_tfunc(typ) + isexact || return false + isconcretedispatch(typ) || return false + typ = typ::DataType + fieldcount(typ) >= length(args) - 1 || return false + for fld_idx in 1:(length(args) - 1) + eT = argextype(args[fld_idx + 1], src) + fT = fieldtype(typ, fld_idx) + eT ⊑ fT || return false + end + return true + elseif head === :new_opaque_closure + length(args) < 5 && return false + typ = argextype(args[1], src) + typ, isexact = instanceof_tfunc(typ) + isexact || return false + typ ⊑ Tuple || return false + isva = argextype(args[2], src) + rt_lb = argextype(args[3], src) + rt_ub = argextype(args[4], src) + src = argextype(args[5], src) + if !(isva ⊑ Bool && rt_lb ⊑ Type && rt_ub ⊑ Type && src ⊑ Method) + return false + end + return true + elseif head === :isdefined || head === :the_exception || head === :copyast || head === :inbounds || head === :boundscheck + return true + else + # e.g. :loopinfo + return false + end + end + return true +end + +""" + argextype(x, src::Union{IRCode,IncrementalCompact}) -> t + argextype(x, src::CodeInfo, sptypes::Vector{Any}) -> t + +Return the type of value `x` in the context of inferred source `src`. +Note that `t` might be an extended lattice element. +Use `widenconst(t)` to get the native Julia type of `x`. +""" +argextype(@nospecialize(x), ir::IRCode, sptypes::Vector{Any} = ir.sptypes) = + argextype(x, ir, sptypes, ir.argtypes) +function argextype(@nospecialize(x), compact::IncrementalCompact, sptypes::Vector{Any} = compact.ir.sptypes) + isa(x, AnySSAValue) && return types(compact)[x] + return argextype(x, compact, sptypes, compact.ir.argtypes) +end +argextype(@nospecialize(x), src::CodeInfo, sptypes::Vector{Any}) = argextype(x, src, sptypes, src.slottypes::Vector{Any}) +function argextype( + @nospecialize(x), src::Union{IRCode,IncrementalCompact,CodeInfo}, + sptypes::Vector{Any}, slottypes::Vector{Any}) + if isa(x, Expr) + if x.head === :static_parameter + return sptypes[x.args[1]::Int] + elseif x.head === :boundscheck + return Bool + elseif x.head === :copyast + return argextype(x.args[1], src, sptypes, slottypes) + end + @assert false "argextype only works on argument-position values" + elseif isa(x, SlotNumber) + return slottypes[x.id] + elseif isa(x, TypedSlot) + return x.typ + elseif isa(x, SSAValue) + return abstract_eval_ssavalue(x, src) + elseif isa(x, Argument) + return slottypes[x.n] + elseif isa(x, QuoteNode) + return Const(x.value) + elseif isa(x, GlobalRef) + return abstract_eval_global(x.mod, x.name) + elseif isa(x, PhiNode) + return Any + elseif isa(x, PiNode) + return x.typ + else + return Const(x) + end +end +abstract_eval_ssavalue(s::SSAValue, src::Union{IRCode,IncrementalCompact}) = types(src)[s] + # compute inlining cost and sideeffects function finish(interp::AbstractInterpreter, opt::OptimizationState, params::OptimizationParams, ir::IRCode, @nospecialize(result)) (; src, linfo) = opt @@ -214,7 +335,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState, params::Opt for i in 1:length(ir.stmts) node = ir.stmts[i] stmt = node[:inst] - if stmt_affects_purity(stmt, ir) && !stmt_effect_free(stmt, node[:type], ir, ir.sptypes) + if stmt_affects_purity(stmt, ir) && !stmt_effect_free(stmt, node[:type], ir) proven_pure = false break end @@ -432,20 +553,19 @@ plus_saturate(x::Int, y::Int) = max(x, y, x+y) isknowntype(@nospecialize T) = (T === Union{}) || isa(T, Const) || isconcretetype(widenconst(T)) function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptypes::Vector{Any}, - slottypes::Vector{Any}, union_penalties::Bool, - params::OptimizationParams, error_path::Bool = false) + union_penalties::Bool, params::OptimizationParams, error_path::Bool = false) head = ex.head if is_meta_expr_head(head) return 0 elseif head === :call farg = ex.args[1] - ftyp = argextype(farg, src, sptypes, slottypes) + ftyp = argextype(farg, src, sptypes) if ftyp === IntrinsicFunction && farg isa SSAValue # if this comes from code that was already inlined into another function, # Consts have been widened. try to recover in simple cases. farg = isa(src, CodeInfo) ? src.code[farg.id] : src.stmts[farg.id][:inst] if isa(farg, GlobalRef) || isa(farg, QuoteNode) || isa(farg, IntrinsicFunction) || isexpr(farg, :static_parameter) - ftyp = argextype(farg, src, sptypes, slottypes) + ftyp = argextype(farg, src, sptypes) end end f = singleton_type(ftyp) @@ -467,15 +587,15 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp # return plus_saturate(argcost, isknowntype(extyp) ? 1 : params.inline_nonleaf_penalty) return 0 elseif (f === Core.arrayref || f === Core.const_arrayref || f === Core.arrayset) && length(ex.args) >= 3 - atyp = argextype(ex.args[3], src, sptypes, slottypes) + atyp = argextype(ex.args[3], src, sptypes) return isknowntype(atyp) ? 4 : error_path ? params.inline_error_path_cost : params.inline_nonleaf_penalty - elseif f === typeassert && isconstType(widenconst(argextype(ex.args[3], src, sptypes, slottypes))) + elseif f === typeassert && isconstType(widenconst(argextype(ex.args[3], src, sptypes))) return 1 elseif f === Core.isa # If we're in a union context, we penalize type computations # on union types. In such cases, it is usually better to perform # union splitting on the outside. - if union_penalties && isa(argextype(ex.args[2], src, sptypes, slottypes), Union) + if union_penalties && isa(argextype(ex.args[2], src, sptypes), Union) return params.inline_nonleaf_penalty end end @@ -487,7 +607,7 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp end return T_FFUNC_COST[fidx] end - extyp = line == -1 ? Any : argextype(SSAValue(line), src, sptypes, slottypes) + extyp = line == -1 ? Any : argextype(SSAValue(line), src, sptypes) if extyp === Union{} return 0 end @@ -498,7 +618,7 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp # run-time of the function, we omit them from # consideration. This way, non-inlined error branches do not # prevent inlining. - extyp = line == -1 ? Any : argextype(SSAValue(line), src, sptypes, slottypes) + extyp = line == -1 ? Any : argextype(SSAValue(line), src, sptypes) return extyp === Union{} ? 0 : 20 elseif head === :(=) if ex.args[1] isa GlobalRef @@ -508,7 +628,7 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp end a = ex.args[2] if a isa Expr - cost = plus_saturate(cost, statement_cost(a, -1, src, sptypes, slottypes, union_penalties, params, error_path)) + cost = plus_saturate(cost, statement_cost(a, -1, src, sptypes, union_penalties, params, error_path)) end return cost elseif head === :copyast @@ -524,11 +644,11 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp end function statement_or_branch_cost(@nospecialize(stmt), line::Int, src::Union{CodeInfo, IRCode}, sptypes::Vector{Any}, - slottypes::Vector{Any}, union_penalties::Bool, params::OptimizationParams) + union_penalties::Bool, params::OptimizationParams) thiscost = 0 dst(tgt) = isa(src, IRCode) ? first(src.cfg.blocks[tgt].stmts) : tgt if stmt isa Expr - thiscost = statement_cost(stmt, line, src, sptypes, slottypes, union_penalties, params, + thiscost = statement_cost(stmt, line, src, sptypes, union_penalties, params, is_stmt_throw_block(isa(src, IRCode) ? src.stmts.flag[line] : src.ssaflags[line]))::Int elseif stmt isa GotoNode # loops are generally always expensive @@ -546,7 +666,7 @@ function inline_worthy(ir::IRCode, bodycost::Int = 0 for line = 1:length(ir.stmts) stmt = ir.stmts[line][:inst] - thiscost = statement_or_branch_cost(stmt, line, ir, ir.sptypes, ir.argtypes, union_penalties, params) + thiscost = statement_or_branch_cost(stmt, line, ir, ir.sptypes, union_penalties, params) bodycost = plus_saturate(bodycost, thiscost) bodycost > cost_threshold && return false end @@ -558,7 +678,6 @@ function statement_costs!(cost::Vector{Int}, body::Vector{Any}, src::Union{CodeI for line = 1:length(body) stmt = body[line] thiscost = statement_or_branch_cost(stmt, line, src, sptypes, - src isa CodeInfo ? src.slottypes : src.argtypes, unionpenalties, params) cost[line] = thiscost if thiscost > maxcost @@ -568,14 +687,6 @@ function statement_costs!(cost::Vector{Int}, body::Vector{Any}, src::Union{CodeI return maxcost end -function is_known_call(e::Expr, @nospecialize(func), src, sptypes::Vector{Any}, slottypes::Vector{Any} = EMPTY_SLOTTYPES) - if e.head !== :call - return false - end - f = argextype(e.args[1], src, sptypes, slottypes) - return isa(f, Const) && f.val === func -end - function renumber_ir_elements!(body::Vector{Any}, changemap::Vector{Int}) return renumber_ir_elements!(body, changemap, changemap) end diff --git a/base/compiler/ssair/driver.jl b/base/compiler/ssair/driver.jl index 9a6071766271e..e54a09fe351b3 100644 --- a/base/compiler/ssair/driver.jl +++ b/base/compiler/ssair/driver.jl @@ -14,7 +14,6 @@ include("compiler/ssair/basicblock.jl") include("compiler/ssair/domtree.jl") include("compiler/ssair/ir.jl") include("compiler/ssair/slot2ssa.jl") -include("compiler/ssair/queries.jl") include("compiler/ssair/passes.jl") include("compiler/ssair/inlining.jl") include("compiler/ssair/verify.jl") diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index 85bf5a64474f7..55445f5c8032b 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -371,7 +371,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector return_value = SSAValue(idx′) inline_compact[idx′] = val inline_compact.result[idx′][:type] = - compact_exprtype(isa(val, Argument) || isa(val, Expr) ? compact : inline_compact, val) + argextype(val, isa(val, Argument) || isa(val, Expr) ? compact : inline_compact) break end inline_compact[idx′] = stmt′ @@ -400,7 +400,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector if isa(val, GlobalRef) || isa(val, Expr) stmt′ = val inline_compact.result[idx′][:type] = - compact_exprtype(isa(val, Expr) ? compact : inline_compact, val) + argextype(val, isa(val, Expr) ? compact : inline_compact) insert_node_here!(inline_compact, NewInstruction(GotoNode(post_bb_id), Any, compact.result[idx′][:line]), true) @@ -435,7 +435,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector return_value = pn.values[1] else return_value = insert_node_here!(compact, - NewInstruction(pn, compact_exprtype(compact, SSAValue(idx)), compact.result[idx][:line])) + NewInstruction(pn, argextype(SSAValue(idx), compact), compact.result[idx][:line])) end end return_value @@ -580,7 +580,7 @@ function batch_inline!(todo::Vector{Pair{Int, Any}}, ir::IRCode, linetable::Vect for aidx in 1:length(argexprs) aexpr = argexprs[aidx] if isa(aexpr, Expr) || isa(aexpr, GlobalRef) - ninst = effect_free(NewInstruction(aexpr, compact_exprtype(compact, aexpr), compact.result[idx][:line])) + ninst = effect_free(NewInstruction(aexpr, argextype(aexpr, compact), compact.result[idx][:line])) argexprs[aidx] = insert_node_here!(compact, ninst) end end @@ -886,7 +886,7 @@ function inline_splatnew!(ir::IRCode, idx::Int, stmt::Expr, @nospecialize(rt)) if nf isa Const eargs = stmt.args tup = eargs[2] - tt = argextype(tup, ir, ir.sptypes) + tt = argextype(tup, ir) tnf = nfields_tfunc(tt) # TODO: hoisting this tnf.val === nf.val check into codegen # would enable us to almost always do this transform @@ -908,7 +908,7 @@ end function call_sig(ir::IRCode, stmt::Expr) isempty(stmt.args) && return nothing - ft = argextype(stmt.args[1], ir, ir.sptypes) + ft = argextype(stmt.args[1], ir) has_free_typevars(ft) && return nothing f = singleton_type(ft) f === Core.Intrinsics.llvmcall && return nothing @@ -916,7 +916,7 @@ function call_sig(ir::IRCode, stmt::Expr) argtypes = Vector{Any}(undef, length(stmt.args)) argtypes[1] = ft for i = 2:length(stmt.args) - a = argextype(stmt.args[i], ir, ir.sptypes) + a = argextype(stmt.args[i], ir) (a === Bottom || isvarargtype(a)) && return nothing argtypes[i] = a end @@ -1025,10 +1025,10 @@ end function narrow_opaque_closure!(ir::IRCode, stmt::Expr, @nospecialize(info), state::InliningState) if isa(info, OpaqueClosureCreateInfo) - lbt = argextype(stmt.args[3], ir, ir.sptypes) + lbt = argextype(stmt.args[3], ir) lb, exact = instanceof_tfunc(lbt) exact || return - ubt = argextype(stmt.args[4], ir, ir.sptypes) + ubt = argextype(stmt.args[4], ir) ub, exact = instanceof_tfunc(ubt) exact || return # Narrow opaque closure type @@ -1046,7 +1046,7 @@ end # For primitives, we do that right here. For proper calls, we will # discover this when we consult the caches. function check_effect_free!(ir::IRCode, idx::Int, @nospecialize(stmt), @nospecialize(rt)) - if stmt_effect_free(stmt, rt, ir, ir.sptypes) + if stmt_effect_free(stmt, rt, ir) ir.stmts[idx][:flag] |= IR_FLAG_EFFECT_FREE end end @@ -1346,7 +1346,7 @@ end function mk_tuplecall!(compact::IncrementalCompact, args::Vector{Any}, line_idx::Int32) e = Expr(:call, TOP_TUPLE, args...) - etyp = tuple_tfunc(Any[compact_exprtype(compact, args[i]) for i in 1:length(args)]) + etyp = tuple_tfunc(Any[argextype(args[i], compact) for i in 1:length(args)]) return insert_node_here!(compact, NewInstruction(e, etyp, line_idx)) end diff --git a/base/compiler/ssair/ir.jl b/base/compiler/ssair/ir.jl index 3838f8d6ec6ab..35f976756dcdd 100644 --- a/base/compiler/ssair/ir.jl +++ b/base/compiler/ssair/ir.jl @@ -520,7 +520,7 @@ function insert_node!(ir::IRCode, pos::Int, inst::NewInstruction, attach_after:: node[:line] = something(inst.line, ir.stmts[pos][:line]) flag = inst.flag if !inst.effect_free_computed - if stmt_effect_free(inst.stmt, inst.type, ir, ir.sptypes) + if stmt_effect_free(inst.stmt, inst.type, ir) flag |= IR_FLAG_EFFECT_FREE end end @@ -765,7 +765,7 @@ function insert_node_here!(compact::IncrementalCompact, inst::NewInstruction, re resize!(compact, result_idx) end flag = inst.flag - if !inst.effect_free_computed && stmt_effect_free(inst.stmt, inst.type, compact, compact.ir.sptypes) + if !inst.effect_free_computed && stmt_effect_free(inst.stmt, inst.type, compact) flag |= IR_FLAG_EFFECT_FREE end node = compact.result[result_idx] @@ -1316,7 +1316,7 @@ function maybe_erase_unused!( callback = null_dce_callback) stmt = compact.result[idx][:inst] stmt === nothing && return false - if compact_exprtype(compact, SSAValue(idx)) === Bottom + if argextype(SSAValue(idx), compact) === Bottom effect_free = false else effect_free = compact.result[idx][:flag] & IR_FLAG_EFFECT_FREE != 0 @@ -1466,8 +1466,3 @@ function iterate(x::BBIdxIter, (idx, bb)::Tuple{Int, Int}=(1, 1)) end return (bb, idx), (idx + 1, next_bb) end - -is_known_call(e::Expr, @nospecialize(func), ir::IRCode) = - is_known_call(e, func, ir, ir.sptypes, ir.argtypes) - -argextype(@nospecialize(x), ir::IRCode) = argextype(x, ir, ir.sptypes, ir.argtypes) diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 4c79bc9d47c63..000bb1849edea 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -1,5 +1,11 @@ # This file is a part of Julia. License is MIT: https://julialang.org/license +function is_known_call(@nospecialize(x), @nospecialize(func), ir::Union{IRCode,IncrementalCompact}) + isexpr(x, :call) || return false + ft = argextype(x.args[1], ir) + return singleton_type(ft) === func +end + """ du::SSADefUse @@ -31,7 +37,7 @@ function try_compute_field_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr elseif isa(field, Int) # try to resolve other constants, e.g. global reference else - field = isa(ir, IncrementalCompact) ? compact_exprtype(ir, field) : argextype(field, ir) + field = argextype(field, ir) if isa(field, Const) field = field.val else @@ -242,7 +248,7 @@ function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospe if is_old(compact, defssa) && isa(val, SSAValue) val = OldSSAValue(val.id) end - edge_typ = widenconst(compact_exprtype(compact, val)) + edge_typ = widenconst(argextype(val, compact)) hasintersect(edge_typ, typeconstraint) || continue push!(possible_predecessors, n) end @@ -286,9 +292,9 @@ function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospe return leaves, visited_phinodes end -function record_immutable_preserve!(new_preserves::Vector{Any}, compact::IncrementalCompact, def::Expr) +function record_immutable_preserve!(new_preserves::Vector{Any}, def::Expr, compact::IncrementalCompact) for arg in (isexpr(def, :new) ? def.args : def.args[2:end]) - if !isbitstype(widenconst(compact_exprtype(compact, arg))) + if !isbitstype(widenconst(argextype(arg, compact))) push!(new_preserves, arg) end end @@ -316,10 +322,10 @@ function is_getfield_captures(@nospecialize(def), compact::IncrementalCompact) isa(def, Expr) || return false length(def.args) >= 3 || return false is_known_call(def, getfield, compact) || return false - which = compact_exprtype(compact, def.args[3]) + which = argextype(def.args[3], compact) isa(which, Const) || return false which.val === :captures || return false - oc = compact_exprtype(compact, def.args[2]) + oc = argextype(def.args[2], compact) return oc ⊑ Core.OpaqueClosure end @@ -340,7 +346,7 @@ function lift_leaves(compact::IncrementalCompact, cache_key = leaf if isa(leaf, AnySSAValue) (def, leaf) = walk_to_def(compact, leaf) - if is_tuple_call(compact, def) && 1 ≤ field < length(def.args) + if is_known_call(def, tuple, compact) && 1 ≤ field < length(def.args) lift_arg!(compact, leaf, cache_key, def, 1+field, lifted_leaves) continue elseif isexpr(def, :new) @@ -388,7 +394,7 @@ function lift_leaves(compact::IncrementalCompact, end return nothing else - typ = compact_exprtype(compact, leaf) + typ = argextype(leaf, compact) if !isa(typ, Const) # TODO: (disabled since #27126) # If the leaf is an old ssa value, insert a getfield here @@ -431,7 +437,7 @@ function lift_arg!( lifted = OldSSAValue(lifted.id) end if isa(lifted, GlobalRef) || isa(lifted, Expr) - lifted = insert_node!(compact, leaf, effect_free(NewInstruction(lifted, compact_exprtype(compact, lifted)))) + lifted = insert_node!(compact, leaf, effect_free(NewInstruction(lifted, argextype(lifted, compact)))) stmt.args[argidx] = lifted if isa(leaf, SSAValue) && leaf.id < compact.result_idx push!(compact.late_fixup, leaf.id) @@ -481,8 +487,8 @@ function lift_comparison!(compact::IncrementalCompact, length(args) == 3 || return lhs, rhs = args[2], args[3] - vl = compact_exprtype(compact, lhs) - vr = compact_exprtype(compact, rhs) + vl = argextype(lhs, compact) + vr = argextype(rhs, compact) if isa(vl, Const) isa(vr, Const) && return cmp = vl @@ -496,7 +502,7 @@ function lift_comparison!(compact::IncrementalCompact, return end - valtyp = widenconst(compact_exprtype(compact, val)) + valtyp = widenconst(argextype(val, compact)) isa(valtyp, Union) || return # bail out if there won't be a good chance for lifting leaves, visited_phinodes = collect_leaves(compact, val, valtyp) @@ -505,7 +511,7 @@ function lift_comparison!(compact::IncrementalCompact, # Let's check if we evaluate the comparison for each one of the leaves lifted_leaves = nothing for leaf in leaves - r = egal_tfunc(compact_exprtype(compact, leaf), cmp) + r = egal_tfunc(argextype(leaf, compact), cmp) if isa(r, Const) if lifted_leaves === nothing lifted_leaves = LiftedLeaves() @@ -646,14 +652,14 @@ function sroa_pass!(ir::IRCode) 4 <= length(stmt.args) <= 5 || continue is_setfield = true if length(stmt.args) == 5 - field_ordering = compact_exprtype(compact, stmt.args[5]) + field_ordering = argextype(stmt.args[5], compact) end elseif is_known_call(stmt, getfield, compact) 3 <= length(stmt.args) <= 5 || continue if length(stmt.args) == 5 - field_ordering = compact_exprtype(compact, stmt.args[5]) + field_ordering = argextype(stmt.args[5], compact) elseif length(stmt.args) == 4 - field_ordering = compact_exprtype(compact, stmt.args[4]) + field_ordering = argextype(stmt.args[4], compact) widenconst(field_ordering) === Bool && (field_ordering = :unspecified) end elseif isexpr(stmt, :foreigncall) @@ -672,17 +678,17 @@ function sroa_pass!(ir::IRCode) isa(def, SSAValue) || continue defidx = def.id def = compact[defidx] - if is_tuple_call(compact, def) - record_immutable_preserve!(new_preserves, compact, def) + if is_known_call(def, tuple, compact) + record_immutable_preserve!(new_preserves, def, compact) push!(preserved, preserved_arg.id) continue elseif isexpr(def, :new) - typ = widenconst(compact_exprtype(compact, SSAValue(defidx))) + typ = widenconst(argextype(SSAValue(defidx), compact)) if isa(typ, UnionAll) typ = unwrap_unionall(typ) end if typ isa DataType && !ismutabletype(typ) - record_immutable_preserve!(new_preserves, compact, def) + record_immutable_preserve!(new_preserves, def, compact) push!(preserved, preserved_arg.id) continue end @@ -722,7 +728,7 @@ function sroa_pass!(ir::IRCode) val = stmt.args[2] - struct_typ = unwrap_unionall(widenconst(compact_exprtype(compact, val))) + struct_typ = unwrap_unionall(widenconst(argextype(val, compact))) if isa(struct_typ, Union) && struct_typ <: Tuple struct_typ = unswitchtupleunion(struct_typ) end @@ -768,7 +774,7 @@ function sroa_pass!(ir::IRCode) leaves, visited_phinodes = collect_leaves(compact, val, struct_typ) isempty(leaves) && continue - result_t = compact_exprtype(compact, SSAValue(idx)) + result_t = argextype(SSAValue(idx), compact) lifted_result = lift_leaves(compact, result_t, field, leaves) lifted_result === nothing && continue lifted_leaves, any_undef = lifted_result @@ -1032,13 +1038,11 @@ function adce_pass!(ir::IRCode) for ((_, idx), stmt) in compact if isa(stmt, PhiNode) push!(all_phis, idx) - elseif isexpr(stmt, :call) + elseif is_known_call(stmt, typeassert, compact) && length(stmt.args) == 3 # nullify safe `typeassert` calls - if is_known_call(stmt, typeassert, compact) && length(stmt.args) == 3 - ty, isexact = instanceof_tfunc(compact_exprtype(compact, stmt.args[3])) - if isexact && compact_exprtype(compact, stmt.args[2]) ⊑ ty - compact[idx] = nothing - end + ty, isexact = instanceof_tfunc(argextype(stmt.args[3], compact)) + if isexact && argextype(stmt.args[2], compact) ⊑ ty + compact[idx] = nothing end end end diff --git a/base/compiler/ssair/queries.jl b/base/compiler/ssair/queries.jl deleted file mode 100644 index 503db9b7d8774..0000000000000 --- a/base/compiler/ssair/queries.jl +++ /dev/null @@ -1,103 +0,0 @@ -# This file is a part of Julia. License is MIT: https://julialang.org/license - -""" -Determine whether a statement is side-effect-free, i.e. may be removed if it has no uses. -""" -function stmt_effect_free(@nospecialize(stmt), @nospecialize(rt), src, sptypes::Vector{Any}) - isa(stmt, PiNode) && return true - isa(stmt, PhiNode) && return true - isa(stmt, ReturnNode) && return false - isa(stmt, GotoNode) && return false - isa(stmt, GotoIfNot) && return false - isa(stmt, Slot) && return false # Slots shouldn't occur in the IR at this point, but let's be defensive here - isa(stmt, GlobalRef) && return isdefined(stmt.mod, stmt.name) - if isa(stmt, Expr) - e = stmt::Expr - head = e.head - if head === :static_parameter - etyp = sptypes[e.args[1]] - # if we aren't certain enough about the type, it might be an UndefVarError at runtime - return isa(etyp, Const) - end - ea = e.args - if head === :call - f = argextype(ea[1], src, sptypes) - f = singleton_type(f) - f === nothing && return false - is_return_type(f) && return true - if isa(f, IntrinsicFunction) - intrinsic_effect_free_if_nothrow(f) || return false - return intrinsic_nothrow(f, - Any[argextype(ea[i], src, sptypes) for i = 2:length(ea)]) - end - contains_is(_PURE_BUILTINS, f) && return true - contains_is(_PURE_OR_ERROR_BUILTINS, f) || return false - rt === Bottom && return false - return _builtin_nothrow(f, Any[argextype(ea[i], src, sptypes) for i = 2:length(ea)], rt) - elseif head === :new - a = ea[1] - typ = argextype(a, src, sptypes) - # `Expr(:new)` of unknown type could raise arbitrary TypeError. - typ, isexact = instanceof_tfunc(typ) - isexact || return false - isconcretedispatch(typ) || return false - typ = typ::DataType - fieldcount(typ) >= length(ea) - 1 || return false - for fld_idx in 1:(length(ea) - 1) - eT = argextype(ea[fld_idx + 1], src, sptypes) - fT = fieldtype(typ, fld_idx) - eT ⊑ fT || return false - end - return true - elseif head === :new_opaque_closure - length(ea) < 5 && return false - a = ea[1] - typ = argextype(a, src, sptypes) - typ, isexact = instanceof_tfunc(typ) - isexact || return false - typ ⊑ Tuple || return false - isva = argextype(ea[2], src, sptypes) - rt_lb = argextype(ea[3], src, sptypes) - rt_ub = argextype(ea[4], src, sptypes) - src = argextype(ea[5], src, sptypes) - if !(isva ⊑ Bool && rt_lb ⊑ Type && rt_ub ⊑ Type && src ⊑ Method) - return false - end - return true - elseif head === :isdefined || head === :the_exception || head === :copyast || head === :inbounds || head === :boundscheck - return true - else - # e.g. :loopinfo - return false - end - end - return true -end - -function abstract_eval_ssavalue(s::SSAValue, src::IRCode) - return types(src)[s] -end - -function abstract_eval_ssavalue(s::SSAValue, src::IncrementalCompact) - return types(src)[s] -end - -function compact_exprtype(compact::IncrementalCompact, @nospecialize(value)) - if isa(value, AnySSAValue) - return types(compact)[value] - elseif isa(value, Argument) - return compact.ir.argtypes[value.n] - end - return argextype(value, compact.ir, compact.ir.sptypes) -end -argextype(@nospecialize(value), compact::IncrementalCompact, sptypes::Vector{Any}) = compact_exprtype(compact, value) - -is_tuple_call(ir::IRCode, @nospecialize(def)) = isa(def, Expr) && is_known_call(def, tuple, ir, ir.sptypes) -is_tuple_call(compact::IncrementalCompact, @nospecialize(def)) = isa(def, Expr) && is_known_call(def, tuple, compact) -function is_known_call(e::Expr, @nospecialize(func), src::IncrementalCompact) - if e.head !== :call - return false - end - f = compact_exprtype(src, e.args[1]) - return singleton_type(f) === func -end diff --git a/base/compiler/utilities.jl b/base/compiler/utilities.jl index 8d6a3a3eddb70..2a3a975a4551d 100644 --- a/base/compiler/utilities.jl +++ b/base/compiler/utilities.jl @@ -226,43 +226,6 @@ end # types # ######### -argextype(@nospecialize(x), state) = argextype(x, state.src, state.sptypes, state.slottypes) - -const EMPTY_SLOTTYPES = Any[] - -function argextype(@nospecialize(x), src, sptypes::Vector{Any}, slottypes::Vector{Any} = EMPTY_SLOTTYPES) - if isa(x, Expr) - if x.head === :static_parameter - return sptypes[x.args[1]::Int] - elseif x.head === :boundscheck - return Bool - elseif x.head === :copyast - return argextype(x.args[1], src, sptypes, slottypes) - end - @assert false "argextype only works on argument-position values" - elseif isa(x, SlotNumber) - return slottypes[(x::SlotNumber).id] - elseif isa(x, TypedSlot) - return (x::TypedSlot).typ - elseif isa(x, SSAValue) - return abstract_eval_ssavalue(x::SSAValue, src) - elseif isa(x, Argument) - return isa(src, IncrementalCompact) ? src.ir.argtypes[x.n] : - isa(src, IRCode) ? src.argtypes[x.n] : - slottypes[x.n] - elseif isa(x, QuoteNode) - return Const((x::QuoteNode).value) - elseif isa(x, GlobalRef) - return abstract_eval_global(x.mod, (x::GlobalRef).name) - elseif isa(x, PhiNode) - return Any - elseif isa(x, PiNode) - return x.typ - else - return Const(x) - end -end - function singleton_type(@nospecialize(ft)) if isa(ft, Const) return ft.val diff --git a/test/compiler/inline.jl b/test/compiler/inline.jl index 53a7c9b35fb38..83780ca8b1ac5 100644 --- a/test/compiler/inline.jl +++ b/test/compiler/inline.jl @@ -381,7 +381,7 @@ f_oc_getfield(x) = (@opaque ()->x)() @test fully_eliminated(f_oc_getfield, Tuple{Int}) import Core.Compiler: argextype, singleton_type -const EMPTY_SPTYPES = Core.Compiler.EMPTY_SLOTTYPES +const EMPTY_SPTYPES = Any[] code_typed1(args...; kwargs...) = first(only(code_typed(args...; kwargs...)))::Core.CodeInfo get_code(args...; kwargs...) = code_typed1(args...; kwargs...).code diff --git a/test/compiler/irpasses.jl b/test/compiler/irpasses.jl index 2151d938b525f..dbffa41edc7ae 100644 --- a/test/compiler/irpasses.jl +++ b/test/compiler/irpasses.jl @@ -70,7 +70,7 @@ end # Tests for SROA import Core.Compiler: argextype, singleton_type -const EMPTY_SPTYPES = Core.Compiler.EMPTY_SLOTTYPES +const EMPTY_SPTYPES = Any[] code_typed1(args...; kwargs...) = first(only(code_typed(args...; kwargs...)))::Core.CodeInfo get_code(args...; kwargs...) = code_typed1(args...; kwargs...).code @@ -627,7 +627,7 @@ let # `sroa_pass!` should work with constant globals end @test !any(src.code) do @nospecialize(stmt) Meta.isexpr(stmt, :call) || return false - ft = Core.Compiler.argextype(stmt.args[1], src, Any[], src.slottypes) + ft = Core.Compiler.argextype(stmt.args[1], src, EMPTY_SPTYPES) return Core.Compiler.widenconst(ft) == typeof(getfield) end @test !any(src.code) do @nospecialize(stmt) @@ -645,7 +645,7 @@ let # `sroa_pass!` should work with constant globals end @test !any(src.code) do @nospecialize(stmt) Meta.isexpr(stmt, :call) || return false - ft = Core.Compiler.argextype(stmt.args[1], src, Any[], src.slottypes) + ft = Core.Compiler.argextype(stmt.args[1], src, EMPTY_SPTYPES) return Core.Compiler.widenconst(ft) == typeof(getfield) end @test !any(src.code) do @nospecialize(stmt) @@ -668,7 +668,7 @@ let # eliminate `typeassert(x2.x, Foo)` @test all(src.code) do @nospecialize stmt Meta.isexpr(stmt, :call) || return true - ft = Core.Compiler.argextype(stmt.args[1], src, Any[], src.slottypes) + ft = Core.Compiler.argextype(stmt.args[1], src, EMPTY_SPTYPES) return Core.Compiler.widenconst(ft) !== typeof(typeassert) end end