Skip to content

wip: scratch compiler plugin #44950

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ function extrema(x::Array)
return vmin, vmax
end

include("compiler/plugin.jl")

include("compiler/bootstrap.jl")
ccall(:jl_set_typeinf_func, Cvoid, (Any,), typeinf_ext_toplevel)

Expand Down
47 changes: 47 additions & 0 deletions base/compiler/plugin.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# this simple code assumes a plugin AbstractInterpreter manages its own code cache
# in a way that it is totally separated from the native code cache

# TODO more composable way
global PLUGIN_INTERPRETER::AbstractInterpreter = NativeInterpreter()
isplugin(interp::AbstractInterpreter) = let
Copy link
Contributor

@MasonProtter MasonProtter May 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't we just write

isplugin(interp::AbstractInterpreter) = false

and then when someone writes a plugin, they add an overload

CC.isplugin(interp::MyInterpreter) = true

I tested this locally and it appears to work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see the problem, this doesn't work with your way of dealing with dynamic dispatch. Is that the background for #46220 then?

global PLUGIN_INTERPRETER
interp === PLUGIN_INTERPRETER
end

function new_opaque_closure(src::CodeInfo, nargs::Int, @nospecialize(rt),
@nospecialize(env...))
@assert src.inferred "unoptimized IR unsupported"
argt = argtypes_to_type((src.slottypes::Vector{Any})[2:nargs+1])

# M = src.parent.def
# sig = Base.tuple_type_tail(src.parent.specTypes)

return ccall(:jl_new_opaque_closure_from_code_info, Any,
(Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any),
argt, Union{}, rt, @__MODULE__, src, 0, nothing, nargs, false, env)
end

function execute_with_plugin(f, args...;
world::UInt = get_world_counter(),
interp::AbstractInterpreter = PLUGIN_INTERPRETER)
global PLUGIN_INTERPRETER = interp
tt = Tuple{Core.Typeof(f), Any[Core.Typeof(args[i]) for i = 1:length(args)]...}
matches = _methods_by_ftype(tt, -1, world)
matches !== nothing || throw(MethodError(f, args, world))
length(matches) ≠ 1 && throw(MethodError(f, args, world))
m = first(matches)::MethodMatch
src, rt = typeinf_code(interp,
m.method, m.spec_types, m.sparams, true)
if src === nothing
# builtin, or bad generated function
return f(args...) # XXX
end
# TODO support varargs
oc = new_opaque_closure(src, length(args), rt)
return oc(args...)
end

module Plugin
import ..execute_with_plugin
export execute_with_plugin
end
26 changes: 26 additions & 0 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,36 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
cache_result!(interp, caller)
end
finish!(interp, caller)
if isplugin(interp)
fixup_plugin_entry!(interp, caller)
end
end
return true
end

function fixup_plugin_entry!(interp::AbstractInterpreter, caller::InferenceResult)
src = caller.src
if src isa CodeInfo
fixup_plugin_entry!(src, collect(Any, caller.linfo.sparam_vals))
elseif src isa OptimizationState
fixup_plugin_entry!(src.src, src.sptypes)
end
end

function fixup_plugin_entry!(src::CodeInfo, sptypes::Vector{Any})
@assert src.inferred
for i = 1:length(src.code)
stmt = src.code[i]
if isexpr(stmt, :call)
ft = argextype(stmt.args[1], src, sptypes)
f = singleton_type(ft)
if f === nothing || !(f isa Builtin)
pushfirst!(stmt.args, GlobalRef(Core.Compiler, :execute_with_plugin))
end
end
end
end

function CodeInstance(
result::InferenceResult, @nospecialize(inferred_result), valid_worlds::WorldRange)
local const_flags::Int32
Expand Down
6 changes: 5 additions & 1 deletion base/errorshow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,11 @@ function show_method_candidates(io::IO, ex::MethodError, @nospecialize kwargs=()
buf = IOBuffer()
iob0 = iob = IOContext(buf, io)
tv = Any[]
sig0 = method.sig
if func isa Core.OpaqueClosure
sig0 = signature_type(func, typeof(func).parameters[1])
else
sig0 = method.sig
end
while isa(sig0, UnionAll)
push!(tv, sig0.var)
iob = IOContext(iob, :unionall_env => sig0.var)
Expand Down
3 changes: 3 additions & 0 deletions base/methodshow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ end

# NOTE: second argument is deprecated and is no longer used
function kwarg_decl(m::Method, kwtype = nothing)
if m.sig === Tuple # OpaqueClosure
return Symbol[]
end
mt = get_methodtable(m)
if isdefined(mt, :kwsorter)
kwtype = typeof(mt.kwsorter)
Expand Down
21 changes: 12 additions & 9 deletions src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ extern jl_value_t *jl_builtin_getfield;
extern jl_value_t *jl_builtin_tuple;

jl_method_t *jl_make_opaque_closure_method(jl_module_t *module, jl_value_t *name,
jl_value_t *nargs, jl_value_t *functionloc, jl_code_info_t *ci, int isva);
int nargs, jl_value_t *functionloc, jl_code_info_t *ci, int isva);

