-
-
Notifications
You must be signed in to change notification settings - Fork 5.6k
Description
Backstory: CUDA.jl and other GPU back-ends use method overlays to replace GPU-incompatible functionality, like several exception-throwing functions. While we generally do support exception, replacing them at the LLVM level with custom reporting function, some are inherently GPU-incompatible due to string interpolation or untyped fields (resulting in a GC frame we cannot support). One such example is InexactError (from boot.jl
):
struct InexactError <: Exception
func::Symbol
T # Type
val
InexactError(f::Symbol, @nospecialize(T), @nospecialize(val)) = (@noinline; new(f, T, val))
end
throw_inexacterror(f::Symbol, ::Type{T}, val) where {T} = (@noinline; throw(InexactError(f, T, val)))
To support this, we rewrite throw_inexacterror
:
@device_override @noinline Core.throw_inexacterror(f::Symbol, ::Type{T}, val) where {T} =
@gpu_print "Inexact conversion"
This used to work fine, but with specific kernels (that contain a kwarg function call where the kwargs are heterogeneously typed) we now get LLVM IR that performs dynamic function calls:
child(; kwargs...) = return
function parent()
child(; a=1f0, b=1.0)
return
end
@overlay method_table @noinline Core.throw_inexacterror(f::Symbol, ::Type{T}, val) where {T} =
return
define void @good() #0 {
top:
%0 = call {}*** @julia.get_pgcstack()
%1 = bitcast {}*** %0 to {}**
%current_task = getelementptr inbounds {}*, {}** %1, i64 -12
%2 = bitcast {}** %current_task to i64*
%world_age = getelementptr inbounds i64, i64* %2, i64 13
ret void
}
define void @bad() #0 {
top:
; blah blah blah
%43 = call nonnull {} addrspace(10)* ({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)*, {} addrspace(10)*, ...) @julia.call({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)* @jl_f_apply_type, {} addrspace(10)* null, {} addrspace(10)* %36, {} addrspace(10)* %38, {} addrspace(10)* %28, {} addrspace(10)* %40, {} addrspace(10)* %42)
ret void
}
I bisected this to #44224, which is a bit weird, since that PR was back-ported to 1.8 too. It's possible that there's some interplay with other functionality only on 1.9, but that's unlikely as the PR was merged very shortly after 1.9 branched.
Anyway, the full MWE (reduced from GPUCompiler.jl):
# pieces lifted from GPUCompiler.jl
using LLVM
# helper type to deal with the transition from Context to ThreadSafeContext
export JuliaContext
if VERSION >= v"1.9.0-DEV.516"
const JuliaContextType = ThreadSafeContext
else
const JuliaContextType = Context
end
function JuliaContext()
if VERSION >= v"1.9.0-DEV.115"
JuliaContextType()
else
isboxed_ref = Ref{Bool}()
typ = LLVMType(ccall(:jl_type_to_llvm, LLVM.API.LLVMTypeRef,
(Any, Ptr{Bool}), Any, isboxed_ref))
context(typ)
end
end
function JuliaContext(f)
if VERSION >= v"1.9.0-DEV.115"
JuliaContextType(f)
else
f(JuliaContext())
end
end
if VERSION >= v"1.9.0-DEV.516"
unwrap_context(ctx::ThreadSafeContext) = context(ctx)
end
unwrap_context(ctx::Context) = ctx
# CI cache
using Core.Compiler: CodeInstance, MethodInstance, InferenceParams, OptimizationParams
struct CodeCache
dict::IdDict{MethodInstance,Vector{CodeInstance}}
CodeCache() = new(Dict{MethodInstance,Vector{CodeInstance}}())
end
function Core.Compiler.setindex!(cache::CodeCache, ci::CodeInstance, mi::MethodInstance)
cis = get!(cache.dict, mi, CodeInstance[])
push!(cis, ci)
end
const GLOBAL_CI_CACHE = CodeCache()
using Core.Compiler: WorldView
function Core.Compiler.haskey(wvc::WorldView{CodeCache}, mi::MethodInstance)
Core.Compiler.get(wvc, mi, nothing) !== nothing
end
function Core.Compiler.get(wvc::WorldView{CodeCache}, mi::MethodInstance, default)
for ci in get!(wvc.cache.dict, mi, CodeInstance[])
if ci.min_world <= wvc.worlds.min_world && wvc.worlds.max_world <= ci.max_world
src = if ci.inferred isa Vector{UInt8}
ccall(:jl_uncompress_ir, Any, (Any, Ptr{Cvoid}, Any),
mi.def, C_NULL, ci.inferred)
else
ci.inferred
end
return ci
end
end
return default
end
function Core.Compiler.getindex(wvc::WorldView{CodeCache}, mi::MethodInstance)
r = Core.Compiler.get(wvc, mi, nothing)
r === nothing && throw(KeyError(mi))
return r::CodeInstance
end
function Core.Compiler.setindex!(wvc::WorldView{CodeCache}, ci::CodeInstance, mi::MethodInstance)
src = if ci.inferred isa Vector{UInt8}
ccall(:jl_uncompress_ir, Any, (Any, Ptr{Cvoid}, Any),
mi.def, C_NULL, ci.inferred)
else
ci.inferred
end
Core.Compiler.setindex!(wvc.cache, ci, mi)
end
function ci_cache_populate(interp, mi, min_world, max_world)
src = Core.Compiler.typeinf_ext_toplevel(interp, mi)
wvc = WorldView(GLOBAL_CI_CACHE, min_world, max_world)
@assert Core.Compiler.haskey(wvc, mi)
ci = Core.Compiler.getindex(wvc, mi)
if ci !== nothing && ci.inferred === nothing
@static if VERSION >= v"1.9.0-DEV.1115"
@atomic ci.inferred = src
else
ci.inferred = src
end
end
return ci
end
function ci_cache_lookup(mi, min_world, max_world)
wvc = WorldView(GLOBAL_CI_CACHE, min_world, max_world)
ci = Core.Compiler.get(wvc, mi, nothing)
if ci !== nothing && ci.inferred === nothing
return nothing
end
return ci
end
# interpreter
using Core.Compiler: AbstractInterpreter, InferenceResult, InferenceParams, InferenceState, OptimizationParams
struct GPUInterpreter <: AbstractInterpreter
local_cache::Vector{InferenceResult}
GPUInterpreter() = new(Vector{InferenceResult}())
end
Core.Compiler.InferenceParams(interp::GPUInterpreter) = InferenceParams()
Core.Compiler.OptimizationParams(interp::GPUInterpreter) = OptimizationParams()
Core.Compiler.get_world_counter(interp::GPUInterpreter) = Base.get_world_counter()
Core.Compiler.get_inference_cache(interp::GPUInterpreter) = interp.local_cache
Core.Compiler.code_cache(interp::GPUInterpreter) = WorldView(GLOBAL_CI_CACHE, Base.get_world_counter())
using Base.Experimental: @overlay, @MethodTable
@MethodTable(GLOBAL_METHOD_TABLE)
using Core.Compiler: OverlayMethodTable
if v"1.8-beta2" <= VERSION < v"1.9-" || VERSION >= v"1.9.0-DEV.120"
Core.Compiler.method_table(interp::GPUInterpreter) =
OverlayMethodTable(Base.get_world_counter(), GLOBAL_METHOD_TABLE)
else
Core.Compiler.method_table(interp::GPUInterpreter, sv::InferenceState) =
OverlayMethodTable(Base.get_world_counter(), GLOBAL_METHOD_TABLE)
end
# disable ir interpretation due to issues with overlay tables
@static if VERSION >= v"1.9.0-DEV.1248"
function Core.Compiler.concrete_eval_eligible(interp::GPUInterpreter,
@nospecialize(f), result::Core.Compiler.MethodCallResult, arginfo::Core.Compiler.ArgInfo)
ret = @invoke Core.Compiler.concrete_eval_eligible(interp::AbstractInterpreter,
f::Any, result::Core.Compiler.MethodCallResult, arginfo::Core.Compiler.ArgInfo)
ret === false && return nothing
return ret
end
end
function codegen(sig)
mi, _ = emit_julia(sig)
JuliaContext() do ctx
ir, _ = emit_llvm(mi; ctx)
strip_debuginfo!(ir)
string(ir)
end
end
function emit_julia(sig)
meth = which(sig)
(ti, env) = ccall(:jl_type_intersection_with_env, Any,
(Any, Any), sig, meth.sig)::Core.SimpleVector
meth = Base.func_for_method_checked(meth, ti, env)
method_instance = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance},
(Any, Any, Any, UInt), meth, ti, env, Base.get_world_counter())
return method_instance, ()
end
function emit_llvm(@nospecialize(method_instance); ctx::JuliaContextType)
InitializeAllTargets()
InitializeAllTargetInfos()
InitializeAllAsmPrinters()
InitializeAllAsmParsers()
InitializeAllTargetMCs()
ir, compiled = irgen(method_instance; ctx)
entry_fn = compiled[method_instance].specfunc
entry = functions(ir)[entry_fn]
return ir, (; entry, compiled)
end
function irgen(method_instance::Core.MethodInstance;
ctx::JuliaContextType)
mod, compiled = compile_method_instance(method_instance; ctx)
entry_fn = compiled[method_instance].specfunc
entry = functions(mod)[entry_fn]
return mod, compiled
end
function compile_method_instance(method_instance::MethodInstance; ctx::JuliaContextType)
world = Base.get_world_counter()
interp = GPUInterpreter()
if ci_cache_lookup(method_instance, world, typemax(Cint)) === nothing
ci_cache_populate(interp, method_instance, world, typemax(Cint))
end
method_instances = []
function lookup_fun(mi, min_world, max_world)
push!(method_instances, mi)
ci_cache_lookup(mi, min_world, max_world)
end
lookup_cb = @cfunction($lookup_fun, Any, (Any, UInt, UInt))
params = Base.CodegenParams(; lookup = Base.unsafe_convert(Ptr{Nothing}, lookup_cb))
GC.@preserve lookup_cb begin
native_code = if VERSION >= v"1.9.0-DEV.516"
mod = LLVM.Module("start"; ctx=unwrap_context(ctx))
ts_mod = ThreadSafeModule(mod; ctx)
ccall(:jl_create_native, Ptr{Cvoid},
(Vector{MethodInstance}, LLVM.API.LLVMOrcThreadSafeModuleRef, Ptr{Base.CodegenParams}, Cint),
[method_instance], ts_mod, Ref(params), 1)
elseif VERSION >= v"1.9.0-DEV.115"
ccall(:jl_create_native, Ptr{Cvoid},
(Vector{MethodInstance}, LLVM.API.LLVMContextRef, Ptr{Base.CodegenParams}, Cint),
[method_instance], ctx, Ref(params), 1)
elseif VERSION >= v"1.8.0-DEV.661"
@assert ctx == JuliaContext()
ccall(:jl_create_native, Ptr{Cvoid},
(Vector{MethodInstance}, Ptr{Base.CodegenParams}, Cint),
[method_instance], Ref(params), 1)
else
@assert ctx == JuliaContext()
ccall(:jl_create_native, Ptr{Cvoid},
(Vector{MethodInstance}, Base.CodegenParams, Cint),
[method_instance], params, 1)
end
@assert native_code != C_NULL
llvm_mod_ref = if VERSION >= v"1.9.0-DEV.516"
ccall(:jl_get_llvm_module, LLVM.API.LLVMOrcThreadSafeModuleRef,
(Ptr{Cvoid},), native_code)
else
ccall(:jl_get_llvm_module, LLVM.API.LLVMModuleRef,
(Ptr{Cvoid},), native_code)
end
@assert llvm_mod_ref != C_NULL
if VERSION >= v"1.9.0-DEV.516"
llvm_ts_mod = LLVM.ThreadSafeModule(llvm_mod_ref)
llvm_mod = nothing
llvm_ts_mod() do mod
llvm_mod = mod
end
else
llvm_mod = LLVM.Module(llvm_mod_ref)
end
end
compiled = Dict()
for mi in method_instances
ci = ci_cache_lookup(mi, world, typemax(Cint))
if ci !== nothing
llvm_func_idx = Ref{Int32}(-1)
llvm_specfunc_idx = Ref{Int32}(-1)
ccall(:jl_get_function_id, Nothing,
(Ptr{Cvoid}, Any, Ptr{Int32}, Ptr{Int32}),
native_code, ci, llvm_func_idx, llvm_specfunc_idx)
llvm_func = if llvm_func_idx[] != -1
llvm_func_ref = ccall(:jl_get_llvm_function, LLVM.API.LLVMValueRef,
(Ptr{Cvoid}, UInt32), native_code, llvm_func_idx[]-1)
@assert llvm_func_ref != C_NULL
LLVM.name(LLVM.Function(llvm_func_ref))
else
nothing
end
llvm_specfunc = if llvm_specfunc_idx[] != -1
llvm_specfunc_ref = ccall(:jl_get_llvm_function, LLVM.API.LLVMValueRef,
(Ptr{Cvoid}, UInt32), native_code, llvm_specfunc_idx[]-1)
@assert llvm_specfunc_ref != C_NULL
LLVM.name(LLVM.Function(llvm_specfunc_ref))
else
nothing
end
compiled[mi] = (; ci, func=llvm_func, specfunc=llvm_specfunc)
end
end
return llvm_mod, compiled
end
############################################################################################
# compiler invocation
child(; kwargs...) = return
function parent()
child(; a=1f0, b=1.0)
return
end
# this override introduces a `jl_invoke`
@overlay GLOBAL_METHOD_TABLE @noinline Core.throw_inexacterror(f::Symbol, ::Type{T}, val) where {T} =
return
println(codegen(Tuple{typeof(parent)}))
cc @aviatesk, you previously looked at GPU-related overlay issues
Putting this on the milestone as this breaks parts of CUDA.jl, but again, I'm also happy to adapt CUDA.jl (although dropping the overlay of throw_inexacterror
will result in other code breaking, so that's not a great solution either).