diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index aa83cc1a7aac2..a057a1879412c 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -35,73 +35,15 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), add_remark!(interp, sv, "Skipped call in throw block") return CallMeta(Any, false) end - valid_worlds = WorldRange() - # NOTE this is valid as far as any "constant" lattice element doesn't represent `Union` type - splitunions = 1 < unionsplitcost(argtypes) <= InferenceParams(interp).MAX_UNION_SPLITTING - mts = Core.MethodTable[] - fullmatch = Bool[] - if splitunions - split_argtypes = switchtupleunion(argtypes) - applicable = Any[] - applicable_argtypes = Vector{Any}[] # arrays like `argtypes`, including constants, for each match - infos = MethodMatchInfo[] - for arg_n in split_argtypes - sig_n = argtypes_to_type(arg_n) - mt = ccall(:jl_method_table_for, Any, (Any,), sig_n) - if mt === nothing - add_remark!(interp, sv, "Could not identify method table for call") - return CallMeta(Any, false) - end - mt = mt::Core.MethodTable - matches = findall(sig_n, method_table(interp); limit=max_methods) - if matches === missing - add_remark!(interp, sv, "For one of the union split cases, too many methods matched") - return CallMeta(Any, false) - end - push!(infos, MethodMatchInfo(matches)) - for m in matches - push!(applicable, m) - push!(applicable_argtypes, arg_n) - end - valid_worlds = intersect(valid_worlds, matches.valid_worlds) - thisfullmatch = _any(match->(match::MethodMatch).fully_covers, matches) - found = false - for (i, mt′) in enumerate(mts) - if mt′ === mt - fullmatch[i] &= thisfullmatch - found = true - break - end - end - if !found - push!(mts, mt) - push!(fullmatch, thisfullmatch) - end - end - info = UnionSplitInfo(infos) - else - mt = ccall(:jl_method_table_for, Any, (Any,), atype) - if mt === nothing - add_remark!(interp, sv, "Could not identify method table for call") - return CallMeta(Any, false) - end - mt = mt::Core.MethodTable - matches = findall(atype, method_table(interp, sv); limit=max_methods) - if matches === missing - # this means too many methods matched - # (assume this will always be true, so we don't compute / update valid age in this case) - add_remark!(interp, sv, "Too many methods matched") - return CallMeta(Any, false) - end - push!(mts, mt) - push!(fullmatch, _any(match->(match::MethodMatch).fully_covers, matches)) - info = MethodMatchInfo(matches) - applicable = matches.matches - valid_worlds = matches.valid_worlds - applicable_argtypes = nothing + + matches = find_matching_methods(argtypes, atype, method_table(interp, sv), InferenceParams(interp).MAX_UNION_SPLITTING, max_methods) + if isa(matches, FailedMethodMatch) + add_remark!(interp, sv, matches.reason) + return CallMeta(Any, false) end + + (; valid_worlds, applicable, info) = matches update_valid_age!(sv, valid_worlds) - applicable = applicable::Array{Any,1} napplicable = length(applicable) rettype = Bottom edges = MethodInstance[] @@ -142,10 +84,13 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), if edge !== nothing push!(edges, edge) end - this_argtypes = applicable_argtypes === nothing ? argtypes : applicable_argtypes[i] - const_rt, const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false) - if const_rt !== rt && const_rt ⊑ rt - rt = const_rt + this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i] + const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false) + if const_result !== nothing + const_rt, const_result = const_result + if const_rt !== rt && const_rt ⊑ rt + rt = const_rt + end end push!(const_results, const_result) if const_result !== nothing @@ -164,10 +109,13 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), end # try constant propagation with argtypes for this match # this is in preparation for inlining, or improving the return result - this_argtypes = applicable_argtypes === nothing ? argtypes : applicable_argtypes[i] - const_this_rt, const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false) - if const_this_rt !== this_rt && const_this_rt ⊑ this_rt - this_rt = const_this_rt + this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i] + const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false) + if const_result !== nothing + const_this_rt, const_result = const_result + if const_this_rt !== this_rt && const_this_rt ⊑ this_rt + this_rt = const_this_rt + end end push!(const_results, const_result) if const_result !== nothing @@ -272,7 +220,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), # and avoid keeping track of a more complex result type. rettype = Any end - add_call_backedges!(interp, rettype, edges, fullmatch, mts, atype, sv) + add_call_backedges!(interp, rettype, edges, matches, atype, sv) if !isempty(sv.pclimitations) # remove self, if present delete!(sv.pclimitations, sv) for caller in sv.callers_in_cycle @@ -283,24 +231,110 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), return CallMeta(rettype, info) end -function add_call_backedges!(interp::AbstractInterpreter, - @nospecialize(rettype), - edges::Vector{MethodInstance}, - fullmatch::Vector{Bool}, mts::Vector{Core.MethodTable}, @nospecialize(atype), - sv::InferenceState) - if rettype === Any - # for `NativeInterpreter`, we don't add backedges when a new method couldn't refine - # (widen) this type - return +struct FailedMethodMatch + reason::String +end + +struct MethodMatches + applicable::Vector{Any} + info::MethodMatchInfo + valid_worlds::WorldRange + mt::Core.MethodTable + fullmatch::Bool +end + +struct UnionSplitMethodMatches + applicable::Vector{Any} + applicable_argtypes::Vector{Vector{Any}} + info::UnionSplitInfo + valid_worlds::WorldRange + mts::Vector{Core.MethodTable} + fullmatches::Vector{Bool} +end + +function find_matching_methods(argtypes::Vector{Any}, @nospecialize(atype), method_table::MethodTableView, + union_split::Int, max_methods::Int) + # NOTE this is valid as far as any "constant" lattice element doesn't represent `Union` type + if 1 < unionsplitcost(argtypes) <= union_split + split_argtypes = switchtupleunion(argtypes) + infos = MethodMatchInfo[] + applicable = Any[] + applicable_argtypes = Vector{Any}[] # arrays like `argtypes`, including constants, for each match + valid_worlds = WorldRange() + mts = Core.MethodTable[] + fullmatches = Bool[] + for i in 1:length(split_argtypes) + arg_n = split_argtypes[i]::Vector{Any} + sig_n = argtypes_to_type(arg_n) + mt = ccall(:jl_method_table_for, Any, (Any,), sig_n) + mt === nothing && return FailedMethodMatch("Could not identify method table for call") + mt = mt::Core.MethodTable + matches = findall(sig_n, method_table; limit = max_methods) + if matches === missing + return FailedMethodMatch("For one of the union split cases, too many methods matched") + end + push!(infos, MethodMatchInfo(matches)) + for m in matches + push!(applicable, m) + push!(applicable_argtypes, arg_n) + end + valid_worlds = intersect(valid_worlds, matches.valid_worlds) + thisfullmatch = _any(match->(match::MethodMatch).fully_covers, matches) + found = false + for (i, mt′) in enumerate(mts) + if mt′ === mt + fullmatches[i] &= thisfullmatch + found = true + break + end + end + if !found + push!(mts, mt) + push!(fullmatches, thisfullmatch) + end + end + return UnionSplitMethodMatches(applicable, + applicable_argtypes, + UnionSplitInfo(infos), + valid_worlds, + mts, + fullmatches) + else + mt = ccall(:jl_method_table_for, Any, (Any,), atype) + if mt === nothing + return FailedMethodMatch("Could not identify method table for call") + end + mt = mt::Core.MethodTable + matches = findall(atype, method_table; limit = max_methods) + if matches === missing + # this means too many methods matched + # (assume this will always be true, so we don't compute / update valid age in this case) + return FailedMethodMatch("Too many methods matched") + end + fullmatch = _any(match->(match::MethodMatch).fully_covers, matches) + return MethodMatches(matches.matches, + MethodMatchInfo(matches), + matches.valid_worlds, + mt, + fullmatch) end +end + +function add_call_backedges!(interp::AbstractInterpreter, @nospecialize(rettype), edges::Vector{MethodInstance}, + matches::Union{MethodMatches,UnionSplitMethodMatches}, @nospecialize(atype), + sv::InferenceState) + # for `NativeInterpreter`, we don't add backedges when a new method couldn't refine (widen) this type + rettype === Any && return for edge in edges add_backedge!(edge, sv) end - for (thisfullmatch, mt) in zip(fullmatch, mts) - if !thisfullmatch - # also need an edge to the method table in case something gets - # added that did not intersect with any existing method - add_mt_backedge!(mt, atype, sv) + # also need an edge to the method table in case something gets + # added that did not intersect with any existing method + if isa(matches, MethodMatches) + matches.fullmatch || add_mt_backedge!(matches.mt, atype, sv) + else + for (thisfullmatch, mt) in zip(matches.fullmatches, matches.mts) + thisfullmatch || add_mt_backedge!(mt, atype, sv) end end end @@ -492,33 +526,39 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, resul @nospecialize(f), argtypes::Vector{Any}, match::MethodMatch, sv::InferenceState, va_override::Bool) mi = maybe_get_const_prop_profitable(interp, result, f, argtypes, match, sv) - mi === nothing && return Any, nothing + mi === nothing && return nothing # try constant prop' inf_cache = get_inference_cache(interp) inf_result = cache_lookup(mi, argtypes, inf_cache) if inf_result === nothing # if there might be a cycle, check to make sure we don't end up # calling ourselves here. - if result.edgecycle && _any(InfStackUnwind(sv)) do infstate - # if the type complexity limiting didn't decide to limit the call signature (`result.edgelimited = false`) - # we can relax the cycle detection by comparing `MethodInstance`s and allow inference to - # propagate different constant elements if the recursion is finite over the lattice - return (result.edgelimited ? match.method === infstate.linfo.def : mi === infstate.linfo) && - any(infstate.result.overridden_by_const) + let result = result # prevent capturing + if result.edgecycle && _any(InfStackUnwind(sv)) do infstate + # if the type complexity limiting didn't decide to limit the call signature (`result.edgelimited = false`) + # we can relax the cycle detection by comparing `MethodInstance`s and allow inference to + # propagate different constant elements if the recursion is finite over the lattice + return (result.edgelimited ? match.method === infstate.linfo.def : mi === infstate.linfo) && + any(infstate.result.overridden_by_const) + end + add_remark!(interp, sv, "[constprop] Edge cycle encountered") + return nothing end - add_remark!(interp, sv, "[constprop] Edge cycle encountered") - return Any, nothing end inf_result = InferenceResult(mi, argtypes, va_override) + if !any(inf_result.overridden_by_const) + add_remark!(interp, sv, "[constprop] Could not handle constant info in matching_cache_argtypes") + return nothing + end frame = InferenceState(inf_result, #=cache=#false, interp) - frame === nothing && return Any, nothing # this is probably a bad generated function (unsound), but just ignore it + frame === nothing && return nothing # this is probably a bad generated function (unsound), but just ignore it frame.parent = sv push!(inf_cache, inf_result) - typeinf(interp, frame) || return Any, nothing + typeinf(interp, frame) || return nothing end result = inf_result.result # if constant inference hits a cycle, just bail out - isa(result, InferenceState) && return Any, nothing + isa(result, InferenceState) && return nothing add_backedge!(mi, sv) return result, inf_result end @@ -1150,7 +1190,8 @@ function abstract_invoke(interp::AbstractInterpreter, argtypes::Vector{Any}, sv: nargtype === Bottom && return CallMeta(Bottom, false) nargtype isa DataType || return CallMeta(Any, false) # other cases are not implemented below isdispatchelem(ft) || return CallMeta(Any, false) # check that we might not have a subtype of `ft` at runtime, before doing supertype lookup below - types = rewrap_unionall(Tuple{ft, unwrap_unionall(types).parameters...}, types) + ft = ft::DataType + types = rewrap_unionall(Tuple{ft, unwrap_unionall(types).parameters...}, types)::Type nargtype = Tuple{ft, nargtype.parameters...} argtype = Tuple{ft, argtype.parameters...} result = findsup(types, method_table(interp)) @@ -1172,12 +1213,14 @@ function abstract_invoke(interp::AbstractInterpreter, argtypes::Vector{Any}, sv: # t, a = ti.parameters[i], argtypes′[i] # argtypes′[i] = t ⊑ a ? t : a # end - const_rt, const_result = abstract_call_method_with_const_args(interp, result, argtype_to_function(ft′), argtypes′, match, sv, false) - if const_rt !== rt && const_rt ⊑ rt - return CallMeta(collect_limitations!(const_rt, sv), InvokeCallInfo(match, const_result)) - else - return CallMeta(collect_limitations!(rt, sv), InvokeCallInfo(match, nothing)) + const_result = abstract_call_method_with_const_args(interp, result, argtype_to_function(ft′), argtypes′, match, sv, false) + if const_result !== nothing + const_rt, const_result = const_result + if const_rt !== rt && const_rt ⊑ rt + return CallMeta(collect_limitations!(const_rt, sv), InvokeCallInfo(match, const_result)) + end end + return CallMeta(collect_limitations!(rt, sv), InvokeCallInfo(match, nothing)) end # call where the function is known exactly @@ -1279,19 +1322,20 @@ end function abstract_call_opaque_closure(interp::AbstractInterpreter, closure::PartialOpaque, argtypes::Vector{Any}, sv::InferenceState) pushfirst!(argtypes, closure.env) sig = argtypes_to_type(argtypes) - (; rt, edge) = result = abstract_call_method(interp, closure.source::Method, sig, Core.svec(), false, sv) + (; rt, edge) = result = abstract_call_method(interp, closure.source, sig, Core.svec(), false, sv) edge !== nothing && add_backedge!(edge, sv) tt = closure.typ - sigT = unwrap_unionall(tt).parameters[1] - match = MethodMatch(sig, Core.svec(), closure.source::Method, sig <: rewrap_unionall(sigT, tt)) + sigT = (unwrap_unionall(tt)::DataType).parameters[1] + match = MethodMatch(sig, Core.svec(), closure.source, sig <: rewrap_unionall(sigT, tt)) info = OpaqueClosureCallInfo(match) if !result.edgecycle - const_rettype, const_result = abstract_call_method_with_const_args(interp, result, closure, argtypes, + const_result = abstract_call_method_with_const_args(interp, result, closure, argtypes, match, sv, closure.isva) - if const_rettype ⊑ rt - rt = const_rettype - end if const_result !== nothing + const_rettype, const_result = const_result + if const_rettype ⊑ rt + rt = const_rettype + end info = ConstCallInfo(info, Union{Nothing,InferenceResult}[const_result]) end end @@ -1301,7 +1345,7 @@ end function most_general_argtypes(closure::PartialOpaque) ret = Any[] cc = widenconst(closure) - argt = unwrap_unionall(cc).parameters[1] + argt = (unwrap_unionall(cc)::DataType).parameters[1] if !isa(argt, DataType) || argt.name !== typename(Tuple) argt = Tuple end @@ -1316,8 +1360,8 @@ function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{ f = argtype_to_function(ft) if isa(ft, PartialOpaque) return abstract_call_opaque_closure(interp, ft, argtypes[2:end], sv) - elseif isa(unwrap_unionall(ft), DataType) && unwrap_unionall(ft).name === typename(Core.OpaqueClosure) - return CallMeta(rewrap_unionall(unwrap_unionall(ft).parameters[2], ft), false) + elseif (uft = unwrap_unionall(ft); isa(uft, DataType) && uft.name === typename(Core.OpaqueClosure)) + return CallMeta(rewrap_unionall((uft::DataType).parameters[2], ft), false) elseif f === nothing # non-constant function, but the number of arguments is known # and the ft is not a Builtin or IntrinsicFunction @@ -1513,12 +1557,12 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e), if length(e.args) == 2 && isconcretetype(t) && !ismutabletype(t) at = abstract_eval_value(interp, e.args[2], vtypes, sv) n = fieldcount(t) - if isa(at, Const) && isa(at.val, Tuple) && n == length(at.val) && - let t = t; _all(i->getfield(at.val, i) isa fieldtype(t, i), 1:n); end + if isa(at, Const) && isa(at.val, Tuple) && n == length(at.val::Tuple) && + let t = t; _all(i->getfield(at.val::Tuple, i) isa fieldtype(t, i), 1:n); end t = Const(ccall(:jl_new_structt, Any, (Any, Any), t, at.val)) - elseif isa(at, PartialStruct) && at ⊑ Tuple && n == length(at.fields) && - let t = t, at = at; _all(i->at.fields[i] ⊑ fieldtype(t, i), 1:n); end - t = PartialStruct(t, at.fields) + elseif isa(at, PartialStruct) && at ⊑ Tuple && n == length(at.fields::Vector{Any}) && + let t = t, at = at; _all(i->(at.fields::Vector{Any})[i] ⊑ fieldtype(t, i), 1:n); end + t = PartialStruct(t, at.fields::Vector{Any}) end end elseif ehead === :new_opaque_closure @@ -1566,7 +1610,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e), sym = e.args[1] t = Bool if isa(sym, SlotNumber) - vtyp = vtypes[slot_id(sym)] + vtyp = vtypes[slot_id(sym)]::VarState if vtyp.typ === Bottom t = Const(false) # never assigned previously elseif !vtyp.undef @@ -1581,7 +1625,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e), t = Const(true) end elseif isa(sym, Expr) && sym.head === :static_parameter - n = sym.args[1] + n = sym.args[1]::Int if 1 <= n <= length(sv.sptypes) spty = sv.sptypes[n] if isa(spty, Const) @@ -1616,7 +1660,7 @@ function abstract_eval_global(M::Module, s::Symbol) end function abstract_eval_ssavalue(s::SSAValue, src::CodeInfo) - typ = src.ssavaluetypes[s.id] + typ = (src.ssavaluetypes::Vector{Any})[s.id] if typ === NOT_FOUND return Bottom end @@ -1704,6 +1748,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) isva = isa(def, Method) && def.isva nslots = nargs - isva slottypes = frame.slottypes + ssavaluetypes = frame.src.ssavaluetypes::Vector{Any} while frame.pc´´ <= n # make progress on the active ip set local pc::Int = frame.pc´´ @@ -1804,7 +1849,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) for (caller, caller_pc) in frame.cycle_backedges # notify backedges of updated type information typeassert(caller.stmt_types[caller_pc], VarTable) # we must have visited this statement before - if !(caller.src.ssavaluetypes[caller_pc] === Any) + if !((caller.src.ssavaluetypes::Vector{Any})[caller_pc] === Any) # no reason to revisit if that call-site doesn't affect the final result if caller_pc < caller.pc´´ caller.pc´´ = caller_pc @@ -1814,6 +1859,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) end end elseif hd === :enter + stmt = stmt::Expr l = stmt.args[1]::Int # propagate type info to exception handler old = states[l] @@ -1829,16 +1875,18 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) elseif hd === :leave else if hd === :(=) + stmt = stmt::Expr t = abstract_eval_statement(interp, stmt.args[2], changes, frame) if t === Bottom break end - frame.src.ssavaluetypes[pc] = t + ssavaluetypes[pc] = t lhs = stmt.args[1] if isa(lhs, SlotNumber) changes = StateUpdate(lhs, VarState(t, false), changes, false) end elseif hd === :method + stmt = stmt::Expr fname = stmt.args[1] if isa(fname, SlotNumber) changes = StateUpdate(fname, VarState(Any, false), changes, false) @@ -1853,7 +1901,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) if !isempty(frame.ssavalue_uses[pc]) record_ssa_assign(pc, t, frame) else - frame.src.ssavaluetypes[pc] = t + ssavaluetypes[pc] = t end end if isa(changes, StateUpdate) @@ -1880,7 +1928,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) if t === nothing # mark other reached expressions as `Any` to indicate they don't throw - frame.src.ssavaluetypes[pc] = Any + ssavaluetypes[pc] = Any end pc´ > n && break # can't proceed with the fast-path fall-through diff --git a/base/compiler/inferenceresult.jl b/base/compiler/inferenceresult.jl index 327ab85d104f3..483e2f38d9ee8 100644 --- a/base/compiler/inferenceresult.jl +++ b/base/compiler/inferenceresult.jl @@ -13,33 +13,35 @@ end # for the provided `linfo` and `given_argtypes`. The purpose of this function is # to return a valid value for `cache_lookup(linfo, argtypes, cache).argtypes`, # so that we can construct cache-correct `InferenceResult`s in the first place. -function matching_cache_argtypes(linfo::MethodInstance, given_argtypes::Vector, va_override) +function matching_cache_argtypes(linfo::MethodInstance, given_argtypes::Vector, va_override::Bool) @assert isa(linfo.def, Method) # ensure the next line works nargs::Int = linfo.def.nargs - @assert length(given_argtypes) >= (nargs - 1) given_argtypes = anymap(widenconditional, given_argtypes) - if va_override || linfo.def.isva + isva = va_override || linfo.def.isva + if isva || isvarargtype(given_argtypes[end]) isva_given_argtypes = Vector{Any}(undef, nargs) - for i = 1:(nargs - 1) + for i = 1:(nargs - isva) isva_given_argtypes[i] = argtype_by_index(given_argtypes, i) end - if length(given_argtypes) >= nargs || !isvarargtype(given_argtypes[end]) - isva_given_argtypes[nargs] = tuple_tfunc(given_argtypes[nargs:end]) - else - isva_given_argtypes[nargs] = tuple_tfunc(given_argtypes[end:end]) + if isva + if length(given_argtypes) < nargs && isvarargtype(given_argtypes[end]) + last = length(given_argtypes) + else + last = nargs + end + isva_given_argtypes[nargs] = tuple_tfunc(given_argtypes[last:end]) end given_argtypes = isva_given_argtypes end + @assert length(given_argtypes) == nargs cache_argtypes, overridden_by_const = matching_cache_argtypes(linfo, nothing, va_override) - if nargs === length(given_argtypes) - for i in 1:nargs - given_argtype = given_argtypes[i] - cache_argtype = cache_argtypes[i] - if !is_argtype_match(given_argtype, cache_argtype, overridden_by_const[i]) - # prefer the argtype we were given over the one computed from `linfo` - cache_argtypes[i] = given_argtype - overridden_by_const[i] = true - end + for i in 1:nargs + given_argtype = given_argtypes[i] + cache_argtype = cache_argtypes[i] + if !is_argtype_match(given_argtype, cache_argtype, overridden_by_const[i]) + # prefer the argtype we were given over the one computed from `linfo` + cache_argtypes[i] = given_argtype + overridden_by_const[i] = true end end return cache_argtypes, overridden_by_const diff --git a/base/compiler/inferencestate.jl b/base/compiler/inferencestate.jl index aa9a3ad1f0094..f13622edb23fe 100644 --- a/base/compiler/inferencestate.jl +++ b/base/compiler/inferencestate.jl @@ -265,7 +265,7 @@ function sptypes_from_meth_instance(linfo::MethodInstance) while temp isa UnionAll temp = temp.body end - sigtypes = temp.parameters + sigtypes = (temp::DataType).parameters for j = 1:length(sigtypes) tj = sigtypes[j] if isType(tj) && tj.parameters[1] === Pi diff --git a/base/compiler/optimize.jl b/base/compiler/optimize.jl index ad0426860ece9..b8ec9610e0739 100644 --- a/base/compiler/optimize.jl +++ b/base/compiler/optimize.jl @@ -196,10 +196,11 @@ function stmt_affects_purity(@nospecialize(stmt), ir) return true end -# Convert IRCode back to CodeInfo and compute inlining cost and sideeffects +# compute inlining cost and sideeffects function finish(interp::AbstractInterpreter, opt::OptimizationState, params::OptimizationParams, ir::IRCode, @nospecialize(result)) - def = opt.linfo.def - nargs = Int(opt.nargs) - 1 + (; src, nargs, linfo) = opt + (; def, specTypes) = linfo + nargs = Int(nargs) - 1 force_noinline = _any(@nospecialize(x) -> isexpr(x, :meta) && x.args[1] === :noinline, ir.meta) @@ -221,7 +222,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState, params::Opt end end if proven_pure - for fl in opt.src.slotflags + for fl in src.slotflags if (fl & SLOT_USEDUNDEF) != 0 proven_pure = false break @@ -230,7 +231,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState, params::Opt end end if proven_pure - opt.src.pure = true + src.pure = true end if proven_pure @@ -243,7 +244,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState, params::Opt if !(isa(result, Const) && !is_inlineable_constant(result.val)) opt.const_api = true end - force_noinline || (opt.src.inlineable = true) + force_noinline || (src.inlineable = true) end end @@ -252,7 +253,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState, params::Opt # determine and cache inlineability union_penalties = false if !force_noinline - sig = unwrap_unionall(opt.linfo.specTypes) + sig = unwrap_unionall(specTypes) if isa(sig, DataType) && sig.name === Tuple.name for P in sig.parameters P = unwrap_unionall(P) @@ -264,25 +265,25 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState, params::Opt else force_noinline = true end - if !opt.src.inlineable && result === Union{} + if !src.inlineable && result === Union{} force_noinline = true end end if force_noinline - opt.src.inlineable = false + src.inlineable = false elseif isa(def, Method) - if opt.src.inlineable && isdispatchtuple(opt.linfo.specTypes) + if src.inlineable && isdispatchtuple(specTypes) # obey @inline declaration if a dispatch barrier would not help else bonus = 0 if result ⊑ Tuple && !isconcretetype(widenconst(result)) bonus = params.inline_tupleret_bonus end - if opt.src.inlineable + if src.inlineable # For functions declared @inline, increase the cost threshold 20x bonus += params.inline_cost_threshold*19 end - opt.src.inlineable = isinlineable(def, opt, params, union_penalties, bonus) + src.inlineable = isinlineable(def, opt, params, union_penalties, bonus) end end diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index 05ed9511b23d8..077a1f105d3d8 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -313,8 +313,10 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector push!(linetable, LineInfoNode(entry.module, entry.method, entry.file, entry.line, (entry.inlined_at > 0 ? entry.inlined_at + linetable_offset : inlined_at))) end - nargs_def = item.mi.def.nargs - isva = nargs_def > 0 && item.mi.def.isva + (; def, sparam_vals) = item.mi + nargs_def = def.nargs::Int32 + isva = nargs_def > 0 && def.isva + sig = def.sig if isva vararg = mk_tuplecall!(compact, argexprs[nargs_def:end], compact.result[idx][:line]) argexprs = Any[argexprs[1:(nargs_def - 1)]..., vararg] @@ -347,7 +349,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector # face of rename_arguments! mutating in place - should figure out # something better eventually. inline_compact[idx′] = nothing - stmt′ = ssa_substitute!(idx′, stmt′, argexprs, item.mi.def.sig, item.mi.sparam_vals, linetable_offset, boundscheck_idx, compact) + stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck_idx, compact) if isa(stmt′, ReturnNode) isa(stmt′.val, SSAValue) && (compact.used_ssas[stmt′.val.id] += 1) return_value = SSAValue(idx′) @@ -374,7 +376,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector inline_compact = IncrementalCompact(compact, spec.ir, compact.result_idx) for ((_, idx′), stmt′) in inline_compact inline_compact[idx′] = nothing - stmt′ = ssa_substitute!(idx′, stmt′, argexprs, item.mi.def.sig, item.mi.sparam_vals, linetable_offset, boundscheck_idx, compact) + stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck_idx, compact) if isa(stmt′, ReturnNode) if isdefined(stmt′, :val) val = stmt′.val @@ -709,9 +711,8 @@ function compileable_specialization(et::Union{EdgeTracker, Nothing}, match::Meth return mi end -function compileable_specialization(et::Union{EdgeTracker, Nothing}, result::InferenceResult) - mi = specialize_method(result.linfo.def::Method, result.linfo.specTypes, - result.linfo.sparam_vals, false, true) +function compileable_specialization(et::Union{EdgeTracker, Nothing}, (; linfo)::InferenceResult) + mi = specialize_method(linfo.def::Method, linfo.specTypes, linfo.sparam_vals, false, true) mi !== nothing && et !== nothing && push!(et, mi::MethodInstance) return mi end @@ -1065,9 +1066,9 @@ function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, (; match, result): pushfirst!(atypes, atype0) if isa(result, InferenceResult) - item = InliningTodo(result, atypes, calltype) - validate_sparams(item.mi.sparam_vals) || return nothing - if argtypes_to_type(atypes) <: item.mi.def.sig + (; mi) = item = InliningTodo(result, atypes, calltype) + validate_sparams(mi.sparam_vals) || return nothing + if argtypes_to_type(atypes) <: mi.def.sig state.mi_cache !== nothing && (item = resolve_todo(item, state)) handle_single_case!(ir, stmt, idx, item, true, todo) return nothing @@ -1195,7 +1196,7 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int for i in 1:length(infos) info = infos[i] meth = info.results - if meth === missing || meth.ambig + if meth.ambig # Too many applicable methods # Or there is a (partial?) ambiguity too_many = true @@ -1213,8 +1214,9 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int only_method = false end for match in meth - signature_union = Union{signature_union, match.spec_types} - if !isdispatchtuple(match.spec_types) + spec_types = match.spec_types + signature_union = Union{signature_union, spec_types} + if !isdispatchtuple(spec_types) fully_covered = false continue end @@ -1222,10 +1224,10 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int if case === nothing fully_covered = false continue - elseif _any(p->p[1] === match.spec_types, cases) + elseif _any(p->p[1] === spec_types, cases) continue end - push!(cases, Pair{Any,Any}(match.spec_types, case)) + push!(cases, Pair{Any,Any}(spec_types, case)) end end diff --git a/base/compiler/ssair/legacy.jl b/base/compiler/ssair/legacy.jl index 49d9aef973e29..e9fddd1d12a02 100644 --- a/base/compiler/ssair/legacy.jl +++ b/base/compiler/ssair/legacy.jl @@ -47,7 +47,7 @@ function replace_code_newstyle!(ci::CodeInfo, ir::IRCode, nargs::Int) for metanode in ir.meta push!(ci.code, metanode) push!(ci.codelocs, 1) - push!(ci.ssavaluetypes, Any) + push!(ci.ssavaluetypes::Vector{Any}, Any) push!(ci.ssaflags, 0x00) end # Translate BB Edges to statement edges diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index ce20a61a4e983..7c8964d371122 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -1064,7 +1064,7 @@ function type_lift_pass!(ir::IRCode) if haskey(processed, id) val = processed[id] else - push!(worklist, (id, up_id, new_phi, i)) + push!(worklist, (id, up_id, new_phi::SSAValue, i)) continue end else diff --git a/base/compiler/ssair/slot2ssa.jl b/base/compiler/ssair/slot2ssa.jl index 21c0bf00ec755..91543835c8c06 100644 --- a/base/compiler/ssair/slot2ssa.jl +++ b/base/compiler/ssair/slot2ssa.jl @@ -871,7 +871,7 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, domtree::DomTree, defuse, narg changed = false for new_idx in type_refine_phi node = new_nodes.stmts[new_idx] - new_typ = recompute_type(node[:inst], ci, ir, ir.sptypes, slottypes) + new_typ = recompute_type(node[:inst]::Union{PhiNode,PhiCNode}, ci, ir, ir.sptypes, slottypes) if !(node[:type] ⊑ new_typ) || !(new_typ ⊑ node[:type]) node[:type] = new_typ changed = true diff --git a/base/compiler/stmtinfo.jl b/base/compiler/stmtinfo.jl index a6ffee299c4f5..0c54e9359fa1a 100644 --- a/base/compiler/stmtinfo.jl +++ b/base/compiler/stmtinfo.jl @@ -9,7 +9,7 @@ to re-consult the method table. This info is illegal on any statement that is not a call to a generic function. """ struct MethodMatchInfo - results::Union{Missing, MethodLookupResult} + results::MethodLookupResult end """ diff --git a/base/compiler/tfuncs.jl b/base/compiler/tfuncs.jl index 511af138883f4..ca03710bbbd47 100644 --- a/base/compiler/tfuncs.jl +++ b/base/compiler/tfuncs.jl @@ -1627,7 +1627,7 @@ function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtyp if length(argtypes) - 1 == tf[2] argtypes = argtypes[1:end-1] else - vatype = argtypes[end] + vatype = argtypes[end]::Core.TypeofVararg argtypes = argtypes[1:end-1] while length(argtypes) < tf[1] push!(argtypes, unwrapva(vatype)) @@ -1733,7 +1733,7 @@ function return_type_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, s aft = argtypes[2] if isa(aft, Const) || (isType(aft) && !has_free_typevars(aft)) || (isconcretetype(aft) && !(aft <: Builtin)) - af_argtype = isa(tt, Const) ? tt.val : tt.parameters[1] + af_argtype = isa(tt, Const) ? tt.val : (tt::DataType).parameters[1] if isa(af_argtype, DataType) && af_argtype <: Tuple argtypes_vec = Any[aft, af_argtype.parameters...] if contains_is(argtypes_vec, Union{}) diff --git a/base/compiler/typeinfer.jl b/base/compiler/typeinfer.jl index 25a07fbb5ee7d..ef6e5a161864a 100644 --- a/base/compiler/typeinfer.jl +++ b/base/compiler/typeinfer.jl @@ -1,6 +1,6 @@ # This file is a part of Julia. License is MIT: https://julialang.org/license -# build (and start inferring) the inference frame for the linfo +# build (and start inferring) the inference frame for the top-level MethodInstance function typeinf(interp::AbstractInterpreter, result::InferenceResult, cached::Bool) frame = InferenceState(result, cached, interp) frame === nothing && return false @@ -243,7 +243,7 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState) # collect results for the new expanded frame results = Tuple{InferenceResult, Vector{Any}, Bool}[ ( frames[i].result, - frames[i].stmt_edges[1], + frames[i].stmt_edges[1]::Vector{Any}, frames[i].cached ) for i in 1:length(frames) ] empty!(frames) @@ -341,7 +341,7 @@ function maybe_compress_codeinfo(interp::AbstractInterpreter, linfo::MethodInsta if cache_the_tree if may_compress(interp) nslots = length(ci.slotflags) - resize!(ci.slottypes, nslots) + resize!(ci.slottypes::Vector{Any}, nslots) resize!(ci.slotnames, nslots) return ccall(:jl_compress_ir, Any, (Any, Any), def, ci) else @@ -386,17 +386,18 @@ function cache_result!(interp::AbstractInterpreter, result::InferenceResult) end # check if the existing linfo metadata is also sufficient to describe the current inference result # to decide if it is worth caching this - already_inferred = already_inferred_quick_test(interp, result.linfo) - if !already_inferred && haskey(WorldView(code_cache(interp), valid_worlds), result.linfo) + linfo = result.linfo + already_inferred = already_inferred_quick_test(interp, linfo) + if !already_inferred && haskey(WorldView(code_cache(interp), valid_worlds), linfo) already_inferred = true end # TODO: also don't store inferred code if we've previously decided to interpret this function if !already_inferred - inferred_result = transform_result_for_cache(interp, result.linfo, valid_worlds, result.src) - code_cache(interp)[result.linfo] = CodeInstance(result, inferred_result, valid_worlds) + inferred_result = transform_result_for_cache(interp, linfo, valid_worlds, result.src) + code_cache(interp)[linfo] = CodeInstance(result, inferred_result, valid_worlds) end - unlock_mi_inference(interp, result.linfo) + unlock_mi_inference(interp, linfo) nothing end @@ -437,7 +438,7 @@ function finish(me::InferenceState, interp::AbstractInterpreter) empty!(edges) end if me.src.edges !== nothing - append!(s_edges, me.src.edges) + append!(s_edges, me.src.edges::Vector) me.src.edges = nothing end # inspect whether our inference had a limited result accuracy, @@ -446,7 +447,7 @@ function finish(me::InferenceState, interp::AbstractInterpreter) limited_ret = me.bestguess isa LimitedAccuracy limited_src = false if !limited_ret - gt = me.src.ssavaluetypes + gt = me.src.ssavaluetypes::Vector{Any} for j = 1:length(gt) gt[j] = gtj = cycle_fix_limited(gt[j], me) if gtj isa LimitedAccuracy && me.parent !== nothing @@ -510,8 +511,9 @@ end # widen all Const elements in type annotations function widen_all_consts!(src::CodeInfo) - for i = 1:length(src.ssavaluetypes) - src.ssavaluetypes[i] = widenconst(src.ssavaluetypes[i]) + ssavaluetypes = src.ssavaluetypes::Vector{Any} + for i = 1:length(ssavaluetypes) + ssavaluetypes[i] = widenconst(ssavaluetypes[i]) end for i = 1:length(src.code) @@ -576,6 +578,7 @@ function record_slot_assign!(sv::InferenceState) states = sv.stmt_types body = sv.src.code::Vector{Any} slottypes = sv.slottypes::Vector{Any} + ssavaluetypes = sv.src.ssavaluetypes::Vector{Any} for i = 1:length(body) expr = body[i] st_i = states[i] @@ -584,7 +587,7 @@ function record_slot_assign!(sv::InferenceState) lhs = expr.args[1] rhs = expr.args[2] if isa(lhs, SlotNumber) - vt = widenconst(sv.src.ssavaluetypes[i]) + vt = widenconst(ssavaluetypes[i]) if vt !== Bottom id = slot_id(lhs) otherTy = slottypes[id] @@ -607,12 +610,11 @@ function type_annotate!(sv::InferenceState, run_optimizer::Bool) # (otherwise, we'll perhaps run the optimization passes later, outside of inference) # remove all unused ssa values - gt = sv.src.ssavaluetypes - for j = 1:length(gt) - if gt[j] === NOT_FOUND - gt[j] = Union{} - end - gt[j] = widenconditional(gt[j]) + src = sv.src + ssavaluetypes = src.ssavaluetypes::Vector{Any} + for j = 1:length(ssavaluetypes) + t = ssavaluetypes[j] + ssavaluetypes[j] = t === NOT_FOUND ? Union{} : widenconditional(t) end # compute the required type for each slot @@ -625,7 +627,6 @@ function type_annotate!(sv::InferenceState, run_optimizer::Bool) # annotate variables load types # remove dead code optimization # and compute which variables may be used undef - src = sv.src states = sv.stmt_types nargs = sv.nargs nslots = length(states[1]::VarTable) @@ -668,7 +669,7 @@ function type_annotate!(sv::InferenceState, run_optimizer::Bool) elseif run_optimizer deleteat!(body, i) deleteat!(states, i) - deleteat!(src.ssavaluetypes, i) + deleteat!(ssavaluetypes, i) deleteat!(src.codelocs, i) deleteat!(sv.stmt_info, i) nexpr -= 1 diff --git a/base/compiler/typelattice.jl b/base/compiler/typelattice.jl index 2d65211c273b2..2f026d41efb35 100644 --- a/base/compiler/typelattice.jl +++ b/base/compiler/typelattice.jl @@ -309,7 +309,7 @@ function smerge(sa::Union{NotFound,VarState}, sb::Union{NotFound,VarState}) end @inline tchanged(@nospecialize(n), @nospecialize(o)) = o === NOT_FOUND || (n !== NOT_FOUND && !(n ⊑ o)) -@inline schanged(@nospecialize(n), @nospecialize(o)) = (n !== o) && (o === NOT_FOUND || (n !== NOT_FOUND && !issubstate(n, o))) +@inline schanged(@nospecialize(n), @nospecialize(o)) = (n !== o) && (o === NOT_FOUND || (n !== NOT_FOUND && !issubstate(n::VarState, o::VarState))) widenconditional(@nospecialize typ) = typ function widenconditional(typ::AnyConditional) @@ -406,7 +406,7 @@ function stupdate1!(state::VarTable, change::StateUpdate) if isa(oldtypetyp, Conditional) && slot_id(oldtypetyp.var) == changeid oldtypetyp = widenconditional(oldtypetyp) if oldtype.typ isa LimitedAccuracy - oldtypetyp = LimitedAccuracy(oldtypetyp, oldtype.typ.causes) + oldtypetyp = LimitedAccuracy(oldtypetyp, (oldtype.typ::LimitedAccuracy).causes) end state[i] = VarState(oldtypetyp, oldtype.undef) end diff --git a/base/compiler/validation.jl b/base/compiler/validation.jl index 6e0f81114744b..02fb1b02c6ef0 100644 --- a/base/compiler/validation.jl +++ b/base/compiler/validation.jl @@ -1,7 +1,7 @@ # This file is a part of Julia. License is MIT: https://julialang.org/license # Expr head => argument count bounds -const VALID_EXPR_HEADS = IdDict{Symbol,UnitRange}( +const VALID_EXPR_HEADS = IdDict{Symbol,UnitRange{Int}}( :call => 1:typemax(Int), :invoke => 2:typemax(Int), :invoke_modify => 3:typemax(Int), @@ -182,10 +182,11 @@ function validate_code!(errors::Vector{>:InvalidCodeError}, c::CodeInfo, is_top_ !is_top_level && nslotnames == 0 && push!(errors, InvalidCodeError(EMPTY_SLOTNAMES)) nslotnames < nslotflags && push!(errors, InvalidCodeError(SLOTFLAGS_MISMATCH, (nslotnames, nslotflags))) if c.inferred - nssavaluetypes = length(c.ssavaluetypes) + nssavaluetypes = length(c.ssavaluetypes::Vector{Any}) nssavaluetypes < nssavals && push!(errors, InvalidCodeError(SSAVALUETYPES_MISMATCH, (nssavals, nssavaluetypes))) else - c.ssavaluetypes != nssavals && push!(errors, InvalidCodeError(SSAVALUETYPES_MISMATCH_UNINFERRED, (nssavals, c.ssavaluetypes))) + ssavaluetypes = c.ssavaluetypes::Int + ssavaluetypes != nssavals && push!(errors, InvalidCodeError(SSAVALUETYPES_MISMATCH_UNINFERRED, (nssavals, ssavaluetypes))) end return errors end @@ -207,7 +208,7 @@ function validate_code!(errors::Vector{>:InvalidCodeError}, mi::Core.MethodInsta else m = mi.def::Method mnargs = m.nargs - n_sig_params = length(Core.Compiler.unwrap_unionall(m.sig).parameters) + n_sig_params = length((unwrap_unionall(m.sig)::DataType).parameters) if (m.isva ? (n_sig_params < (mnargs - 1)) : (n_sig_params != mnargs)) push!(errors, InvalidCodeError(SIGNATURE_NARGS_MISMATCH, (m.isva, n_sig_params, mnargs))) end diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index a2432169b09ad..d4d0f6700c179 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -2238,12 +2238,10 @@ code28279 = code_lowered(f28279, (Bool,))[1].code oldcode28279 = deepcopy(code28279) ssachangemap = fill(0, length(code28279)) labelchangemap = fill(0, length(code28279)) -worklist = Int[] let i for i in 1:length(code28279) stmt = code28279[i] if isa(stmt, GotoIfNot) - push!(worklist, i) ssachangemap[i] = 1 if i < length(code28279) labelchangemap[i + 1] = 1 @@ -3499,3 +3497,10 @@ end end return x end) === Union{Int, Float64, Char} + +# issue #42097 +struct Foo42097{F} end +Foo42097(f::F, args) where {F} = Foo42097{F}() +Foo42097(A) = Foo42097(Base.inferencebarrier(+), Base.inferencebarrier(1)...) +foo42097() = Foo42097([1]...) +@test foo42097() isa Foo42097{typeof(+)}