diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index ae1e269af597d..b1655686acc89 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -16,7 +16,6 @@ const _REF_NAME = Ref.body.name call_result_unused(frame::InferenceState, pc::LineNum=frame.currpc) = isexpr(frame.src.code[frame.currpc], :call) && isempty(frame.ssavalue_uses[pc]) - function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), argtypes::Vector{Any}, @nospecialize(atype), sv::InferenceState, max_methods::Int = InferenceParams(interp).MAX_METHODS) if sv.params.unoptimize_throw_blocks && sv.currpc in sv.throw_blocks @@ -1380,7 +1379,11 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) if isa(fname, Slot) changes = StateUpdate(fname, VarState(Any, false), changes) end - elseif hd === :inbounds || hd === :meta || hd === :loopinfo || hd === :code_coverage_effect + elseif hd === :meta + if stmt.args[1] == :noinline + frame.saw_noinline = true + end + elseif hd === :inbounds || hd === :loopinfo || hd === :code_coverage_effect # these do not generate code else t = abstract_eval_statement(interp, stmt, changes, frame) diff --git a/base/compiler/inferencestate.jl b/base/compiler/inferencestate.jl index 2d5fce04c0454..12671e89f855a 100644 --- a/base/compiler/inferencestate.jl +++ b/base/compiler/inferencestate.jl @@ -42,6 +42,7 @@ mutable struct InferenceState limited::Bool inferred::Bool dont_work_on_me::Bool + saw_noinline::Bool # The place to look up methods while working on this function. # In particular, we cache method lookup results for the same function to @@ -113,7 +114,7 @@ mutable struct InferenceState Vector{Tuple{InferenceState,LineNum}}(), # cycle_backedges Vector{InferenceState}(), # callers_in_cycle #=parent=#nothing, - cached, false, false, false, + cached, false, false, false, false, CachedMethodTable(method_table(interp)), interp) result.result = frame diff --git a/base/compiler/optimize.jl b/base/compiler/optimize.jl index 90324b9665175..071d6cce7ab01 100644 --- a/base/compiler/optimize.jl +++ b/base/compiler/optimize.jl @@ -21,10 +21,24 @@ function push!(et::EdgeTracker, ci::CodeInstance) push!(et, ci.def) end -struct InferenceCaches{T, S} - inf_cache::T +# An mi_cache that overlays some base cache, but also caches +# temporary results while we're working on a cycle +struct CycleInferenceCache{S} + cycle_mis::IdDict{MethodInstance, InferenceResult} mi_cache::S end +CycleInferenceCache(mi_cache) = CycleInferenceCache(IdDict{MethodInstance, InferenceResult}(), mi_cache) + +function setindex!(cic::CycleInferenceCache, v::InferenceResult, mi::MethodInstance) + cic.cycle_mis[mi] = v + return cic +end + +function get(cic::CycleInferenceCache, mi::MethodInstance, @nospecialize(default)) + result = get(cic.cycle_mis, mi, nothing) + result !== nothing && return result + return get(cic.mi_cache, mi, default) +end struct InliningState{S <: Union{EdgeTracker, Nothing}, T <: Union{InferenceCaches, Nothing}, V <: Union{Nothing, MethodTableView}} params::OptimizationParams @@ -42,7 +56,9 @@ mutable struct OptimizationState sptypes::Vector{Any} # static parameters slottypes::Vector{Any} const_api::Bool - inlining::InliningState + params::OptimizationParams + et::Union{Nothing, EdgeTracker} + mt::Union{Nothing, MethodTableView} function OptimizationState(frame::InferenceState, params::OptimizationParams, interp::AbstractInterpreter) s_edges = frame.stmt_edges[1] if s_edges === nothing @@ -50,16 +66,12 @@ mutable struct OptimizationState frame.stmt_edges[1] = s_edges end src = frame.src - inlining = InliningState(params, - EdgeTracker(s_edges::Vector{Any}, frame.valid_worlds), - InferenceCaches( - get_inference_cache(interp), - WorldView(code_cache(interp), frame.world)), - method_table(interp)) return new(frame.linfo, src, frame.stmt_info, frame.mod, frame.nargs, frame.sptypes, frame.slottypes, false, - inlining) + params, + EdgeTracker(s_edges::Vector{Any}, frame.valid_worlds), + method_table(interp)) end function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::OptimizationParams, interp::AbstractInterpreter) # prepare src for running optimization passes @@ -86,16 +98,13 @@ mutable struct OptimizationState end # Allow using the global MI cache, but don't track edges. # This method is mostly used for unit testing the optimizer - inlining = InliningState(params, - nothing, - InferenceCaches( - get_inference_cache(interp), - WorldView(code_cache(interp), get_world_counter())), - method_table(interp)) + return new(linfo, src, stmt_info, inmodule, nargs, sptypes_from_meth_instance(linfo), slottypes, false, - inlining) + params, + nothing, + method_table(interp)) end end @@ -180,10 +189,10 @@ function stmt_affects_purity(@nospecialize(stmt), ir) end # run the optimization work -function optimize(opt::OptimizationState, params::OptimizationParams, @nospecialize(result)) +function optimize(opt::OptimizationState, params::OptimizationParams, caches::InferenceCaches, @nospecialize(result)) def = opt.linfo.def nargs = Int(opt.nargs) - 1 - @timeit "optimizer" ir = run_passes(opt.src, nargs, opt) + @timeit "optimizer" ir = run_passes(opt.src, nargs, caches, opt) force_noinline = _any(@nospecialize(x) -> isexpr(x, :meta) && x.args[1] === :noinline, ir.meta) # compute inlining and other related optimizations diff --git a/base/compiler/ssair/domtree.jl b/base/compiler/ssair/domtree.jl index 1ab2876b769da..2afed390db6c8 100644 --- a/base/compiler/ssair/domtree.jl +++ b/base/compiler/ssair/domtree.jl @@ -109,9 +109,14 @@ end length(D::DFSTree) = length(D.from_pre) -function DFS!(D::DFSTree, blocks::Vector{BasicBlock}) +succs(bb::BasicBlock) = bb.succs + +function DFS!(D::DFSTree, blocks::Vector; roots = 1) copy!(D, DFSTree(length(blocks))) - to_visit = Tuple{BBNumber, PreNumber, Bool}[(1, 0, false)] + to_visit = Tuple{BBNumber, PreNumber, Bool}[] + for root in roots + push!(to_visit, (root, 0, false)) + end pre_num = 1 post_num = 1 while !isempty(to_visit) @@ -144,7 +149,7 @@ function DFS!(D::DFSTree, blocks::Vector{BasicBlock}) to_visit[end] = (current_node_bb, parent_pre, true) # Push children to the stack - for succ_bb in blocks[current_node_bb].succs + for succ_bb in succs(blocks[current_node_bb]) push!(to_visit, (succ_bb, pre_num, false)) end @@ -161,7 +166,7 @@ function DFS!(D::DFSTree, blocks::Vector{BasicBlock}) return D end -DFS(blocks::Vector{BasicBlock}) = DFS!(DFSTree(0), blocks) +DFS(blocks::Vector; roots = 1) = DFS!(DFSTree(0), blocks; roots) """ Keeps the per-BB state of the Semi NCA algorithm. In the original formulation, diff --git a/base/compiler/ssair/driver.jl b/base/compiler/ssair/driver.jl index 83205033342d6..bbed8edc218b9 100644 --- a/base/compiler/ssair/driver.jl +++ b/base/compiler/ssair/driver.jl @@ -118,14 +118,15 @@ function slot2reg(ir::IRCode, ci::CodeInfo, nargs::Int, sv::OptimizationState) return ir end -function run_passes(ci::CodeInfo, nargs::Int, sv::OptimizationState) +function run_passes(ci::CodeInfo, nargs::Int, caches::InferenceCaches, sv::OptimizationState) preserve_coverage = coverage_enabled(sv.mod) ir = convert_to_ircode(ci, copy_exprargs(ci.code), preserve_coverage, nargs, sv) ir = slot2reg(ir, ci, nargs, sv) #@Base.show ("after_construct", ir) # TODO: Domsorting can produce an updated domtree - no need to recompute here @timeit "compact 1" ir = compact!(ir) - @timeit "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, sv.inlining, ci.propagate_inbounds) + inlining = InliningState(sv.params, sv.et, caches, sv.mt) + @timeit "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, inlining, ci.propagate_inbounds) #@timeit "verify 2" verify_ir(ir) ir = compact!(ir) #@Base.show ("before_sroa", ir) diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index 0e95f812e5eb6..02a4be1074e95 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -1419,8 +1419,18 @@ function find_inferred(mi::MethodInstance, atypes::Vector{Any}, caches::Inferenc return svec(true, quoted(linfo.rettype_const)) end return svec(false, linfo.inferred) - else - # `linfo` may be `nothing` or an IRCode here - return svec(false, linfo) + elseif isa(linfo, InferenceResult) + let inferred_src = linfo.src + if isa(inferred_src, CodeInfo) + return svec(false, inferred_src) + end + if isa(inferred_src, Const) && is_inlineable_constant(inferred_src.val) + return svec(true, quoted(inferred_src.val),) + end + end + linfo = nothing end + + # `linfo` may be `nothing` or an IRCode here + return svec(false, linfo) end diff --git a/base/compiler/typeinfer.jl b/base/compiler/typeinfer.jl index 04c0edb9a0fde..ddf2f5c813781 100644 --- a/base/compiler/typeinfer.jl +++ b/base/compiler/typeinfer.jl @@ -208,6 +208,36 @@ function typeinf(interp::AbstractInterpreter, frame::InferenceState) end end +struct ForwardEdges + edges::Vector{Int} +end +ForwardEdges() = ForwardEdges(Int[]) +succs(f::ForwardEdges) = f.edges +push!(f::ForwardEdges, i::Int) = push!(f.edges, i) + +function postorder_sort_frames(frames) + length(frames) == 1 && return frames + roots = Int[i for i in 1:length(frames) if frames[i].saw_noinline] + # If there are no noinline annoations, just leave the default order + isempty(roots) && return frames + + # Number frames + numbering = IdDict{Any, Int}(frames[i] => i for i in 1:length(frames)) + + # Compute forward edges + forward_edges = ForwardEdges[ForwardEdges() for i in 1:length(frames)] + for i in 1:length(frames) + frame = frames[i] + for (edge, _) in frame.cycle_backedges + push!(forward_edges[numbering[edge]], i) + end + end + + # Compute postorder + dfs_tree = DFS(forward_edges; roots=roots) + return InferenceState[frames[i] for i in dfs_tree.from_post] +end + function _typeinf(interp::AbstractInterpreter, frame::InferenceState) typeinf_nocycle(interp, frame) || return false # frame is now part of a higher cycle # with no active ip's, frame is done @@ -220,19 +250,26 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState) for caller in frames finish(caller, interp) end + # We postorder sort frames rooted on any frames marked noinline (if any). + # This makes sure that the inliner has the maximum opportunity to inline. + frames = postorder_sort_frames(frames) # collect results for the new expanded frame results = Tuple{InferenceResult, Bool}[ ( frames[i].result, frames[i].cached || frames[i].parent !== nothing ) for i in 1:length(frames) ] # empty!(frames) valid_worlds = frame.valid_worlds cached = frame.cached + caches = InferenceCaches(interp) + cycle_cache = CycleInferenceCache(caches.mi_cache) + caches = InferenceCaches(caches.inf_cache, cycle_cache) if cached || frame.parent !== nothing for (caller, doopt) in results opt = caller.src if opt isa OptimizationState run_optimizer = doopt && may_optimize(interp) if run_optimizer - optimize(opt, OptimizationParams(interp), caller.result) + cycle_cache[opt.linfo] = caller + optimize(opt, OptimizationParams(interp), caches, caller.result) finish(opt.src, interp) # finish updating the result struct validate_code_in_debug_mode(opt.linfo, opt.src, "optimized") @@ -251,7 +288,7 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState) end # As a hack the et reuses frame_edges[1] to push any optimization # edges into, so we don't need to handle them specially here - valid_worlds = intersect(valid_worlds, opt.inlining.et.valid_worlds[]) + valid_worlds = intersect(valid_worlds, opt.et.valid_worlds[]) end end end @@ -768,7 +805,7 @@ function typeinf_code(interp::AbstractInterpreter, method::Method, @nospecialize if typeinf(interp, frame) && run_optimizer opt_params = OptimizationParams(interp) opt = OptimizationState(frame, opt_params, interp) - optimize(opt, opt_params, result.result) + optimize(opt, opt_params, InferenceCaches(interp), result.result) opt.src.inferred = true end ccall(:jl_typeinf_end, Cvoid, ()) diff --git a/base/compiler/types.jl b/base/compiler/types.jl index 3ca6cff20ccd6..446df5719be4a 100644 --- a/base/compiler/types.jl +++ b/base/compiler/types.jl @@ -134,6 +134,11 @@ struct InferenceParams end end +struct InferenceCaches{T, S} + inf_cache::T + mi_cache::S +end + """ NativeInterpreter @@ -187,6 +192,11 @@ get_inference_cache(ni::NativeInterpreter) = ni.cache code_cache(ni::NativeInterpreter) = WorldView(GLOBAL_CI_CACHE, ni.world) +InferenceCaches(ni::NativeInterpreter) = + InferenceCaches( + get_inference_cache(ni), + WorldView(code_cache(ni), ni.world)) + """ lock_mi_inference(ni::NativeInterpreter, mi::MethodInstance)