Skip to content

Overlay of Core.throw_inexacterror results in bad codegen of kwarg call #48097

@maleadt

Description

@maleadt

Backstory: CUDA.jl and other GPU back-ends use method overlays to replace GPU-incompatible functionality, like several exception-throwing functions. While we generally do support exception, replacing them at the LLVM level with custom reporting function, some are inherently GPU-incompatible due to string interpolation or untyped fields (resulting in a GC frame we cannot support). One such example is InexactError (from boot.jl):

struct InexactError <: Exception
    func::Symbol
    T  # Type
    val
    InexactError(f::Symbol, @nospecialize(T), @nospecialize(val)) = (@noinline; new(f, T, val))
end
throw_inexacterror(f::Symbol, ::Type{T}, val) where {T} = (@noinline; throw(InexactError(f, T, val)))

To support this, we rewrite throw_inexacterror:

@device_override @noinline Core.throw_inexacterror(f::Symbol, ::Type{T}, val) where {T} =
    @gpu_print "Inexact conversion"

This used to work fine, but with specific kernels (that contain a kwarg function call where the kwargs are heterogeneously typed) we now get LLVM IR that performs dynamic function calls:

child(; kwargs...) = return
function parent()
    child(; a=1f0, b=1.0)
    return
end

@overlay method_table @noinline Core.throw_inexacterror(f::Symbol, ::Type{T}, val) where {T} =
    return
define void @good() #0 {
top:
  %0 = call {}*** @julia.get_pgcstack()
  %1 = bitcast {}*** %0 to {}**
  %current_task = getelementptr inbounds {}*, {}** %1, i64 -12
  %2 = bitcast {}** %current_task to i64*
  %world_age = getelementptr inbounds i64, i64* %2, i64 13
  ret void
}

define void @bad() #0 {
top:
  ; blah blah blah
  %43 = call nonnull {} addrspace(10)* ({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)*, {} addrspace(10)*, ...) @julia.call({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)* @jl_f_apply_type, {} addrspace(10)* null, {} addrspace(10)* %36, {} addrspace(10)* %38, {} addrspace(10)* %28, {} addrspace(10)* %40, {} addrspace(10)* %42)
  ret void
}

I bisected this to #44224, which is a bit weird, since that PR was back-ported to 1.8 too. It's possible that there's some interplay with other functionality only on 1.9, but that's unlikely as the PR was merged very shortly after 1.9 branched.

Anyway, the full MWE (reduced from GPUCompiler.jl):

# pieces lifted from GPUCompiler.jl

using LLVM

# helper type to deal with the transition from Context to ThreadSafeContext
export JuliaContext
if VERSION >= v"1.9.0-DEV.516"
    const JuliaContextType = ThreadSafeContext
else
    const JuliaContextType = Context
end
function JuliaContext()
    if VERSION >= v"1.9.0-DEV.115"
        JuliaContextType()
    else
        isboxed_ref = Ref{Bool}()
        typ = LLVMType(ccall(:jl_type_to_llvm, LLVM.API.LLVMTypeRef,
                       (Any, Ptr{Bool}), Any, isboxed_ref))
        context(typ)
    end
end
function JuliaContext(f)
    if VERSION >= v"1.9.0-DEV.115"
        JuliaContextType(f)
    else
        f(JuliaContext())
    end
end
if VERSION >= v"1.9.0-DEV.516"
unwrap_context(ctx::ThreadSafeContext) = context(ctx)
end
unwrap_context(ctx::Context) = ctx

# CI cache
using Core.Compiler: CodeInstance, MethodInstance, InferenceParams, OptimizationParams
struct CodeCache
    dict::IdDict{MethodInstance,Vector{CodeInstance}}
    CodeCache() = new(Dict{MethodInstance,Vector{CodeInstance}}())
end
function Core.Compiler.setindex!(cache::CodeCache, ci::CodeInstance, mi::MethodInstance)
    cis = get!(cache.dict, mi, CodeInstance[])
    push!(cis, ci)
end
const GLOBAL_CI_CACHE = CodeCache()

using Core.Compiler: WorldView
function Core.Compiler.haskey(wvc::WorldView{CodeCache}, mi::MethodInstance)
    Core.Compiler.get(wvc, mi, nothing) !== nothing
end
function Core.Compiler.get(wvc::WorldView{CodeCache}, mi::MethodInstance, default)
    for ci in get!(wvc.cache.dict, mi, CodeInstance[])
        if ci.min_world <= wvc.worlds.min_world && wvc.worlds.max_world <= ci.max_world
            src = if ci.inferred isa Vector{UInt8}
                ccall(:jl_uncompress_ir, Any, (Any, Ptr{Cvoid}, Any),
                       mi.def, C_NULL, ci.inferred)
            else
                ci.inferred
            end
            return ci
        end
    end
    return default
end
function Core.Compiler.getindex(wvc::WorldView{CodeCache}, mi::MethodInstance)
    r = Core.Compiler.get(wvc, mi, nothing)
    r === nothing && throw(KeyError(mi))
    return r::CodeInstance
