-
-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Description
Julia's GPU back-ends need to be able to create variables with threadgroup- and thread-local semantics. Awaiting something like #47569, we currently do so by emitting LLVM IR that defines a global variable, and accessing that memory as an array using unsafe_wrap
:
@inline shmem() = Base.llvmcall(("""
@shmem = internal global [1 x i8] zeroinitializer, align 32
define i8* @entry() #0 {
ret i8* getelementptr inbounds ([1 x i8], [1 x i8]* @shmem, i64 0, i64 0)
}
attributes #0 = { alwaysinline }""", "entry"),
Core.LLVMPtr{Int8,0}, Tuple{})
function kernel1()
ptr = reinterpret(Ptr{Int8}, shmem())
arr = unsafe_wrap(Array, ptr, 1)
@inbounds begin
arr[] = 1
end
end
@shmem = internal global [1 x i8] zeroinitializer, align 32
define i64 @julia_kernel_189() #0 {
top:
%0 = call nonnull {}* inttoptr (i64 140467922984160 to {}* ({}*, i64, i64, i32)*)({}* inttoptr (i64 140467582684928 to {}*), i64 ptrtoint ([1 x i8]* @shmem to i64), i64 1, i32 0)
%1 = bitcast {}* %0 to i8**
%2 = load i8*, i8** %1, align 8
store i8 1, i8* %2, align 1
ret i64 1
}
Not particularly clean, but this has been working fine for us. Even if we have multiple calls to shmem()
, we just get multiple instances of the IR, which get duplicated correctly upon module merging:
function kernel2()
ptr1 = reinterpret(Ptr{Int8}, shmem())
arr1 = unsafe_wrap(Array, ptr1, 1)
ptr2 = reinterpret(Ptr{Int8}, shmem())
arr2 = unsafe_wrap(Array, ptr2, 1)
@inbounds begin
arr1[] = 1
arr2[]
end
end
@shmem = internal global [1 x i8] zeroinitializer, align 32
@shmem.5 = internal global [1 x i8] zeroinitializer, align 32
define i8 @julia_kernel3_618() #0 {
top:
...
%6 = call nonnull {}* inttoptr (i64 140467922984160 to {}* ({}*, i64, i64, i32)*)({}* inttoptr (i64 140467582684928 to {}*), i64 ptrtoint ([1 x i8]* @shmem to i64), i64 1, i32 0)
%7 = getelementptr inbounds [3 x {}*], [3 x {}*]* %gcframe5, i64 0, i64 2
store {}* %6, {}** %7, align 16
%8 = call nonnull {}* inttoptr (i64 140467922984160 to {}* ({}*, i64, i64, i32)*)({}* inttoptr (i64 140467582684928 to {}*), i64 ptrtoint ([1 x i8]* @shmem.5 to i64), i64 1, i32 0)
...
ret i8 %13
}
This however changed on 1.9. Bisected to #44440 (cc @pchintalapudi) we only get a single shmem array, which obviously breaks a lot of things:
@shmem = internal global [1 x i8] zeroinitializer, align 32
define i8 @julia_kernel3_464() #0 {
top:
...
%6 = call nonnull {}* inttoptr (i64 140077718466992 to {}* ({}*, i64, i64, i32)*)({}* inttoptr (i64 140077392336176 to {}*), i64 ptrtoint ([1 x i8]* @shmem to i64), i64 1, i32 0)
%7 = getelementptr inbounds [3 x {}*], [3 x {}*]* %gcframe5, i64 0, i64 2
store {}* %6, {}** %7, align 16
%8 = call nonnull {}* inttoptr (i64 140077718466992 to {}* ({}*, i64, i64, i32)*)({}* inttoptr (i64 140077392336176 to {}*), i64 ptrtoint ([1 x i8]* @shmem to i64), i64 1, i32 0)
...
}
Even if I inline the llvmcall
into the function, I still only get a single shmem array:
function kernel3()
ptr1 = reinterpret(Ptr{Int8}, Base.llvmcall(("""
@shmem = internal global [1 x i8] zeroinitializer, align 32
define i8* @entry() #0 {
ret i8* getelementptr inbounds ([1 x i8], [1 x i8]* @shmem, i64 0, i64 0)
}
attributes #0 = { alwaysinline }""", "entry"),
Core.LLVMPtr{Int8,0}, Tuple{}))
arr1 = unsafe_wrap(Array, ptr1, 1)
ptr2 = reinterpret(Ptr{Int8}, Base.llvmcall(("""
@shmem = internal global [1 x i8] zeroinitializer, align 32
define i8* @entry() #0 {
ret i8* getelementptr inbounds ([1 x i8], [1 x i8]* @shmem, i64 0, i64 0)
}
attributes #0 = { alwaysinline }""", "entry"),
Core.LLVMPtr{Int8,0}, Tuple{}))
arr2 = unsafe_wrap(Array, ptr2, 1)
@inbounds begin
arr1[] = 1
arr2[]
end
end
Putting this on the milestone because this breaks our GPU back-ends. Happy to adapt those back-ends if another approach is better, although I don't want to go back to the old days where we made shmem()
a macro so that we could unique the gvar name (which is just a bad UI).