Skip to content

Commit 62329a9

Browse files
committed
wip: scratch compiler plugin
1 parent 6c22265 commit 62329a9

File tree

5 files changed

+186
-2
lines changed

5 files changed

+186
-2
lines changed

base/compiler/compiler.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ function extrema(x::Array)
154154
return vmin, vmax
155155
end
156156

157+
include("compiler/plugin.jl")
158+
157159
include("compiler/bootstrap.jl")
158160
ccall(:jl_set_typeinf_func, Cvoid, (Any,), typeinf_ext_toplevel)
159161

base/compiler/plugin.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# this simple code assumes a plugin AbstractInterpreter manages its own code cache
2+
# in a way that it is totally separated from the native code cache
3+
4+
# TODO more composable way
5+
global PLUGIN_INTERPRETER::AbstractInterpreter = NativeInterpreter()
6+
isplugin(interp::AbstractInterpreter) = let
7+
global PLUGIN_INTERPRETER
8+
interp === PLUGIN_INTERPRETER
9+
end
10+
11+
function new_opaque_closure(src::CodeInfo, nargs::Int, @nospecialize(rt),
12+
@nospecialize(env...))
13+
@assert src.inferred "unoptimized IR unsupported"
14+
argt = argtypes_to_type((src.slottypes::Vector{Any})[2:nargs+1])
15+
16+
# M = src.parent.def
17+
# sig = Base.tuple_type_tail(src.parent.specTypes)
18+
19+
return ccall(:jl_new_opaque_closure_from_code_info, Any,
20+
(Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any),
21+
argt, Union{}, rt, @__MODULE__, src, 0, nothing, nargs, false, env)
22+
end
23+
24+
function execute_with_plugin(f, args...;
25+
world::UInt = get_world_counter(),
26+
interp::AbstractInterpreter = PLUGIN_INTERPRETER)
27+
global PLUGIN_INTERPRETER = interp
28+
tt = Tuple{Core.Typeof(f), Any[Core.Typeof(args[i]) for i = 1:length(args)]...}
29+
matches = _methods_by_ftype(tt, -1, world)
30+
matches !== nothing || throw(MethodError(f, args, world))
31+
length(matches) 1 && throw(MethodError(f, args, world))
32+
m = first(matches)::MethodMatch
33+
ci, rt = typeinf_code(interp,
34+
m.method, m.spec_types, m.sparams, true)
35+
if ci === nothing
36+
# builtin, or bad generated function
37+
return f(args...) # XXX
38+
end
39+
# TODO support varargs
40+
oc = new_opaque_closure(ci, length(args), rt)
41+
return oc(args...)
42+
end
43+
44+
module Plugin
45+
import ..execute_with_plugin
46+
export execute_with_plugin
47+
end

base/compiler/typeinfer.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,10 +277,36 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
277277
cache_result!(interp, caller)
278278
end
279279
finish!(interp, caller)
280+
if isplugin(interp)
281+
fixup_plugin_entry!(interp, caller)
282+
end
280283
end
281284
return true
282285
end
283286

287+
function fixup_plugin_entry!(interp::AbstractInterpreter, caller::InferenceResult)
288+
src = caller.src
289+
if src isa CodeInfo
290+
fixup_plugin_entry!(src, collect(Any, caller.linfo.sparam_vals))
291+
elseif src isa OptimizationState
292+
fixup_plugin_entry!(src.src, src.sptypes)
293+
end
294+
end
295+
296+
function fixup_plugin_entry!(src::CodeInfo, sptypes::Vector{Any})
297+
@assert src.inferred
298+
for i = 1:length(src.code)
299+
stmt = src.code[i]
300+
if isexpr(stmt, :call)
301+
ft = argextype(stmt.args[1], src, sptypes)
302+
f = singleton_type(ft)
303+
if f === nothing || !(f isa Builtin)
304+
pushfirst!(stmt.args, GlobalRef(Core.Compiler, :execute_with_plugin))
305+
end
306+
end
307+
end
308+
end
309+
284310
function CodeInstance(
285311
result::InferenceResult, @nospecialize(inferred_result), valid_worlds::WorldRange)
286312
local const_flags::Int32

test/choosetests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,8 @@ function choosetests(choices = [])
141141
# do subarray before sparse but after linalg
142142
filtertests!(tests, "subarray")
143143
filtertests!(tests, "compiler", ["compiler/inference", "compiler/validation",
144-
"compiler/ssair", "compiler/irpasses", "compiler/codegen",
145-
"compiler/inline", "compiler/contextual", "compiler/AbstractInterpreter",
144+
"compiler/ssair", "compiler/inline", "compiler/irpasses", "compiler/codegen",
145+
"compiler/contextual", "compiler/AbstractInterpreter", "compiler/plugin",
146146
"compiler/EscapeAnalysis/local", "compiler/EscapeAnalysis/interprocedural"])
147147
filtertests!(tests, "compiler/EscapeAnalysis", [
148148
"compiler/EscapeAnalysis/local", "compiler/EscapeAnalysis/interprocedural"])