end
function Core.Compiler.setindex!(wvc::WorldView{CodeCache}, ci::CodeInstance, mi::MethodInstance)
    src = if ci.inferred isa Vector{UInt8}
        ccall(:jl_uncompress_ir, Any, (Any, Ptr{Cvoid}, Any),
                mi.def, C_NULL, ci.inferred)
    else
        ci.inferred
    end
    Core.Compiler.setindex!(wvc.cache, ci, mi)
end
function ci_cache_populate(interp, mi, min_world, max_world)
    src = Core.Compiler.typeinf_ext_toplevel(interp, mi)
    wvc = WorldView(GLOBAL_CI_CACHE, min_world, max_world)
    @assert Core.Compiler.haskey(wvc, mi)
    ci = Core.Compiler.getindex(wvc, mi)
    if ci !== nothing && ci.inferred === nothing
        @static if VERSION >= v"1.9.0-DEV.1115"
            @atomic ci.inferred = src
        else
            ci.inferred = src
        end
    end
    return ci
end
function ci_cache_lookup(mi, min_world, max_world)
    wvc = WorldView(GLOBAL_CI_CACHE, min_world, max_world)
    ci = Core.Compiler.get(wvc, mi, nothing)
    if ci !== nothing && ci.inferred === nothing
        return nothing
    end
    return ci
end

# interpreter
using Core.Compiler: AbstractInterpreter, InferenceResult, InferenceParams, InferenceState, OptimizationParams
struct GPUInterpreter <: AbstractInterpreter
    local_cache::Vector{InferenceResult}
    GPUInterpreter() = new(Vector{InferenceResult}())
end
Core.Compiler.InferenceParams(interp::GPUInterpreter) = InferenceParams()
Core.Compiler.OptimizationParams(interp::GPUInterpreter) = OptimizationParams()
Core.Compiler.get_world_counter(interp::GPUInterpreter) = Base.get_world_counter()
Core.Compiler.get_inference_cache(interp::GPUInterpreter) = interp.local_cache
Core.Compiler.code_cache(interp::GPUInterpreter) = WorldView(GLOBAL_CI_CACHE, Base.get_world_counter())

using Base.Experimental: @overlay, @MethodTable
@MethodTable(GLOBAL_METHOD_TABLE)

using Core.Compiler: OverlayMethodTable
if v"1.8-beta2" <= VERSION < v"1.9-" || VERSION >= v"1.9.0-DEV.120"
Core.Compiler.method_table(interp::GPUInterpreter) =
    OverlayMethodTable(Base.get_world_counter(), GLOBAL_METHOD_TABLE)
else
Core.Compiler.method_table(interp::GPUInterpreter, sv::InferenceState) =
    OverlayMethodTable(Base.get_world_counter(), GLOBAL_METHOD_TABLE)
end

# disable ir interpretation due to issues with overlay tables
@static if VERSION >= v"1.9.0-DEV.1248"
function Core.Compiler.concrete_eval_eligible(interp::GPUInterpreter,
    @nospecialize(f), result::Core.Compiler.MethodCallResult, arginfo::Core.Compiler.ArgInfo)
    ret = @invoke Core.Compiler.concrete_eval_eligible(interp::AbstractInterpreter,
        f::Any, result::Core.Compiler.MethodCallResult, arginfo::Core.Compiler.ArgInfo)
    ret === false && return nothing
    return ret
end
end

function codegen(sig)
    mi, _ = emit_julia(sig)

    JuliaContext() do ctx
        ir, _ = emit_llvm(mi; ctx)
        strip_debuginfo!(ir)
        string(ir)
    end
end

function emit_julia(sig)
    meth = which(sig)
    (ti, env) = ccall(:jl_type_intersection_with_env, Any,
                        (Any, Any), sig, meth.sig)::Core.SimpleVector
    meth = Base.func_for_method_checked(meth, ti, env)
    method_instance = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance},
                            (Any, Any, Any, UInt), meth, ti, env, Base.get_world_counter())
    return method_instance, ()
end

function emit_llvm(@nospecialize(method_instance); ctx::JuliaContextType)
    InitializeAllTargets()
    InitializeAllTargetInfos()
    InitializeAllAsmPrinters()
    InitializeAllAsmParsers()
    InitializeAllTargetMCs()

    ir, compiled = irgen(method_instance; ctx)
    entry_fn = compiled[method_instance].specfunc
    entry = functions(ir)[entry_fn]
    return ir, (; entry, compiled)
end

function irgen(method_instance::Core.MethodInstance;
               ctx::JuliaContextType)
    mod, compiled = compile_method_instance(method_instance; ctx)
    entry_fn = compiled[method_instance].specfunc
    entry = functions(mod)[entry_fn]
    return mod, compiled
end