static void check_c_types(const char *where, jl_value_t *rt, jl_value_t *at)
{
Expand Down Expand Up @@ -51,11 +51,14 @@ static jl_value_t *resolve_globals(jl_value_t *expr, jl_module_t *module, jl_sve
return jl_module_globalref(module, (jl_sym_t*)expr);
}
else if (jl_is_returnnode(expr)) {
jl_value_t *val = resolve_globals(jl_returnnode_value(expr), module, sparam_vals, binding_effects, eager_resolve);
if (val != jl_returnnode_value(expr)) {
JL_GC_PUSH1(&val);
expr = jl_new_struct(jl_returnnode_type, val);
JL_GC_POP();
jl_value_t *retval = jl_returnnode_value(expr);
if (retval) {
jl_value_t *val = resolve_globals(retval, module, sparam_vals, binding_effects, eager_resolve);
if (val != retval) {
JL_GC_PUSH1(&val);
expr = jl_new_struct(jl_returnnode_type, val);
JL_GC_POP();
}
}
return expr;
}
Expand Down Expand Up @@ -102,7 +105,7 @@ static jl_value_t *resolve_globals(jl_value_t *expr, jl_module_t *module, jl_sve
if (!jl_is_code_info(ci)) {
jl_error("opaque_closure_method: lambda should be a CodeInfo");
}
jl_method_t *m = jl_make_opaque_closure_method(module, name, nargs, functionloc, (jl_code_info_t*)ci, isva);
jl_method_t *m = jl_make_opaque_closure_method(module, name, jl_unbox_long(nargs), functionloc, (jl_code_info_t*)ci, isva);
return (jl_value_t*)m;
}
if (e->head == jl_cfunction_sym) {
Expand Down Expand Up @@ -782,7 +785,7 @@ JL_DLLEXPORT jl_method_t *jl_new_method_uninit(jl_module_t *module)
// method definition ----------------------------------------------------------

jl_method_t *jl_make_opaque_closure_method(jl_module_t *module, jl_value_t *name,
jl_value_t *nargs, jl_value_t *functionloc, jl_code_info_t *ci, int isva)
int nargs, jl_value_t *functionloc, jl_code_info_t *ci, int isva)
{
jl_method_t *m = jl_new_method_uninit(module);
JL_GC_PUSH1(&m);
Expand All @@ -796,7 +799,7 @@ jl_method_t *jl_make_opaque_closure_method(jl_module_t *module, jl_value_t *name
assert(jl_is_symbol(name));
m->name = (jl_sym_t*)name;
}
m->nargs = jl_unbox_long(nargs) + 1;
m->nargs = nargs + 1;
assert(jl_is_linenode(functionloc));
jl_value_t *file = jl_linenode_file(functionloc);
m->file = jl_is_symbol(file) ? (jl_sym_t*)file : jl_empty_sym;
Expand Down
82 changes: 68 additions & 14 deletions src/opaque_closure.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,23 @@ JL_DLLEXPORT int jl_is_valid_oc_argtype(jl_tupletype_t *argt, jl_method_t *sourc
return 1;
}

jl_opaque_closure_t *jl_new_opaque_closure(jl_tupletype_t *argt, jl_value_t *rt_lb, jl_value_t *rt_ub,
jl_value_t *source_, jl_value_t **env, size_t nenv)
static jl_value_t *prepend_type(jl_value_t *t0, jl_tupletype_t *t)
{
jl_svec_t *sig_args = NULL;
JL_GC_PUSH1(&sig_args);
size_t nsig = 1 + jl_svec_len(t->parameters);
sig_args = jl_alloc_svec_uninit(nsig);
jl_svecset(sig_args, 0, t0);
for (size_t i = 0; i < nsig-1; ++i) {
jl_svecset(sig_args, 1+i, jl_tparam(t, i));
}
jl_value_t *sigtype = (jl_value_t*)jl_apply_tuple_type_v(jl_svec_data(sig_args), nsig);
JL_GC_POP();
return sigtype;
}

static jl_opaque_closure_t *new_opaque_closure(jl_tupletype_t *argt, jl_value_t *rt_lb, jl_value_t *rt_ub,
jl_value_t *source_, jl_value_t *captures)
{
if (!jl_is_tuple_type((jl_value_t*)argt)) {
jl_error("OpaqueClosure argument tuple must be a tuple type");
Expand All @@ -40,26 +55,19 @@ jl_opaque_closure_t *jl_new_opaque_closure(jl_tupletype_t *argt, jl_value_t *rt_
}
if (jl_nparams(argt) + 1 - jl_is_va_tuple(argt) < source->nargs - source->isva)
jl_error("Argument type tuple has too few required arguments for method");
jl_task_t *ct = jl_current_task;
jl_value_t *sigtype = NULL;
JL_GC_PUSH1(&sigtype);
sigtype = prepend_type(jl_typeof(captures), argt);

jl_value_t *oc_type JL_ALWAYS_LEAFTYPE;
oc_type = jl_apply_type2((jl_value_t*)jl_opaque_closure_type, (jl_value_t*)argt, rt_ub);
JL_GC_PROMISE_ROOTED(oc_type);
jl_value_t *captures = NULL, *sigtype = NULL;
jl_svec_t *sig_args = NULL;
JL_GC_PUSH3(&captures, &sigtype, &sig_args);
captures = jl_f_tuple(NULL, env, nenv);

size_t nsig = 1 + jl_svec_len(argt->parameters);
sig_args = jl_alloc_svec_uninit(nsig);
jl_svecset(sig_args, 0, jl_typeof(captures));
for (size_t i = 0; i < nsig-1; ++i) {
jl_svecset(sig_args, 1+i, jl_tparam(argt, i));
}
sigtype = (jl_value_t*)jl_apply_tuple_type_v(jl_svec_data(sig_args), nsig);
jl_method_instance_t *mi = jl_specializations_get_linfo(source, sigtype, jl_emptysvec);
size_t world = jl_atomic_load_acquire(&jl_world_counter);
jl_code_instance_t *ci = jl_compile_method_internal(mi, world);

jl_task_t *ct = jl_current_task;
jl_opaque_closure_t *oc = (jl_opaque_closure_t*)jl_gc_alloc(ct->ptls, sizeof(jl_opaque_closure_t), oc_type);
JL_GC_POP();
oc->source = source;
Expand All @@ -82,6 +90,52 @@ jl_opaque_closure_t *jl_new_opaque_closure(jl_tupletype_t *argt, jl_value_t *rt_
return oc;
}

jl_opaque_closure_t *jl_new_opaque_closure(jl_tupletype_t *argt, jl_value_t *rt_lb, jl_value_t *rt_ub,
jl_value_t *source_, jl_value_t **env, size_t nenv)
{
jl_value_t *captures = jl_f_tuple(NULL, env, nenv);
JL_GC_PUSH1(&captures);
jl_opaque_closure_t *oc = new_opaque_closure(argt, rt_lb, rt_ub, source_, captures);
JL_GC_POP();
return oc;
}

jl_method_t *jl_make_opaque_closure_method(jl_module_t *module, jl_value_t *name,
int nargs, jl_value_t *functionloc, jl_code_info_t *ci, int isva);

JL_DLLEXPORT jl_code_instance_t* jl_new_codeinst(
jl_method_instance_t *mi, jl_value_t *rettype,
jl_value_t *inferred_const, jl_value_t *inferred,
int32_t const_flags, size_t min_world, size_t max_world,
uint32_t ipo_effects, uint32_t effects, jl_value_t *argescapes,
uint8_t relocatability);

JL_DLLEXPORT void jl_mi_cache_insert(jl_method_instance_t *mi JL_ROOTING_ARGUMENT,
jl_code_instance_t *ci JL_ROOTED_ARGUMENT JL_MAYBE_UNROOTED);

JL_DLLEXPORT jl_opaque_closure_t *jl_new_opaque_closure_from_code_info(jl_tupletype_t *argt, jl_value_t *rt_lb, jl_value_t *rt_ub,
jl_module_t *mod, jl_code_info_t *ci, int lineno, jl_value_t *file, int nargs, int isva, jl_value_t *env)
{
if (!ci->inferred)
jl_error("CodeInfo must already be inferred");
jl_value_t *root = NULL, *sigtype = NULL;
jl_code_instance_t *inst = NULL;
JL_GC_PUSH3(&root, &sigtype, &inst);
root = jl_box_long(lineno);
root = jl_new_struct(jl_linenumbernode_type, root, file);
root = (jl_value_t*)jl_make_opaque_closure_method(mod, jl_nothing, nargs, root, ci, isva);

sigtype = prepend_type(jl_typeof(env), argt);
jl_method_instance_t *mi = jl_specializations_get_linfo((jl_method_t*)root, sigtype, jl_emptysvec);
inst = jl_new_codeinst(mi, rt_ub, NULL, (jl_value_t*)ci,
0, ((jl_method_t*)root)->primary_world, -1, 0, 0, jl_nothing, 0);
jl_mi_cache_insert(mi, inst);

jl_opaque_closure_t *oc = new_opaque_closure(argt, rt_lb, rt_ub, root, env);
JL_GC_POP();
return oc;
}

JL_CALLABLE(jl_new_opaque_closure_jlcall)
{
if (nargs < 4)
Expand Down
4 changes: 2 additions & 2 deletions test/choosetests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ function choosetests(choices = [])
# do subarray before sparse but after linalg
filtertests!(tests, "subarray")
filtertests!(tests, "compiler", ["compiler/inference", "compiler/validation",
"compiler/ssair", "compiler/irpasses", "compiler/codegen",
"compiler/inline", "compiler/contextual", "compiler/AbstractInterpreter",
"compiler/ssair", "compiler/inline", "compiler/irpasses", "compiler/codegen",
"compiler/contextual", "compiler/AbstractInterpreter", "compiler/plugin",
"compiler/EscapeAnalysis/local", "compiler/EscapeAnalysis/interprocedural"])
filtertests!(tests, "compiler/EscapeAnalysis", [
"compiler/EscapeAnalysis/local", "compiler/EscapeAnalysis/interprocedural"])
Expand Down
Loading