|
| 1 | +# %% |
| 2 | +const CC = Core.Compiler |
| 3 | +using Test, .CC.Plugin |
| 4 | +import Core: MethodInstance, CodeInstance |
| 5 | +import .CC: WorldRange, WorldView |
| 6 | + |
| 7 | +# %% |
| 8 | +# SinCosRewriter |
| 9 | +# -------------- |
| 10 | + |
| 11 | +struct SinCosRewriterCache |
| 12 | + dict::IdDict{MethodInstance,CodeInstance} |
| 13 | +end |
| 14 | +struct SinCosRewriter <: CC.AbstractInterpreter |
| 15 | + interp::CC.NativeInterpreter |
| 16 | + cache::SinCosRewriterCache |
| 17 | +end |
| 18 | +let global_cache = SinCosRewriterCache(IdDict{MethodInstance,CodeInstance}()) |
| 19 | + global function SinCosRewriter( |
| 20 | + world = Base.get_world_counter(); |
| 21 | + interp = CC.NativeInterpreter(world), |
| 22 | + cache = global_cache) |
| 23 | + return SinCosRewriter(interp, cache) |
| 24 | + end |
| 25 | +end |
| 26 | +CC.InferenceParams(interp::SinCosRewriter) = CC.InferenceParams(interp.interp) |
| 27 | +CC.OptimizationParams(interp::SinCosRewriter) = CC.OptimizationParams(interp.interp) |
| 28 | +CC.get_world_counter(interp::SinCosRewriter) = CC.get_world_counter(interp.interp) |
| 29 | +CC.get_inference_cache(interp::SinCosRewriter) = CC.get_inference_cache(interp.interp) |
| 30 | +CC.code_cache(interp::SinCosRewriter) = WorldView(interp.cache, WorldRange(CC.get_world_counter(interp))) |
| 31 | +CC.get(wvc::WorldView{<:SinCosRewriterCache}, mi::MethodInstance, default) = get(wvc.cache.dict, mi, default) |
| 32 | +CC.getindex(wvc::WorldView{<:SinCosRewriterCache}, mi::MethodInstance) = getindex(wvc.cache.dict, mi) |
| 33 | +CC.haskey(wvc::WorldView{<:SinCosRewriterCache}, mi::MethodInstance) = haskey(wvc.cache.dict, mi) |
| 34 | +CC.setindex!(wvc::WorldView{<:SinCosRewriterCache}, ci::CodeInstance, mi::MethodInstance) = setindex!(wvc.cache.dict, ci, mi) |
| 35 | + |
| 36 | +function CC.abstract_call_gf_by_type( |
| 37 | + interp::SinCosRewriter, @nospecialize(f), |
| 38 | + arginfo::CC.ArgInfo, @nospecialize(atype), |
| 39 | + sv::CC.InferenceState, max_methods::Int) |
| 40 | + if f === sin |
| 41 | + f = cos |
| 42 | + atype′ = CC.unwrap_unionall(atype)::DataType |
| 43 | + atype′ = Tuple{typeof(cos), atype′.parameters[2:end]...} |
| 44 | + atype = CC.rewrap_unionall(atype′, atype) |
| 45 | + end |
| 46 | + return Base.@invoke CC.abstract_call_gf_by_type( |
| 47 | + interp::CC.AbstractInterpreter, f, |
| 48 | + arginfo::CC.ArgInfo, atype, |
| 49 | + sv::CC.InferenceState, max_methods::Int) |
| 50 | +end |
| 51 | + |
| 52 | +# simple |
| 53 | +@test execute_with_plugin(; interp=SinCosRewriter()) do |
| 54 | + sin(42) |
| 55 | +end == cos(42) |
| 56 | +@test execute_with_plugin(42; interp=SinCosRewriter()) do a |
| 57 | + sin(a) |
| 58 | +end == cos(42) |
| 59 | +@test execute_with_plugin(sin, 42; interp=SinCosRewriter()) do f, a |
| 60 | + f(a) |
| 61 | +end == cos(42) |
| 62 | + |
| 63 | +# global |
| 64 | +nested(a) = sin(a) # => cos(a) |
| 65 | +@test execute_with_plugin(42; interp=SinCosRewriter()) do a |
| 66 | + nested(a) |
| 67 | +end == cos(42) |
| 68 | + |
| 69 | +# dynamic |
| 70 | +gv::Any = 42 |
| 71 | +@test execute_with_plugin() do |
| 72 | + sin(gv::Any) # dynamic dispatch |
| 73 | +end == cos(42) |
| 74 | +@test execute_with_plugin(sin) do f |
| 75 | + f(gv) |
| 76 | +end == cos(42) |
| 77 | +gf::Any = sin |
| 78 | +@test execute_with_plugin() do |
| 79 | + (gf::Any)(gv::Any) |
| 80 | +end == cos(42) |
| 81 | + |
| 82 | +# static dispatch |
| 83 | +@noinline noninlined_sin() = sin(gv::Int) |
| 84 | +@test_broken execute_with_plugin(; interp=SinCosRewriter()) do |
| 85 | + noninlined_sin() |
| 86 | +end == cos(42) |
| 87 | + |
| 88 | +# invoke |
| 89 | +@test_broken execute_with_plugin(42.0; interp=SinCosRewriter()) do a |
| 90 | + Base.@invoke sin(a::Float64) |
| 91 | +end == cos(42.0) |
| 92 | + |
| 93 | +# end to end |
| 94 | +function kernel(fs) |
| 95 | + r = 0 |
| 96 | + for i = 1:length(fs) |
| 97 | + r += sum(fs[i](i)) |
| 98 | + end |
| 99 | + return r |
| 100 | +end |
| 101 | +let |
| 102 | + fs = (sin, sin, cos) |
| 103 | + gs = (cos, cos, cos) |
| 104 | + @test execute_with_plugin(kernel, fs) == kernel(gs) |
| 105 | + |
| 106 | + fs = Any[sin, sin, cos] |
| 107 | + gs = Any[cos, cos, cos] |
| 108 | + @test execute_with_plugin(kernel, fs) == kernel(gs) |
| 109 | +end |
0 commit comments