diff --git a/base/compiler/compiler.jl b/base/compiler/compiler.jl index 6991e2d38437b..a49344eac09fa 100644 --- a/base/compiler/compiler.jl +++ b/base/compiler/compiler.jl @@ -154,6 +154,8 @@ function extrema(x::Array) return vmin, vmax end +include("compiler/plugin.jl") + include("compiler/bootstrap.jl") ccall(:jl_set_typeinf_func, Cvoid, (Any,), typeinf_ext_toplevel) diff --git a/base/compiler/plugin.jl b/base/compiler/plugin.jl new file mode 100644 index 0000000000000..5dcab732c470f --- /dev/null +++ b/base/compiler/plugin.jl @@ -0,0 +1,47 @@ +# this simple code assumes a plugin AbstractInterpreter manages its own code cache +# in a way that it is totally separated from the native code cache + +# TODO more composable way +global PLUGIN_INTERPRETER::AbstractInterpreter = NativeInterpreter() +isplugin(interp::AbstractInterpreter) = let + global PLUGIN_INTERPRETER + interp === PLUGIN_INTERPRETER +end + +function new_opaque_closure(src::CodeInfo, nargs::Int, @nospecialize(rt), + @nospecialize(env...)) + @assert src.inferred "unoptimized IR unsupported" + argt = argtypes_to_type((src.slottypes::Vector{Any})[2:nargs+1]) + + # M = src.parent.def + # sig = Base.tuple_type_tail(src.parent.specTypes) + + return ccall(:jl_new_opaque_closure_from_code_info, Any, + (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any), + argt, Union{}, rt, @__MODULE__, src, 0, nothing, nargs, false, env) +end + +function execute_with_plugin(f, args...; + world::UInt = get_world_counter(), + interp::AbstractInterpreter = PLUGIN_INTERPRETER) + global PLUGIN_INTERPRETER = interp + tt = Tuple{Core.Typeof(f), Any[Core.Typeof(args[i]) for i = 1:length(args)]...} + matches = _methods_by_ftype(tt, -1, world) + matches !== nothing || throw(MethodError(f, args, world)) + length(matches) ≠ 1 && throw(MethodError(f, args, world)) + m = first(matches)::MethodMatch + src, rt = typeinf_code(interp, + m.method, m.spec_types, m.sparams, true) + if src === nothing + # builtin, or bad generated function + return f(args...) # XXX + end + # TODO support varargs + oc = new_opaque_closure(src, length(args), rt) + return oc(args...) +end + +module Plugin +import ..execute_with_plugin +export execute_with_plugin +end diff --git a/base/compiler/typeinfer.jl b/base/compiler/typeinfer.jl index 94387b643b0b6..35f4f090208bd 100644 --- a/base/compiler/typeinfer.jl +++ b/base/compiler/typeinfer.jl @@ -277,10 +277,36 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState) cache_result!(interp, caller) end finish!(interp, caller) + if isplugin(interp) + fixup_plugin_entry!(interp, caller) + end end return true end +function fixup_plugin_entry!(interp::AbstractInterpreter, caller::InferenceResult) + src = caller.src + if src isa CodeInfo + fixup_plugin_entry!(src, collect(Any, caller.linfo.sparam_vals)) + elseif src isa OptimizationState + fixup_plugin_entry!(src.src, src.sptypes) + end +end + +function fixup_plugin_entry!(src::CodeInfo, sptypes::Vector{Any}) + @assert src.inferred + for i = 1:length(src.code) + stmt = src.code[i] + if isexpr(stmt, :call) + ft = argextype(stmt.args[1], src, sptypes) + f = singleton_type(ft) + if f === nothing || !(f isa Builtin) + pushfirst!(stmt.args, GlobalRef(Core.Compiler, :execute_with_plugin)) + end + end + end +end + function CodeInstance( result::InferenceResult, @nospecialize(inferred_result), valid_worlds::WorldRange) local const_flags::Int32 diff --git a/base/errorshow.jl b/base/errorshow.jl index e56a095d832fd..2f6fa6604b775 100644 --- a/base/errorshow.jl +++ b/base/errorshow.jl @@ -409,7 +409,11 @@ function show_method_candidates(io::IO, ex::MethodError, @nospecialize kwargs=() buf = IOBuffer() iob0 = iob = IOContext(buf, io) tv = Any[] - sig0 = method.sig + if func isa Core.OpaqueClosure + sig0 = signature_type(func, typeof(func).parameters[1]) + else + sig0 = method.sig + end while isa(sig0, UnionAll) push!(tv, sig0.var) iob = IOContext(iob, :unionall_env => sig0.var) diff --git a/base/methodshow.jl b/base/methodshow.jl index ba9911179fd19..2688434423f30 100644 --- a/base/methodshow.jl +++ b/base/methodshow.jl @@ -79,6 +79,9 @@ end # NOTE: second argument is deprecated and is no longer used function kwarg_decl(m::Method, kwtype = nothing) + if m.sig === Tuple # OpaqueClosure + return Symbol[] + end mt = get_methodtable(m) if isdefined(mt, :kwsorter) kwtype = typeof(mt.kwsorter) diff --git a/src/method.c b/src/method.c index 7325670bd76a4..6dd9546a8dde0 100644 --- a/src/method.c +++ b/src/method.c @@ -19,7 +19,7 @@ extern jl_value_t *jl_builtin_getfield; extern jl_value_t *jl_builtin_tuple; jl_method_t *jl_make_opaque_closure_method(jl_module_t *module, jl_value_t *name, - jl_value_t *nargs, jl_value_t *functionloc, jl_code_info_t *ci, int isva); + int nargs, jl_value_t *functionloc, jl_code_info_t *ci, int isva); static void check_c_types(const char *where, jl_value_t *rt, jl_value_t *at) { @@ -51,11 +51,14 @@ static jl_value_t *resolve_globals(jl_value_t *expr, jl_module_t *module, jl_sve return jl_module_globalref(module, (jl_sym_t*)expr); } else if (jl_is_returnnode(expr)) { - jl_value_t *val = resolve_globals(jl_returnnode_value(expr), module, sparam_vals, binding_effects, eager_resolve); - if (val != jl_returnnode_value(expr)) { - JL_GC_PUSH1(&val); - expr = jl_new_struct(jl_returnnode_type, val); - JL_GC_POP(); + jl_value_t *retval = jl_returnnode_value(expr); + if (retval) { + jl_value_t *val = resolve_globals(retval, module, sparam_vals, binding_effects, eager_resolve); + if (val != retval) { + JL_GC_PUSH1(&val); + expr = jl_new_struct(jl_returnnode_type, val); + JL_GC_POP(); + } } return expr; } @@ -102,7 +105,7 @@ static jl_value_t *resolve_globals(jl_value_t *expr, jl_module_t *module, jl_sve if (!jl_is_code_info(ci)) { jl_error("opaque_closure_method: lambda should be a CodeInfo"); } - jl_method_t *m = jl_make_opaque_closure_method(module, name, nargs, functionloc, (jl_code_info_t*)ci, isva); + jl_method_t *m = jl_make_opaque_closure_method(module, name, jl_unbox_long(nargs), functionloc, (jl_code_info_t*)ci, isva); return (jl_value_t*)m; } if (e->head == jl_cfunction_sym) { @@ -782,7 +785,7 @@ JL_DLLEXPORT jl_method_t *jl_new_method_uninit(jl_module_t *module) // method definition ---------------------------------------------------------- jl_method_t *jl_make_opaque_closure_method(jl_module_t *module, jl_value_t *name, - jl_value_t *nargs, jl_value_t *functionloc, jl_code_info_t *ci, int isva) + int nargs, jl_value_t *functionloc, jl_code_info_t *ci, int isva) { jl_method_t *m = jl_new_method_uninit(module); JL_GC_PUSH1(&m); @@ -796,7 +799,7 @@ jl_method_t *jl_make_opaque_closure_method(jl_module_t *module, jl_value_t *name assert(jl_is_symbol(name)); m->name = (jl_sym_t*)name; } - m->nargs = jl_unbox_long(nargs) + 1; + m->nargs = nargs + 1; assert(jl_is_linenode(functionloc)); jl_value_t *file = jl_linenode_file(functionloc); m->file = jl_is_symbol(file) ? (jl_sym_t*)file : jl_empty_sym; diff --git a/src/opaque_closure.c b/src/opaque_closure.c index 3fceadf67a583..d34989181b7ad 100644 --- a/src/opaque_closure.c +++ b/src/opaque_closure.c @@ -22,8 +22,23 @@ JL_DLLEXPORT int jl_is_valid_oc_argtype(jl_tupletype_t *argt, jl_method_t *sourc return 1; } -jl_opaque_closure_t *jl_new_opaque_closure(jl_tupletype_t *argt, jl_value_t *rt_lb, jl_value_t *rt_ub, - jl_value_t *source_, jl_value_t **env, size_t nenv) +static jl_value_t *prepend_type(jl_value_t *t0, jl_tupletype_t *t) +{ + jl_svec_t *sig_args = NULL; + JL_GC_PUSH1(&sig_args); + size_t nsig = 1 + jl_svec_len(t->parameters); + sig_args = jl_alloc_svec_uninit(nsig); + jl_svecset(sig_args, 0, t0); + for (size_t i = 0; i < nsig-1; ++i) { + jl_svecset(sig_args, 1+i, jl_tparam(t, i)); + } + jl_value_t *sigtype = (jl_value_t*)jl_apply_tuple_type_v(jl_svec_data(sig_args), nsig); + JL_GC_POP(); + return sigtype; +} + +static jl_opaque_closure_t *new_opaque_closure(jl_tupletype_t *argt, jl_value_t *rt_lb, jl_value_t *rt_ub, + jl_value_t *source_, jl_value_t *captures) { if (!jl_is_tuple_type((jl_value_t*)argt)) { jl_error("OpaqueClosure argument tuple must be a tuple type"); @@ -40,26 +55,19 @@ jl_opaque_closure_t *jl_new_opaque_closure(jl_tupletype_t *argt, jl_value_t *rt_ } if (jl_nparams(argt) + 1 - jl_is_va_tuple(argt) < source->nargs - source->isva) jl_error("Argument type tuple has too few required arguments for method"); - jl_task_t *ct = jl_current_task; + jl_value_t *sigtype = NULL; + JL_GC_PUSH1(&sigtype); + sigtype = prepend_type(jl_typeof(captures), argt); + jl_value_t *oc_type JL_ALWAYS_LEAFTYPE; oc_type = jl_apply_type2((jl_value_t*)jl_opaque_closure_type, (jl_value_t*)argt, rt_ub); JL_GC_PROMISE_ROOTED(oc_type); - jl_value_t *captures = NULL, *sigtype = NULL; - jl_svec_t *sig_args = NULL; - JL_GC_PUSH3(&captures, &sigtype, &sig_args); - captures = jl_f_tuple(NULL, env, nenv); - size_t nsig = 1 + jl_svec_len(argt->parameters); - sig_args = jl_alloc_svec_uninit(nsig); - jl_svecset(sig_args, 0, jl_typeof(captures)); - for (size_t i = 0; i < nsig-1; ++i) { - jl_svecset(sig_args, 1+i, jl_tparam(argt, i)); - } - sigtype = (jl_value_t*)jl_apply_tuple_type_v(jl_svec_data(sig_args), nsig); jl_method_instance_t *mi = jl_specializations_get_linfo(source, sigtype, jl_emptysvec); size_t world = jl_atomic_load_acquire(&jl_world_counter); jl_code_instance_t *ci = jl_compile_method_internal(mi, world); + jl_task_t *ct = jl_current_task; jl_opaque_closure_t *oc = (jl_opaque_closure_t*)jl_gc_alloc(ct->ptls, sizeof(jl_opaque_closure_t), oc_type); JL_GC_POP(); oc->source = source; @@ -82,6 +90,52 @@ jl_opaque_closure_t *jl_new_opaque_closure(jl_tupletype_t *argt, jl_value_t *rt_ return oc; } +jl_opaque_closure_t *jl_new_opaque_closure(jl_tupletype_t *argt, jl_value_t *rt_lb, jl_value_t *rt_ub, + jl_value_t *source_, jl_value_t **env, size_t nenv) +{ + jl_value_t *captures = jl_f_tuple(NULL, env, nenv); + JL_GC_PUSH1(&captures); + jl_opaque_closure_t *oc = new_opaque_closure(argt, rt_lb, rt_ub, source_, captures); + JL_GC_POP(); + return oc; +} + +jl_method_t *jl_make_opaque_closure_method(jl_module_t *module, jl_value_t *name, + int nargs, jl_value_t *functionloc, jl_code_info_t *ci, int isva); + +JL_DLLEXPORT jl_code_instance_t* jl_new_codeinst( + jl_method_instance_t *mi, jl_value_t *rettype, + jl_value_t *inferred_const, jl_value_t *inferred, + int32_t const_flags, size_t min_world, size_t max_world, + uint32_t ipo_effects, uint32_t effects, jl_value_t *argescapes, + uint8_t relocatability); + +JL_DLLEXPORT void jl_mi_cache_insert(jl_method_instance_t *mi JL_ROOTING_ARGUMENT, + jl_code_instance_t *ci JL_ROOTED_ARGUMENT JL_MAYBE_UNROOTED); + +JL_DLLEXPORT jl_opaque_closure_t *jl_new_opaque_closure_from_code_info(jl_tupletype_t *argt, jl_value_t *rt_lb, jl_value_t *rt_ub, + jl_module_t *mod, jl_code_info_t *ci, int lineno, jl_value_t *file, int nargs, int isva, jl_value_t *env) +{ + if (!ci->inferred) + jl_error("CodeInfo must already be inferred"); + jl_value_t *root = NULL, *sigtype = NULL; + jl_code_instance_t *inst = NULL; + JL_GC_PUSH3(&root, &sigtype, &inst); + root = jl_box_long(lineno); + root = jl_new_struct(jl_linenumbernode_type, root, file); + root = (jl_value_t*)jl_make_opaque_closure_method(mod, jl_nothing, nargs, root, ci, isva); + + sigtype = prepend_type(jl_typeof(env), argt); + jl_method_instance_t *mi = jl_specializations_get_linfo((jl_method_t*)root, sigtype, jl_emptysvec); + inst = jl_new_codeinst(mi, rt_ub, NULL, (jl_value_t*)ci, + 0, ((jl_method_t*)root)->primary_world, -1, 0, 0, jl_nothing, 0); + jl_mi_cache_insert(mi, inst); + + jl_opaque_closure_t *oc = new_opaque_closure(argt, rt_lb, rt_ub, root, env); + JL_GC_POP(); + return oc; +} + JL_CALLABLE(jl_new_opaque_closure_jlcall) { if (nargs < 4) diff --git a/test/choosetests.jl b/test/choosetests.jl index 099dfa18a71c5..b5bf8f9b54b49 100644 --- a/test/choosetests.jl +++ b/test/choosetests.jl @@ -141,8 +141,8 @@ function choosetests(choices = []) # do subarray before sparse but after linalg filtertests!(tests, "subarray") filtertests!(tests, "compiler", ["compiler/inference", "compiler/validation", - "compiler/ssair", "compiler/irpasses", "compiler/codegen", - "compiler/inline", "compiler/contextual", "compiler/AbstractInterpreter", + "compiler/ssair", "compiler/inline", "compiler/irpasses", "compiler/codegen", + "compiler/contextual", "compiler/AbstractInterpreter", "compiler/plugin", "compiler/EscapeAnalysis/local", "compiler/EscapeAnalysis/interprocedural"]) filtertests!(tests, "compiler/EscapeAnalysis", [ "compiler/EscapeAnalysis/local", "compiler/EscapeAnalysis/interprocedural"]) diff --git a/test/compiler/plugin.jl b/test/compiler/plugin.jl new file mode 100644 index 0000000000000..1f699f3c35dd3 --- /dev/null +++ b/test/compiler/plugin.jl @@ -0,0 +1,132 @@ +# %% +const CC = Core.Compiler +using Test, .CC.Plugin +import Core: MethodInstance, CodeInstance +import .CC: WorldRange, WorldView + +# %% +# SinCosRewriter +# -------------- + +struct SinCosRewriterCache + dict::IdDict{MethodInstance,CodeInstance} +end +struct SinCosRewriter <: CC.AbstractInterpreter + interp::CC.NativeInterpreter + cache::SinCosRewriterCache +end +let global_cache = SinCosRewriterCache(IdDict{MethodInstance,CodeInstance}()) + global function SinCosRewriter( + world = Base.get_world_counter(); + interp = CC.NativeInterpreter(world), + cache = global_cache) + return SinCosRewriter(interp, cache) + end +end +CC.InferenceParams(interp::SinCosRewriter) = CC.InferenceParams(interp.interp) +CC.OptimizationParams(interp::SinCosRewriter) = CC.OptimizationParams(interp.interp) +CC.get_world_counter(interp::SinCosRewriter) = CC.get_world_counter(interp.interp) +CC.get_inference_cache(interp::SinCosRewriter) = CC.get_inference_cache(interp.interp) +CC.code_cache(interp::SinCosRewriter) = WorldView(interp.cache, WorldRange(CC.get_world_counter(interp))) +CC.get(wvc::WorldView{<:SinCosRewriterCache}, mi::MethodInstance, default) = get(wvc.cache.dict, mi, default) +CC.getindex(wvc::WorldView{<:SinCosRewriterCache}, mi::MethodInstance) = getindex(wvc.cache.dict, mi) +CC.haskey(wvc::WorldView{<:SinCosRewriterCache}, mi::MethodInstance) = haskey(wvc.cache.dict, mi) +function CC.setindex!(wvc::WorldView{<:SinCosRewriterCache}, ci::CodeInstance, mi::MethodInstance) + # ccall(:jl_mi_cache_insert, Cvoid, (Any, Any), mi, ci) + setindex!(wvc.cache.dict, ci, mi) +end + +function CC.abstract_call_gf_by_type( + interp::SinCosRewriter, @nospecialize(f), + arginfo::CC.ArgInfo, @nospecialize(atype), + sv::CC.InferenceState, max_methods::Int) + if f === sin + f = cos + atype′ = CC.unwrap_unionall(atype)::DataType + atype′ = Tuple{typeof(cos), atype′.parameters[2:end]...} + atype = CC.rewrap_unionall(atype′, atype) + end + return Base.@invoke CC.abstract_call_gf_by_type( + interp::CC.AbstractInterpreter, f, + arginfo::CC.ArgInfo, atype, + sv::CC.InferenceState, max_methods::Int) +end + +global gv::Any = 42 +macro test_native_execution!() + return :(let + Base.Experimental.@force_compile + global gv + v1 = 42 + v2 = gv::Int + v3 = gv + @test isone(sin(v1)^2 + cos(v1)^2) + @test isone(@noinline sin(v2)^2 + cos(v2)^2) + @test isone(sin(v3)^2 + cos(v3)^2) + end) +end + +# simple +@test execute_with_plugin(; interp=SinCosRewriter()) do + sin(42) +end == cos(42) +@test execute_with_plugin(42; interp=SinCosRewriter()) do a + sin(a) +end == cos(42) +@test execute_with_plugin(sin, 42; interp=SinCosRewriter()) do f, a + f(a) +end == cos(42) +@test_native_execution! + +# global +nested(a) = sin(a) # => cos(a) +@test execute_with_plugin(42; interp=SinCosRewriter()) do a + nested(a) +end == cos(42) +@test_native_execution! + +# dynamic +global gv::Any = 42 +@test execute_with_plugin() do + sin(gv::Any) # dynamic dispatch +end == cos(42) +@test execute_with_plugin(sin) do f + f(gv) +end == cos(42) +global gf::Any = sin +@test execute_with_plugin() do + (gf::Any)(gv::Any) +end == cos(42) +@test_native_execution! + +# static dispatch +@noinline noninlined_sin() = sin(gv::Int) +@test_broken execute_with_plugin(; interp=SinCosRewriter()) do + noninlined_sin() +end == cos(42) +@test_native_execution! + +# invoke +@test_broken execute_with_plugin(42.0; interp=SinCosRewriter()) do a + Base.@invoke sin(a::Float64) +end == cos(42.0) +@test_native_execution! + +# end to end +function kernel(fs) + r = 0 + for i = 1:length(fs) + r += sum(fs[i](i)) + end + return r +end +let + fs = (sin, sin, cos) + gs = (cos, cos, cos) + @test execute_with_plugin(kernel, fs) == kernel(gs) + + fs = Any[sin, sin, cos] + gs = Any[cos, cos, cos] + @test execute_with_plugin(kernel, fs) == kernel(gs) +end +@test_native_execution! diff --git a/test/opaque_closure.jl b/test/opaque_closure.jl index 3525e8bcb03eb..d2190424481e7 100644 --- a/test/opaque_closure.jl +++ b/test/opaque_closure.jl @@ -239,3 +239,49 @@ end let oc = @opaque a->sin(a) @test length(code_typed(oc, (Int,))) == 1 end + +# constructing an opaque closure from IRCode +using Core.Compiler: IRCode +using Core: CodeInfo + +function OC(ir::IRCode, nargs::Int, isva::Bool, env...) + if (isva && nargs > length(ir.argtypes)) || (!isva && nargs != length(ir.argtypes)-1) + throw(ArgumentError("invalid argument count")) + end + src = ccall(:jl_new_code_info_uninit, Ref{CodeInfo}, ()) + src.slotflags = UInt8[] + src.slotnames = fill(:none, nargs+1) + Core.Compiler.replace_code_newstyle!(src, ir, nargs+1) + Core.Compiler.widen_all_consts!(src) + src.inferred = true + # NOTE: we need ir.argtypes[1] == typeof(env) + + ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any), + Tuple{ir.argtypes[2:end]...}, Union{}, Any, @__MODULE__, src, 0, nothing, nargs, isva, env) +end + +function OC(src::CodeInfo, env...) + M = src.parent.def + sig = Base.tuple_type_tail(src.parent.specTypes) + + ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any), + sig, Union{}, Any, @__MODULE__, src, 0, nothing, M.nargs - 1, M.isva, env) +end + +let ci = code_typed(+, (Int, Int))[1][1] + ir = Core.Compiler.inflate_ir(ci) + @test OC(ir, 2, false)(40, 2) == 42 + @test OC(ci)(40, 2) == 42 +end + +let ci = code_typed((x, y...)->(x, y), (Int, Int))[1][1] + ir = Core.Compiler.inflate_ir(ci) + @test OC(ir, 2, true)(40, 2) === (40, (2,)) + @test OC(ci)(40, 2) === (40, (2,)) +end + +let ci = code_typed((x, y...)->(x, y), (Int, Int))[1][1] + ir = Core.Compiler.inflate_ir(ci) + @test_throws MethodError OC(ir, 2, true)(1, 2, 3) + @test_throws MethodError OC(ci)(1, 2, 3) +end