Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,12 @@ uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[[RandomNumbers]]
deps = ["Random", "Requires"]
git-tree-sha1 = "441e6fc35597524ada7f85e13df1f4e10137d16f"
uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143"
version = "1.4.0"

[[Reexport]]
git-tree-sha1 = "57d8440b0c7d98fc4f889e478e80f268d534c9d5"
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
Expand Down
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Memoize = "c03570c3-d221-55d1-a50c-7939bbd78826"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RandomNumbers = "e6cf234a-135c-5ec9-84dd-332b85af5143"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand Down
5 changes: 3 additions & 2 deletions src/CUDA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ include("device/llvm.jl")
include("device/runtime.jl")
include("device/texture.jl")

# array essentials
include("pool.jl")
include("array.jl")

# compiler libraries
include("../lib/cupti/CUPTI.jl")
Expand All @@ -71,8 +73,7 @@ include("compiler/execution.jl")
include("compiler/exceptions.jl")
include("compiler/reflection.jl")

# array abstraction
include("array.jl")
# array implementation
include("gpuarrays.jl")
include("utilities.jl")
include("texture.jl")
Expand Down
7 changes: 0 additions & 7 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,6 @@ function Base.unsafe_convert(::Type{CuDeviceArray{T,N,AS.Global}}, a::DenseCuArr
CuDeviceArray{T,N,AS.Global}(size(a), reinterpret(LLVMPtr{T,AS.Global}, pointer(a)))
end

Adapt.adapt_storage(::Adaptor, xs::CuArray{T,N}) where {T,N} =
Base.unsafe_convert(CuDeviceArray{T,N,AS.Global}, xs)

# we materialize ReshapedArray/ReinterpretArray/SubArray/... directly as a device array
Adapt.adapt_structure(::Adaptor, xs::DenseCuArray{T,N}) where {T,N} =
Base.unsafe_convert(CuDeviceArray{T,N,AS.Global}, xs)


## interop with CPU arrays

Expand Down
30 changes: 27 additions & 3 deletions src/compiler/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,13 @@ end
Base.getindex(r::CuRefValue) = r.x
Adapt.adapt_structure(to::Adaptor, r::Base.RefValue) = CuRefValue(adapt(to, r[]))

Adapt.adapt_storage(::Adaptor, xs::CuArray{T,N}) where {T,N} =
Base.unsafe_convert(CuDeviceArray{T,N,AS.Global}, xs)

# we materialize ReshapedArray/ReinterpretArray/SubArray/... directly as a device array
Adapt.adapt_structure(::Adaptor, xs::DenseCuArray{T,N}) where {T,N} =
Base.unsafe_convert(CuDeviceArray{T,N,AS.Global}, xs)

"""
cudaconvert(x)

Expand Down Expand Up @@ -193,10 +200,16 @@ end

## host-side kernels

struct HostKernel{F,TT} <: AbstractKernel{F,TT}
mutable struct HostKernel{F,TT} <: AbstractKernel{F,TT}
ctx::CuContext
mod::CuModule
fun::CuFunction

random_state::Union{Nothing,Missing,CuVector{UInt32}}

function HostKernel{F,TT}(ctx::CuContext, mod::CuModule, fun::CuFunction, random_state) where {F,TT}
kernel = new{F,TT}(ctx, mod, fun, random_state)
end
end

@doc (@doc AbstractKernel) HostKernel
Expand Down Expand Up @@ -345,10 +358,21 @@ function cufunction_link(@nospecialize(job::CompilerJob), compiled)
filter!(!isequal("exception_flag"), compiled.external_gvars)
end

return HostKernel{job.source.f,job.source.tt}(ctx, mod, fun)
random_state = nothing
if "global_random_state" in compiled.external_gvars
random_state = missing
filter!(!isequal("global_random_state"), compiled.external_gvars)
end

return HostKernel{job.source.f,job.source.tt}(ctx, mod, fun, random_state)
end

(kernel::HostKernel)(args...; kwargs...) = call(kernel, map(cudaconvert, args)...; kwargs...)
function (kernel::HostKernel)(args...; threads::CuDim=1, blocks::CuDim=1, kwargs...)
if kernel.random_state !== nothing
init_random_state!(kernel, prod(threads) * prod(blocks))
end
call(kernel, map(cudaconvert, args)...; threads, blocks, kwargs...)
end


## device-side kernels
Expand Down
1 change: 1 addition & 0 deletions src/device/intrinsics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ include("intrinsics/memory_dynamic.jl")
include("intrinsics/atomics.jl")
include("intrinsics/misc.jl")
include("intrinsics/wmma.jl")
include("intrinsics/random.jl")

# functionality from libdevice
#
Expand Down
71 changes: 71 additions & 0 deletions src/device/intrinsics/random.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
## random number generation

using Random
import RandomNumbers


# helpers

global_index() = (threadIdx().x, threadIdx().y, threadIdx().z,
blockIdx().x, blockIdx().y, blockIdx().z)


# global state

struct ThreadLocalXorshift32 <: RandomNumbers.AbstractRNG{UInt32}
vals::CuDeviceArray{UInt32, 6, AS.Generic}
end

function init_random_state!(kernel, len)
if kernel.random_state === missing || length(kernel.random_state) < len
kernel.random_state = CuVector{UInt32}(undef, len)
end

random_state_ptr = CuGlobal{Ptr{Cvoid}}(kernel.mod, "global_random_state")
random_state_ptr[] = reinterpret(Ptr{Cvoid}, pointer(kernel.random_state))
end

@eval @inline function global_random_state()
ptr = reinterpret(LLVMPtr{UInt32, AS.Generic}, Base.llvmcall(
$("""@global_random_state = weak externally_initialized global i$(WORD_SIZE) 0
define i$(WORD_SIZE) @entry() #0 {
%ptr = load i$(WORD_SIZE), i$(WORD_SIZE)* @global_random_state, align 8
ret i$(WORD_SIZE) %ptr
}
attributes #0 = { alwaysinline }
""", "entry"), Ptr{Cvoid}, Tuple{}))
dims = (blockDim().x, blockDim().y, blockDim().z, gridDim().x, gridDim().y, gridDim().z)
CuDeviceArray(dims, ptr)
end

@device_override Random.default_rng() = ThreadLocalXorshift32(global_random_state())

@device_override Random.make_seed() = clock(UInt32)

function Random.seed!(rng::ThreadLocalXorshift32, seed::Integer)
index = global_index()
rng.vals[index...] = seed % UInt32
return
end


# generators

function xorshift(x::UInt32)::UInt32
x = xor(x, x << 13)
x = xor(x, x >> 17)
x = xor(x, x << 5)
return x
end

function Random.rand(rng::ThreadLocalXorshift32, ::Type{UInt32})
# NOTE: we add the current linear index to the local state, to make sure threads get
# different random numbers when unseeded (initial state = 0 for all threads)
index = global_index()
offset = LinearIndices(rng.vals)[index...]
state = rng.vals[index...] + UInt32(offset)

new_state = xorshift(state)
rng.vals[index...] = new_state
return new_state
end
54 changes: 54 additions & 0 deletions test/device/intrinsics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1180,3 +1180,57 @@ end
end

end



############################################################################################

@testset "random numbers" begin

n = 256

@testset "basic" begin
function kernel(A::CuDeviceArray{T}, B::CuDeviceArray{T}) where {T}
tid = threadIdx().x
A[tid] = rand(T)
B[tid] = rand(T)
return nothing
end

@testset for T in (Int32, UInt32, Int64, UInt64, Int128, UInt128,
Float32, Float64)
a = CUDA.zeros(T, n)
b = CUDA.zeros(T, n)

@cuda threads=n kernel(a, b)

@test all(Array(a) .!= Array(b))

if T == Float64
@test allunique(Array(a))
@test allunique(Array(b))
end
end
end

@testset "custom seed" begin
function kernel(A::CuDeviceArray{T}) where {T}
tid = threadIdx().x
Random.seed!(1234)
A[tid] = rand(T)
return nothing
end

@testset for T in (Int32, UInt32, Int64, UInt64, Int128, UInt128,
Float32, Float64)
a = CUDA.zeros(T, n)
b = CUDA.zeros(T, n)

@cuda threads=n kernel(a)
@cuda threads=n kernel(b)

@test Array(a) == Array(b)
end
end

end