function compile_method_instance(method_instance::MethodInstance; ctx::JuliaContextType)
    world = Base.get_world_counter()
    interp =  GPUInterpreter()
    if ci_cache_lookup(method_instance, world, typemax(Cint)) === nothing
        ci_cache_populate(interp, method_instance, world, typemax(Cint))
    end

    method_instances = []
    function lookup_fun(mi, min_world, max_world)
        push!(method_instances, mi)
        ci_cache_lookup(mi, min_world, max_world)
    end
    lookup_cb = @cfunction($lookup_fun, Any, (Any, UInt, UInt))
    params = Base.CodegenParams(; lookup = Base.unsafe_convert(Ptr{Nothing}, lookup_cb))
    GC.@preserve lookup_cb begin
        native_code = if VERSION >= v"1.9.0-DEV.516"
            mod = LLVM.Module("start"; ctx=unwrap_context(ctx))
            ts_mod = ThreadSafeModule(mod; ctx)
            ccall(:jl_create_native, Ptr{Cvoid},
                  (Vector{MethodInstance}, LLVM.API.LLVMOrcThreadSafeModuleRef, Ptr{Base.CodegenParams}, Cint),
                  [method_instance], ts_mod, Ref(params),  1)
        elseif VERSION >= v"1.9.0-DEV.115"
            ccall(:jl_create_native, Ptr{Cvoid},
                  (Vector{MethodInstance}, LLVM.API.LLVMContextRef, Ptr{Base.CodegenParams}, Cint),
                  [method_instance], ctx, Ref(params),  1)
        elseif VERSION >= v"1.8.0-DEV.661"
            @assert ctx == JuliaContext()
            ccall(:jl_create_native, Ptr{Cvoid},
                  (Vector{MethodInstance}, Ptr{Base.CodegenParams}, Cint),
                  [method_instance], Ref(params),  1)
        else
            @assert ctx == JuliaContext()
            ccall(:jl_create_native, Ptr{Cvoid},
                  (Vector{MethodInstance}, Base.CodegenParams, Cint),
                  [method_instance], params,  1)
        end
        @assert native_code != C_NULL
        llvm_mod_ref = if VERSION >= v"1.9.0-DEV.516"
            ccall(:jl_get_llvm_module, LLVM.API.LLVMOrcThreadSafeModuleRef,
                  (Ptr{Cvoid},), native_code)
        else
            ccall(:jl_get_llvm_module, LLVM.API.LLVMModuleRef,
                  (Ptr{Cvoid},), native_code)
        end
        @assert llvm_mod_ref != C_NULL
        if VERSION >= v"1.9.0-DEV.516"
            llvm_ts_mod = LLVM.ThreadSafeModule(llvm_mod_ref)
            llvm_mod = nothing
            llvm_ts_mod() do mod
                llvm_mod = mod
            end
        else
            llvm_mod = LLVM.Module(llvm_mod_ref)
        end
    end

    compiled = Dict()
    for mi in method_instances
        ci = ci_cache_lookup(mi, world, typemax(Cint))
        if ci !== nothing
            llvm_func_idx = Ref{Int32}(-1)
            llvm_specfunc_idx = Ref{Int32}(-1)
            ccall(:jl_get_function_id, Nothing,
                (Ptr{Cvoid}, Any, Ptr{Int32}, Ptr{Int32}),
                native_code, ci, llvm_func_idx, llvm_specfunc_idx)
            llvm_func = if llvm_func_idx[] != -1
                llvm_func_ref = ccall(:jl_get_llvm_function, LLVM.API.LLVMValueRef,
                                      (Ptr{Cvoid}, UInt32), native_code, llvm_func_idx[]-1)
                @assert llvm_func_ref != C_NULL
                LLVM.name(LLVM.Function(llvm_func_ref))
            else
                nothing
            end
            llvm_specfunc = if llvm_specfunc_idx[] != -1
                llvm_specfunc_ref = ccall(:jl_get_llvm_function, LLVM.API.LLVMValueRef,
                                        (Ptr{Cvoid}, UInt32), native_code, llvm_specfunc_idx[]-1)
                @assert llvm_specfunc_ref != C_NULL
                LLVM.name(LLVM.Function(llvm_specfunc_ref))
            else
                nothing
            end
            compiled[mi] = (; ci, func=llvm_func, specfunc=llvm_specfunc)
        end
    end

    return llvm_mod, compiled
end

############################################################################################

# compiler invocation

child(; kwargs...) = return
function parent()
    child(; a=1f0, b=1.0)
    return
end

# this override introduces a `jl_invoke`
@overlay GLOBAL_METHOD_TABLE @noinline Core.throw_inexacterror(f::Symbol, ::Type{T}, val) where {T} =
    return

println(codegen(Tuple{typeof(parent)}))

cc @aviatesk, you previously looked at GPU-related overlay issues

Putting this on the milestone as this breaks parts of CUDA.jl, but again, I'm also happy to adapt CUDA.jl (although dropping the overlay of throw_inexacterror will result in other code breaking, so that's not a great solution either).

Metadata

Metadata

Assignees

No one assigned

    Labels

    compiler:codegenGeneration of LLVM IR and native codegpuAffects running Julia on a GPUregressionRegression in behavior compared to a previous version

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions