Skip to content

Commit 06ad3f4

Browse files
authored
Make various core math functions easier for the compiler to reason about (#43907)
* ldexp: Break inference loop We have an inference loop fma_emulated -> ldexp -> ^(::Float64, ::Int) -> fma -> fma_emulated. The arguments to `^` are constant, so constprop will figure it out, but it does require a bunch of extra processing. There is a simpler way to write this using elementary bit operations. Since resolving the inference loop requires constprop, this was breaking #43852. That is fixable, but I think we should also make this change to avoid having an unnecessary inference loop in our basic math functions, which will make future analyses easier. * Make fma_emulated easier for the compiler to reason about The fact that the `exponent` call in `fma_emulated` requires reasoning about the ranges of the floating point values in question, which the compiler is not capable of doing (and is unlikely to ever do automatically). Thus, in order for the compiler to know that `fma_emulated` (and by extension `fma`) is :nothrow in a post-#43852 world, create a separate version of the `exponent` function that assumes its precondition. We could use `@assume_effects` instead, but this version is currently slightly easier on the compiler. * pow: Make integer vs float branch obvious to constprop The integer branch is nothrow, so if the caller does something like `^(x::Float64, 2.0)`, we'd like to discover that.
1 parent 9769024 commit 06ad3f4

File tree

2 files changed

+44
-10
lines changed

2 files changed

+44
-10
lines changed

base/floatfuncs.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -375,24 +375,31 @@ function fma_emulated(a::Float64, b::Float64,c::Float64)
375375
return aandbfinite ? c : abhi+c
376376
end
377377
(iszero(a) || iszero(b)) && return abhi+c
378-
bias = exponent(a) + exponent(b)
378+
# The checks above satisfy exponent's nothrow precondition
379+
bias = Math._exponent_finite_nonzero(a) + Math._exponent_finite_nonzero(b)
379380
c_denorm = ldexp(c, -bias)
380381
if isfinite(c_denorm)
381382
# rescale a and b to [1,2), equivalent to ldexp(a, -exponent(a))
382383
issubnormal(a) && (a *= 0x1p52)
383384
issubnormal(b) && (b *= 0x1p52)
384-
a = reinterpret(Float64, (reinterpret(UInt64, a) & 0x800fffffffffffff) | 0x3ff0000000000000)
385-
b = reinterpret(Float64, (reinterpret(UInt64, b) & 0x800fffffffffffff) | 0x3ff0000000000000)
385+
a = reinterpret(Float64, (reinterpret(UInt64, a) & ~Base.exponent_mask(Float64)) | Base.exponent_one(Float64))
386+
b = reinterpret(Float64, (reinterpret(UInt64, b) & ~Base.exponent_mask(Float64)) | Base.exponent_one(Float64))
386387
c = c_denorm
387388
abhi, ablo = twomul(a,b)
389+
# abhi <= 4 -> isfinite(r) (α)
388390
r = abhi+c
391+
# s ≈ 0 (β)
389392
s = (abs(abhi) > abs(c)) ? (abhi-r+c+ablo) : (c-r+abhi+ablo)
393+
# α ⩓ β -> isfinite(sumhi) (γ)
390394
sumhi = r+s
391395
# If result is subnormal, ldexp will cause double rounding because subnormals have fewer mantisa bits.
392396
# As such, we need to check whether round to even would lead to double rounding and manually round sumhi to avoid it.
393397
if issubnormal(ldexp(sumhi, bias))
394398
sumlo = r-sumhi+s
395-
bits_lost = -bias-exponent(sumhi)-1022
399+
# finite: See γ
400+
# non-zero: If sumhi == ±0., then ldexp(sumhi, bias) == ±0,
401+
# so we don't take this branch.
402+
bits_lost = -bias-Math._exponent_finite_nonzero(sumhi)-1022
396403
sumhiInt = reinterpret(UInt64, sumhi)
397404
if (bits_lost != 1) (sumhiInt&1 == 1)
398405
sumhi = nextfloat(sumhi, cmp(sumlo,0))

base/math.jl

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ export sin, cos, sincos, tan, sinh, cosh, tanh, asin, acos, atan,
1818
import .Base: log, exp, sin, cos, tan, sinh, cosh, tanh, asin,
1919
acos, atan, asinh, acosh, atanh, sqrt, log2, log10,
2020
max, min, minmax, ^, exp2, muladd, rem,
21-
exp10, expm1, log1p
21+
exp10, expm1, log1p, @constprop
2222

2323
using .Base: sign_mask, exponent_mask, exponent_one,
2424
exponent_half, uinttype, significand_mask,
@@ -784,7 +784,7 @@ function ldexp(x::T, e::Integer) where T<:IEEEFloat
784784
xu = reinterpret(Unsigned, x)
785785
xs = xu & ~sign_mask(T)
786786
xs >= exponent_mask(T) && return x # NaN or Inf
787-
k = Int(xs >> significand_bits(T))
787+
k = (xs >> significand_bits(T)) % Int
788788
if k == 0 # x is subnormal
789789
xs == 0 && return x # +-0
790790
m = leading_zeros(xs) - exponent_bits(T)
@@ -817,7 +817,8 @@ function ldexp(x::T, e::Integer) where T<:IEEEFloat
817817
return flipsign(T(0.0), x)
818818
end
819819
k += significand_bits(T)
820-
z = T(2.0)^-significand_bits(T)
820+
# z = T(2.0) ^ (-significand_bits(T))
821+
z = reinterpret(T, rem(exponent_bias(T)-significand_bits(T), uinttype(T)) << significand_bits(T))
821822
xu = (xu & ~exponent_mask(T)) | (rem(k, uinttype(T)) << significand_bits(T))
822823
return z*reinterpret(T, xu)
823824
end
@@ -841,7 +842,7 @@ julia> exponent(16.0)
841842
"""
842843
function exponent(x::T) where T<:IEEEFloat
843844
@noinline throw1(x) = throw(DomainError(x, "Cannot be NaN or Inf."))
844-
@noinline throw2(x) = throw(DomainError(x, "Cannot be subnormal converted to 0."))
845+
@noinline throw2(x) = throw(DomainError(x, "Cannot be ±0.0."))
845846
xs = reinterpret(Unsigned, x) & ~sign_mask(T)
846847
xs >= exponent_mask(T) && throw1(x)
847848
k = Int(xs >> significand_bits(T))
@@ -853,6 +854,21 @@ function exponent(x::T) where T<:IEEEFloat
853854
return k - exponent_bias(T)
854855
end
855856

857+
# Like exponent, but assumes the nothrow precondition. For
858+
# internal use only. Could be written as
859+
# @assume_effects :nothrow exponent()
860+
# but currently this form is easier on the compiler.
861+
function _exponent_finite_nonzero(x::T) where T<:IEEEFloat
862+
# @precond :nothrow !isnan(x) && !isinf(x) && !iszero(x)
863+
xs = reinterpret(Unsigned, x) & ~sign_mask(T)
864+
k = rem(xs >> significand_bits(T), Int)
865+
if k == 0 # x is subnormal
866+
m = leading_zeros(xs) - exponent_bits(T)
867+
k = 1 - m
868+
end
869+
return k - exponent_bias(T)
870+
end
871+
856872
"""
857873
significand(x)
858874
@@ -977,11 +993,17 @@ function modf(x::T) where T<:IEEEFloat
977993
return (rx, ix)
978994
end
979995

980-
function ^(x::Float64, y::Float64)
996+
# @constprop aggressive to help the compiler see the switch between the integer and float
997+
# variants for callers with constant `y`
998+
@constprop :aggressive function ^(x::Float64, y::Float64)
981999
yint = unsafe_trunc(Int, y) # Note, this is actually safe since julia freezes the result
9821000
y == yint && return x^yint
9831001
x<0 && y > -4e18 && throw_exp_domainerror(x) # |y| is small enough that y isn't an integer
9841002
x == 1 && return 1.0
1003+
return pow_body(x, y)
1004+
end
1005+
1006+
@inline function pow_body(x::Float64, y::Float64)
9851007
!isfinite(x) && return x*(y>0 || isnan(x))
9861008
x==0 && return abs(y)*Inf*(!(y>0))
9871009
logxhi,logxlo = Base.Math._log_ext(x)
@@ -990,10 +1012,15 @@ function ^(x::Float64, y::Float64)
9901012
hi = xyhi+xylo
9911013
return Base.Math.exp_impl(hi, xylo-(hi-xyhi), Val(:ℯ))
9921014
end
993-
function ^(x::T, y::T) where T <: Union{Float16, Float32}
1015+
1016+
@constprop :aggressive function ^(x::T, y::T) where T <: Union{Float16, Float32}
9941017
yint = unsafe_trunc(Int64, y) # Note, this is actually safe since julia freezes the result
9951018
y == yint && return x^yint
9961019
x < 0 && y > -4e18 && throw_exp_domainerror(x) # |y| is small enough that y isn't an integer
1020+
return pow_body(x, y)
1021+
end
1022+
1023+
@inline function pow_body(x::T, y::T) where T <: Union{Float16, Float32}
9971024
x == 1 && return one(T)
9981025
!isfinite(x) && return x*(y>0 || isnan(x))
9991026
x==0 && return abs(y)*T(Inf)*(!(y>0))

0 commit comments

Comments
 (0)