Closed
Description
MWE, extracted from GPUCompiler.jl:
## custom interpreter
# extracted from GPUCompiler.jl
using Core.Compiler: AbstractInterpreter, InferenceResult, InferenceParams, InferenceState, OptimizationParams, MethodInstance, WorldView, CodeInstance, OverlayMethodTable
struct CustomInterpreter <: AbstractInterpreter
method_table::Union{Nothing,Core.MethodTable}
# Cache of inference results for this particular interpreter
local_cache::Vector{InferenceResult}
# The world age we're working inside of
world::UInt
# Parameters for inference and optimization
inf_params::InferenceParams
opt_params::OptimizationParams
function CustomInterpreter(mt::Union{Nothing,Core.MethodTable}, world::UInt)
@assert world <= Base.get_world_counter()
return new(
mt,
# Initially empty cache
Vector{InferenceResult}(),
# world age counter
world,
# parameters for inference and optimization
InferenceParams(unoptimize_throw_blocks=false),
VERSION >= v"1.8.0-DEV.486" ? OptimizationParams() :
OptimizationParams(unoptimize_throw_blocks=false),
)
end
end
struct CodeCache
dict::Dict{MethodInstance,Vector{CodeInstance}}
CodeCache() = new(Dict{MethodInstance,Vector{CodeInstance}}())
end
const global_code_cache = CodeCache()
function Core.Compiler.setindex!(cache::CodeCache, ci::CodeInstance, mi::MethodInstance)
# make sure the invalidation callback is attached to the method instance
callback(mi, max_world) = invalidate(cache, mi, max_world)
if !isdefined(mi, :callbacks)
mi.callbacks = Any[callback]
elseif !in(callback, mi.callbacks)
push!(mi.callbacks, callback)
end
cis = get!(cache.dict, mi, CodeInstance[])
push!(cis, ci)
end
function Core.Compiler.haskey(wvc::WorldView{CodeCache}, mi::MethodInstance)
Core.Compiler.get(wvc, mi, nothing) !== nothing
end
function Core.Compiler.get(wvc::WorldView{CodeCache}, mi::MethodInstance, default)
# check the cache
for ci in get!(wvc.cache.dict, mi, CodeInstance[])
if ci.min_world <= wvc.worlds.min_world && wvc.worlds.max_world <= ci.max_world
# TODO: if (code && (code == jl_nothing || jl_ir_flag_inferred((jl_array_t*)code)))
src = if ci.inferred isa Vector{UInt8}
ccall(:jl_uncompress_ir, Any, (Any, Ptr{Cvoid}, Any),
mi.def, C_NULL, ci.inferred)
else
ci.inferred
end
return ci
end
end
return default
end
function Core.Compiler.getindex(wvc::WorldView{CodeCache}, mi::MethodInstance)
r = Core.Compiler.get(wvc, mi, nothing)
r === nothing && throw(KeyError(mi))
return r::CodeInstance
end
function Core.Compiler.setindex!(wvc::WorldView{CodeCache}, ci::CodeInstance, mi::MethodInstance)
src = if ci.inferred isa Vector{UInt8}
ccall(:jl_uncompress_ir, Any, (Any, Ptr{Cvoid}, Any),
mi.def, C_NULL, ci.inferred)
else
ci.inferred
end
Core.Compiler.setindex!(wvc.cache, ci, mi)
end
Core.Compiler.lock_mi_inference(interp::CustomInterpreter, mi::MethodInstance) = nothing
Core.Compiler.unlock_mi_inference(interp::CustomInterpreter, mi::MethodInstance) = nothing
Core.Compiler.may_optimize(interp::CustomInterpreter) = true
Core.Compiler.may_compress(interp::CustomInterpreter) = true
Core.Compiler.may_discard_trees(interp::CustomInterpreter) = true
if VERSION >= v"1.7.0-DEV.577"
Core.Compiler.verbose_stmt_info(interp::CustomInterpreter) = false
end
Core.Compiler.add_remark!(interp::CustomInterpreter, sv::InferenceState, msg) = nothing
Core.Compiler.InferenceParams(interp::CustomInterpreter) = interp.inf_params
Core.Compiler.OptimizationParams(interp::CustomInterpreter) = interp.opt_params
Core.Compiler.get_world_counter(interp::CustomInterpreter) = interp.world
Core.Compiler.get_inference_cache(interp::CustomInterpreter) = interp.local_cache
Core.Compiler.code_cache(interp::CustomInterpreter) =
WorldView(global_code_cache, interp.world)
Core.Compiler.method_table(interp::CustomInterpreter, sv::InferenceState) =
OverlayMethodTable(interp.world, interp.method_table)
## kernel
using Base.Experimental: @MethodTable, @overlay
kernel() = child()
child() = 0
@MethodTable(method_table)
@overlay method_table child() = 1
## main
using InteractiveUtils
interp = CustomInterpreter(method_table, Base.get_world_counter())
InteractiveUtils.code_warntype(kernel, Tuple{}; interp)
On 1.7 this correctly results in child
being called (and constant propagated):
MethodInstance for kernel()
from kernel() in Main at /home/tim/Julia/pkg/GPUCompiler/wip.jl:122
Arguments
#self#::Core.Const(kernel)
Body::Int64
1 ─ %1 = Main.child()::Core.Const(1)
└── return %1
Whereas since #43852, the override doesn't take:
MethodInstance for kernel()
from kernel() in Main at /home/tim/Julia/pkg/GPUCompiler/wip.jl:122
Arguments
#self#::Core.Const(kernel)
Body::Int64
1 ─ %1 = Main.child()::Core.Const(0)
└── return %1