Skip to content

Purity modeling broke overlay tables #44174

Closed
@maleadt

Description

@maleadt

MWE, extracted from GPUCompiler.jl:

## custom interpreter

# extracted from GPUCompiler.jl

using Core.Compiler: AbstractInterpreter, InferenceResult, InferenceParams, InferenceState, OptimizationParams, MethodInstance, WorldView, CodeInstance, OverlayMethodTable

struct CustomInterpreter <: AbstractInterpreter
    method_table::Union{Nothing,Core.MethodTable}

    # Cache of inference results for this particular interpreter
    local_cache::Vector{InferenceResult}
    # The world age we're working inside of
    world::UInt

    # Parameters for inference and optimization
    inf_params::InferenceParams
    opt_params::OptimizationParams

    function CustomInterpreter(mt::Union{Nothing,Core.MethodTable}, world::UInt)
        @assert world <= Base.get_world_counter()

        return new(
            mt,

            # Initially empty cache
            Vector{InferenceResult}(),

            # world age counter
            world,

            # parameters for inference and optimization
            InferenceParams(unoptimize_throw_blocks=false),
            VERSION >= v"1.8.0-DEV.486" ? OptimizationParams() :
                                          OptimizationParams(unoptimize_throw_blocks=false),
        )
    end
end

struct CodeCache
    dict::Dict{MethodInstance,Vector{CodeInstance}}

    CodeCache() = new(Dict{MethodInstance,Vector{CodeInstance}}())
end
const global_code_cache = CodeCache()

function Core.Compiler.setindex!(cache::CodeCache, ci::CodeInstance, mi::MethodInstance)
    # make sure the invalidation callback is attached to the method instance
    callback(mi, max_world) = invalidate(cache, mi, max_world)
    if !isdefined(mi, :callbacks)
        mi.callbacks = Any[callback]
    elseif !in(callback, mi.callbacks)
        push!(mi.callbacks, callback)
    end

    cis = get!(cache.dict, mi, CodeInstance[])
    push!(cis, ci)
end

function Core.Compiler.haskey(wvc::WorldView{CodeCache}, mi::MethodInstance)
    Core.Compiler.get(wvc, mi, nothing) !== nothing
end

function Core.Compiler.get(wvc::WorldView{CodeCache}, mi::MethodInstance, default)
    # check the cache
    for ci in get!(wvc.cache.dict, mi, CodeInstance[])
        if ci.min_world <= wvc.worlds.min_world && wvc.worlds.max_world <= ci.max_world
            # TODO: if (code && (code == jl_nothing || jl_ir_flag_inferred((jl_array_t*)code)))
            src = if ci.inferred isa Vector{UInt8}
                ccall(:jl_uncompress_ir, Any, (Any, Ptr{Cvoid}, Any),
                       mi.def, C_NULL, ci.inferred)
            else
                ci.inferred
            end
            return ci
        end
    end

    return default
end

function Core.Compiler.getindex(wvc::WorldView{CodeCache}, mi::MethodInstance)
    r = Core.Compiler.get(wvc, mi, nothing)
    r === nothing && throw(KeyError(mi))
    return r::CodeInstance
end

function Core.Compiler.setindex!(wvc::WorldView{CodeCache}, ci::CodeInstance, mi::MethodInstance)
    src = if ci.inferred isa Vector{UInt8}
        ccall(:jl_uncompress_ir, Any, (Any, Ptr{Cvoid}, Any),
                mi.def, C_NULL, ci.inferred)
    else
        ci.inferred
    end
    Core.Compiler.setindex!(wvc.cache, ci, mi)
end

Core.Compiler.lock_mi_inference(interp::CustomInterpreter, mi::MethodInstance) = nothing
Core.Compiler.unlock_mi_inference(interp::CustomInterpreter, mi::MethodInstance) = nothing
Core.Compiler.may_optimize(interp::CustomInterpreter) = true
Core.Compiler.may_compress(interp::CustomInterpreter) = true
Core.Compiler.may_discard_trees(interp::CustomInterpreter) = true
if VERSION >= v"1.7.0-DEV.577"
Core.Compiler.verbose_stmt_info(interp::CustomInterpreter) = false
end
Core.Compiler.add_remark!(interp::CustomInterpreter, sv::InferenceState, msg) = nothing

Core.Compiler.InferenceParams(interp::CustomInterpreter) = interp.inf_params
Core.Compiler.OptimizationParams(interp::CustomInterpreter) = interp.opt_params
Core.Compiler.get_world_counter(interp::CustomInterpreter) = interp.world
Core.Compiler.get_inference_cache(interp::CustomInterpreter) = interp.local_cache
Core.Compiler.code_cache(interp::CustomInterpreter) =
    WorldView(global_code_cache, interp.world)

Core.Compiler.method_table(interp::CustomInterpreter, sv::InferenceState) =
    OverlayMethodTable(interp.world, interp.method_table)


## kernel

using Base.Experimental: @MethodTable, @overlay

kernel() = child()
child() = 0

@MethodTable(method_table)
@overlay method_table child() = 1


## main

using InteractiveUtils

interp = CustomInterpreter(method_table, Base.get_world_counter())
InteractiveUtils.code_warntype(kernel, Tuple{}; interp)

On 1.7 this correctly results in child being called (and constant propagated):

MethodInstance for kernel()
  from kernel() in Main at /home/tim/Julia/pkg/GPUCompiler/wip.jl:122
Arguments
  #self#::Core.Const(kernel)
Body::Int64
1%1 = Main.child()::Core.Const(1)
└──      return %1

Whereas since #43852, the override doesn't take:

MethodInstance for kernel()
  from kernel() in Main at /home/tim/Julia/pkg/GPUCompiler/wip.jl:122
Arguments
  #self#::Core.Const(kernel)
Body::Int64
1%1 = Main.child()::Core.Const(0)
└──      return %1

cc @Keno @vchuravy

Metadata

Metadata

Assignees

Labels

gpuAffects running Julia on a GPUregressionRegression in behavior compared to a previous version

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions