diff --git a/base/floatfuncs.jl b/base/floatfuncs.jl index 4c2f03d304f81..bae27d642e7c7 100644 --- a/base/floatfuncs.jl +++ b/base/floatfuncs.jl @@ -342,88 +342,30 @@ significantly more expensive than `x*y+z`. `fma` is used to improve accuracy in algorithms. See [`muladd`](@ref). """ function fma end -function fma_emulated(a::Float32, b::Float32, c::Float32)::Float32 - ab = Float64(a) * b - res = ab+c - reinterpret(UInt64, res)&0x1fff_ffff!=0x1000_0000 && return res - # yes error compensation is necessary. It sucks - reslo = abs(c)>abs(ab) ? ab-(res - c) : c-(res - ab) - res = iszero(reslo) ? res : (signbit(reslo) ? prevfloat(res) : nextfloat(res)) - return res -end - -""" Splits a Float64 into a hi bit and a low bit where the high bit has 27 trailing 0s and the low bit has 26 trailing 0s""" -@inline function splitbits(x::Float64) - hi = reinterpret(Float64, reinterpret(UInt64, x) & 0xffff_ffff_f800_0000) - return hi, x-hi -end - -@inline function twomul(a::Float64, b::Float64) - ahi, alo = splitbits(a) - bhi, blo = splitbits(b) - abhi = a*b - blohi, blolo = splitbits(blo) - ablo = alo*blohi - (((abhi - ahi*bhi) - alo*bhi) - ahi*blo) + blolo*alo - return abhi, ablo -end -function fma_emulated(a::Float64, b::Float64,c::Float64) - abhi, ablo = twomul(a,b) - if !isfinite(abhi+c) || isless(abs(abhi), nextfloat(0x1p-969)) || issubnormal(a) || issubnormal(b) - (isfinite(a) && isfinite(b) && isfinite(c)) || return abhi+c - (iszero(a) || iszero(b)) && return abhi+c - bias = exponent(a) + exponent(b) - c_denorm = ldexp(c, -bias) - if isfinite(c_denorm) - # rescale a and b to [1,2), equivalent to ldexp(a, -exponent(a)) - issubnormal(a) && (a *= 0x1p52) - issubnormal(b) && (b *= 0x1p52) - a = reinterpret(Float64, (reinterpret(UInt64, a) & 0x800fffffffffffff) | 0x3ff0000000000000) - b = reinterpret(Float64, (reinterpret(UInt64, b) & 0x800fffffffffffff) | 0x3ff0000000000000) - c = c_denorm - abhi, ablo = twomul(a,b) - r = abhi+c - s = (abs(abhi) > abs(c)) ? (abhi-r+c+ablo) : (c-r+abhi+ablo) - sumhi = r+s - # If result is subnormal, ldexp will cause double rounding because subnormals have fewer mantisa bits. - # As such, we need to check whether round to even would lead to double rounding and manually round sumhi to avoid it. - if issubnormal(ldexp(sumhi, bias)) - sumlo = r-sumhi+s - bits_lost = -bias-exponent(sumhi)-1022 - sumhiInt = reinterpret(UInt64, sumhi) - if (bits_lost != 1) ⊻ (sumhiInt&1 == 1) - sumhi = nextfloat(sumhi, cmp(sumlo,0)) - end - end - return ldexp(sumhi, bias) - end - isinf(abhi) && signbit(c) == signbit(a*b) && return abhi - # fall through - end - r = abhi+c - s = (abs(abhi) > abs(c)) ? (abhi-r+c+ablo) : (c-r+abhi+ablo) - return r+s -end +fma_libm(x::Float32, y::Float32, z::Float32) = + ccall(("fmaf", libm_name), Float32, (Float32,Float32,Float32), x, y, z) +fma_libm(x::Float64, y::Float64, z::Float64) = + ccall(("fma", libm_name), Float64, (Float64,Float64,Float64), x, y, z) fma_llvm(x::Float32, y::Float32, z::Float32) = fma_float(x, y, z) fma_llvm(x::Float64, y::Float64, z::Float64) = fma_float(x, y, z) # Disable LLVM's fma if it is incorrect, e.g. because LLVM falls back -# onto a broken system libm; if so, use a software emulated fma +# onto a broken system libm; if so, use openlibm's fma instead # 1.0000305f0 = 1 + 1/2^15 # 1.0000000009313226 = 1 + 1/2^30 # If fma_llvm() clobbers the rounding mode, the result of 0.1 + 0.2 will be 0.3 # instead of the properly-rounded 0.30000000000000004; check after calling fma -# TODO actually detect fma in hardware and switch on that. if (Sys.ARCH !== :i686 && fma_llvm(1.0000305f0, 1.0000305f0, -1.0f0) == 6.103609f-5 && (fma_llvm(1.0000000009313226, 1.0000000009313226, -1.0) == 1.8626451500983188e-9) && 0.1 + 0.2 == 0.30000000000000004) fma(x::Float32, y::Float32, z::Float32) = fma_llvm(x,y,z) fma(x::Float64, y::Float64, z::Float64) = fma_llvm(x,y,z) else - fma(x::Float32, y::Float32, z::Float32) = fma_emulated(x,y,z) - fma(x::Float64, y::Float64, z::Float64) = fma_emulated(x,y,z) + fma(x::Float32, y::Float32, z::Float32) = fma_libm(x,y,z) + fma(x::Float64, y::Float64, z::Float64) = fma_libm(x,y,z) end function fma(a::Float16, b::Float16, c::Float16) - Float16(muladd(Float32(a), Float32(b), Float32(c))) #don't use fma if the hardware doesn't have it. + Float16(fma(Float32(a), Float32(b), Float32(c))) end # This is necessary at least on 32-bit Intel Linux, since fma_llvm may diff --git a/test/math.jl b/test/math.jl index da027fa8919f4..cdcc1c5c6a47d 100644 --- a/test/math.jl +++ b/test/math.jl @@ -1286,27 +1286,3 @@ end @test_throws MethodError f(x) end end - -@testset "fma" begin - for func in (fma, Base.fma_emulated) - @test func(nextfloat(1.),nextfloat(1.),-1.0) === 4.440892098500626e-16 - @test func(nextfloat(1f0),nextfloat(1f0),-1f0) === 2.3841858f-7 - @testset "$T" for T in (Float32, Float64) - @test func(floatmax(T), T(2), -floatmax(T)) === floatmax(T) - @test func(floatmax(T), T(1), eps(floatmax((T)))) === T(Inf) - @test func(T(Inf), T(Inf), T(Inf)) === T(Inf) - @test isnan_type(T, func(T(Inf), T(1), -T(Inf))) - @test isnan_type(T, func(T(Inf), T(0), -T(0))) - @test func(-zero(T), zero(T), -zero(T)) === -zero(T) - for _ in 1:2^18 - a, b, c = reinterpret.(T, rand(Base.uinttype(T), 3)) - @test isequal(func(a, b, c), fma(a, b, c)) || (a,b,c) - end - end - @test func(floatmax(Float64), nextfloat(1.0), -floatmax(Float64)) === 3.991680619069439e292 - @test func(floatmax(Float32), nextfloat(1f0), -floatmax(Float32)) === 4.0564817f31 - @test func(1.6341681540852291e308, -2., floatmax(Float64)) == -1.4706431733081426e308 # case where inv(a)*c*a == Inf - @test func(-2., 1.6341681540852291e308, floatmax(Float64)) == -1.4706431733081426e308 # case where inv(b)*c*b == Inf - @test func(-1.9369631f13, 2.1513551f-7, -1.7354427f-24) == -4.1670958f6 - end -end