test/compiler/plugin.jl

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# %%
2+
const CC = Core.Compiler
3+
using Test, .CC.Plugin
4+
import Core: MethodInstance, CodeInstance
5+
import .CC: WorldRange, WorldView
6+
7+
# %%
8+
# SinCosRewriter
9+
# --------------
10+
11+
struct SinCosRewriterCache
12+
dict::IdDict{MethodInstance,CodeInstance}
13+
end
14+
struct SinCosRewriter <: CC.AbstractInterpreter
15+
interp::CC.NativeInterpreter
16+
cache::SinCosRewriterCache
17+
end
18+
let global_cache = SinCosRewriterCache(IdDict{MethodInstance,CodeInstance}())
19+
global function SinCosRewriter(
20+
world = Base.get_world_counter();
21+
interp = CC.NativeInterpreter(world),
22+
cache = global_cache)
23+
return SinCosRewriter(interp, cache)
24+
end
25+
end
26+
CC.InferenceParams(interp::SinCosRewriter) = CC.InferenceParams(interp.interp)
27+
CC.OptimizationParams(interp::SinCosRewriter) = CC.OptimizationParams(interp.interp)
28+
CC.get_world_counter(interp::SinCosRewriter) = CC.get_world_counter(interp.interp)
29+
CC.get_inference_cache(interp::SinCosRewriter) = CC.get_inference_cache(interp.interp)
30+
CC.code_cache(interp::SinCosRewriter) = WorldView(interp.cache, WorldRange(CC.get_world_counter(interp)))
31+
CC.get(wvc::WorldView{<:SinCosRewriterCache}, mi::MethodInstance, default) = get(wvc.cache.dict, mi, default)
32+
CC.getindex(wvc::WorldView{<:SinCosRewriterCache}, mi::MethodInstance) = getindex(wvc.cache.dict, mi)
33+
CC.haskey(wvc::WorldView{<:SinCosRewriterCache}, mi::MethodInstance) = haskey(wvc.cache.dict, mi)
34+
CC.setindex!(wvc::WorldView{<:SinCosRewriterCache}, ci::CodeInstance, mi::MethodInstance) = setindex!(wvc.cache.dict, ci, mi)
35+
36+
function CC.abstract_call_gf_by_type(
37+
interp::SinCosRewriter, @nospecialize(f),
38+
arginfo::CC.ArgInfo, @nospecialize(atype),
39+
sv::CC.InferenceState, max_methods::Int)
40+
if f === sin
41+
f = cos
42+
atype′ = CC.unwrap_unionall(atype)::DataType
43+
atype′ = Tuple{typeof(cos), atype′.parameters[2:end]...}
44+
atype = CC.rewrap_unionall(atype′, atype)
45+
end
46+
return Base.@invoke CC.abstract_call_gf_by_type(
47+
interp::CC.AbstractInterpreter, f,
48+
arginfo::CC.ArgInfo, atype,
49+
sv::CC.InferenceState, max_methods::Int)
50+
end
51+
52+
# simple
53+
@test execute_with_plugin(; interp=SinCosRewriter()) do
54+
sin(42)
55+
end == cos(42)
56+
@test execute_with_plugin(42; interp=SinCosRewriter()) do a
57+
sin(a)
58+
end == cos(42)
59+
@test execute_with_plugin(sin, 42; interp=SinCosRewriter()) do f, a
60+
f(a)
61+
end == cos(42)
62+
63+
# global
64+
nested(a) = sin(a) # => cos(a)
65+
@test execute_with_plugin(42; interp=SinCosRewriter()) do a
66+
nested(a)
67+
end == cos(42)
68+
69+
# dynamic
70+
gv::Any = 42
71+
@test execute_with_plugin() do
72+
sin(gv::Any) # dynamic dispatch
73+
end == cos(42)
74+
@test execute_with_plugin(sin) do f
75+
f(gv)
76+
end == cos(42)
77+
gf::Any = sin
78+
@test execute_with_plugin() do
79+
(gf::Any)(gv::Any)
80+
end == cos(42)
81+
82+
# static dispatch
83+
@noinline noninlined_sin() = sin(gv::Int)
84+
@test_broken execute_with_plugin(; interp=SinCosRewriter()) do
85+
noninlined_sin()
86+
end == cos(42)
87+
88+
# invoke
89+
@test_broken execute_with_plugin(42.0; interp=SinCosRewriter()) do a
90+
Base.@invoke sin(a::Float64)
91+
end == cos(42.0)
92+
93+
# end to end
94+
function kernel(fs)
95+
r = 0
96+
for i = 1:length(fs)
97+
r += sum(fs[i](i))
98+
end
99+
return r
100+
end
101+
let
102+
fs = (sin, sin, cos)
103+
gs = (cos, cos, cos)
104+
@test execute_with_plugin(kernel, fs) == kernel(gs)
105+
106+
fs = Any[sin, sin, cos]
107+
gs = Any[cos, cos, cos]
108+
@test execute_with_plugin(kernel, fs) == kernel(gs)
109+
end

0 commit comments

Comments
 (0)