Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions stdlib/Random/src/Xoshiro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -296,11 +296,16 @@ rand(r::Union{TaskLocalRNG, Xoshiro}, ::SamplerTrivial{UInt52Raw{UInt64}}) = ran
rand(r::Union{TaskLocalRNG, Xoshiro}, ::SamplerTrivial{UInt52{UInt64}}) = rand(r, UInt64) >>> 12
rand(r::Union{TaskLocalRNG, Xoshiro}, ::SamplerTrivial{UInt104{UInt128}}) = rand(r, UInt104Raw())

rand(r::Union{TaskLocalRNG, Xoshiro}, ::SamplerTrivial{CloseOpen01{Float16}}) =
Float16(rand(r, UInt16) >>> 5) * Float16(0x1.0p-11)

rand(r::Union{TaskLocalRNG, Xoshiro}, ::SamplerTrivial{CloseOpen01{Float32}}) =
Float32(rand(r, UInt32) >>> 8) * Float32(0x1.0p-24)

rand(r::Union{TaskLocalRNG, Xoshiro}, ::SamplerTrivial{CloseOpen01_64}) =
Float64(rand(r, UInt64) >>> 11) * 0x1.0p-53
for FT in (Float16, Float32, Float64)
UT = Base.uinttype(FT)
# Helper function: scale an unsigned integer to a floating point number of the same size
# in the interval [0, 1). This is equivalent to, but more easily extensible than
# Float16(i >>> 5) * Float16(0x1.0p-11)
# Float32(i >>> 8) * Float32(0x1.0p-24)
# Float32(i >>> 11) * Float64(0x1.0p-53)
@eval @inline _uint2float(i::$(UT), ::Type{$(FT)}) =
$(FT)(i >>> $(8 * sizeof(FT) - precision(FT))) * $(FT(2) ^ -precision(FT))

@eval rand(r::Union{TaskLocalRNG, Xoshiro}, ::SamplerTrivial{CloseOpen01{$(FT)}}) =
_uint2float(rand(r, $(UT)), $(FT))
end
21 changes: 13 additions & 8 deletions stdlib/Random/src/XoshiroSimd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
module XoshiroSimd
# Getting the xoroshiro RNG to reliably vectorize is somewhat of a hassle without Simd.jl.
import ..Random: rand!
using ..Random: TaskLocalRNG, rand, Xoshiro, CloseOpen01, UnsafeView, SamplerType, SamplerTrivial, getstate, setstate!
using ..Random: TaskLocalRNG, rand, Xoshiro, CloseOpen01, UnsafeView, SamplerType, SamplerTrivial, getstate, setstate!, _uint2float
using Base: BitInteger_types
using Base.Libc: memcpy
using Core.Intrinsics: llvmcall
Expand All @@ -30,7 +30,12 @@ simdThreshold(::Type{Bool}) = 640
Tuple{UInt64, Int64},
x, y)

@inline _bits2float(x::UInt64, ::Type{Float64}) = reinterpret(UInt64, Float64(x >>> 11) * 0x1.0p-53)
# `_bits2float(x::UInt64, T)` takes `x::UInt64` as input, it splits it in `N` parts where
# `N = sizeof(UInt64) / sizeof(T)` (`N = 1` for `Float64`, `N = 2` for `Float32, etc...), it
# truncates each part to the unsigned type of the same size as `T`, scales all of these
# numbers to a value of type `T` in the range [0,1) with `_uint2float`, and then
# recomposes another `UInt64` using all these parts.
@inline _bits2float(x::UInt64, ::Type{Float64}) = reinterpret(UInt64, _uint2float(x, Float64))
@inline function _bits2float(x::UInt64, ::Type{Float32})
#=
# this implementation uses more high bits, but is harder to vectorize
Expand All @@ -40,19 +45,19 @@ simdThreshold(::Type{Bool}) = 640
=#
ui = (x>>>32) % UInt32
li = x % UInt32
u = Float32(ui >>> 8) * Float32(0x1.0p-24)
l = Float32(li >>> 8) * Float32(0x1.0p-24)
u = _uint2float(ui, Float32)
l = _uint2float(ui, Float32)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ui -> li...?

(UInt64(reinterpret(UInt32, u)) << 32) | UInt64(reinterpret(UInt32, l))
end
@inline function _bits2float(x::UInt64, ::Type{Float16})
i1 = (x>>>48) % UInt16
i2 = (x>>>32) % UInt16
i3 = (x>>>16) % UInt16
i4 = x % UInt16
f1 = Float16(i1 >>> 5) * Float16(0x1.0p-11)
f2 = Float16(i2 >>> 5) * Float16(0x1.0p-11)
f3 = Float16(i3 >>> 5) * Float16(0x1.0p-11)
f4 = Float16(i4 >>> 5) * Float16(0x1.0p-11)
f1 = _uint2float(i1, Float16)
f2 = _uint2float(i2, Float16)
f3 = _uint2float(i3, Float16)
f4 = _uint2float(i4, Float16)
return (UInt64(reinterpret(UInt16, f1)) << 48) | (UInt64(reinterpret(UInt16, f2)) << 32) | (UInt64(reinterpret(UInt16, f3)) << 16) | UInt64(reinterpret(UInt16, f4))
end

Expand Down