Skip to content

Commit b8f7ef1

Browse files
committed
only use accurate powf function
The powi intrinsic optimization over calling powf is that it is inaccurate. We don't need that. When it is equally accurate (e.g. tiny constant powers), LLVM will already recognize and optimize any call to a function named `powf`, and produce the same speedup. fix #19872
1 parent 5a1f971 commit b8f7ef1

File tree

9 files changed

+29
-94
lines changed

9 files changed

+29
-94
lines changed

base/fastmath.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ module FastMath
2323

2424
export @fastmath
2525

26-
import Core.Intrinsics: powi_llvm, sqrt_llvm_fast, neg_float_fast,
26+
import Core.Intrinsics: sqrt_llvm_fast, neg_float_fast,
2727
add_float_fast, sub_float_fast, mul_float_fast, div_float_fast, rem_float_fast,
2828
eq_float_fast, ne_float_fast, lt_float_fast, le_float_fast
2929

@@ -243,8 +243,8 @@ end
243243

244244
# builtins
245245

246-
pow_fast(x::FloatTypes, y::Integer) = pow_fast(x, Int32(y))
247-
pow_fast(x::FloatTypes, y::Int32) = Base.powi_llvm(x, y)
246+
pow_fast(x::Float32, y::Integer) = ccall("llvm.powi.f32", llvmcall, Float32, (Float32, Int32), x, y)
247+
pow_fast(x::Float64, y::Integer) = ccall("llvm.powi.f64", llvmcall, Float64, (Float64, Int32), x, y)
248248

249249
# TODO: Change sqrt_llvm intrinsic to avoid nan checking; add nan
250250
# checking to sqrt in math.jl; remove sqrt_llvm_fast intrinsic

base/inference.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,6 @@ add_tfunc(floor_llvm, 1, 1, math_tfunc)
468468
add_tfunc(trunc_llvm, 1, 1, math_tfunc)
469469
add_tfunc(rint_llvm, 1, 1, math_tfunc)
470470
add_tfunc(sqrt_llvm, 1, 1, math_tfunc)
471-
add_tfunc(powi_llvm, 2, 2, math_tfunc)
472471
add_tfunc(sqrt_llvm_fast, 1, 1, math_tfunc)
473472
## same-type comparisons ##
474473
cmp_tfunc(x::ANY, y::ANY) = Bool

base/math.jl

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ using Base: sign_mask, exponent_mask, exponent_one, exponent_bias,
2424
exponent_half, exponent_max, exponent_raw_max, fpinttype,
2525
significand_mask, significand_bits, exponent_bits
2626

27-
using Core.Intrinsics: sqrt_llvm, powi_llvm
27+
using Core.Intrinsics: sqrt_llvm
2828

2929
const IEEEFloat = Union{Float16,Float32,Float64}
3030
# non-type specific math functions
@@ -286,6 +286,8 @@ exp10(x::Float32) = 10.0f0^x
286286
exp10(x::Integer) = exp10(float(x))
287287

288288
# utility for converting NaN return to DomainError
289+
# the branch in nan_dom_err prevents its callers from inlining, so be sure to force it
290+
# until the heuristics can be improved
289291
@inline nan_dom_err(f, x) = isnan(f) & !isnan(x) ? throw(DomainError()) : f
290292

291293
# functions that return NaN on non-NaN argument for domain error
@@ -403,9 +405,9 @@ log1p(x)
403405
for f in (:sin, :cos, :tan, :asin, :acos, :acosh, :atanh, :log, :log2, :log10,
404406
:lgamma, :log1p)
405407
@eval begin
406-
($f)(x::Float64) = nan_dom_err(ccall(($(string(f)),libm), Float64, (Float64,), x), x)
407-
($f)(x::Float32) = nan_dom_err(ccall(($(string(f,"f")),libm), Float32, (Float32,), x), x)
408-
($f)(x::Real) = ($f)(float(x))
408+
@inline ($f)(x::Float64) = nan_dom_err(ccall(($(string(f)), libm), Float64, (Float64,), x), x)
409+
@inline ($f)(x::Float32) = nan_dom_err(ccall(($(string(f, "f")), libm), Float32, (Float32,), x), x)
410+
@inline ($f)(x::Real) = ($f)(float(x))
409411
end
410412
end
411413

@@ -683,14 +685,11 @@ function modf(x::Float64)
683685
f, _modf_temp[]
684686
end
685687

686-
^(x::Float64, y::Float64) = nan_dom_err(ccall((:pow,libm), Float64, (Float64,Float64), x, y), x+y)
687-
^(x::Float32, y::Float32) = nan_dom_err(ccall((:powf,libm), Float32, (Float32,Float32), x, y), x+y)
688-
689-
^(x::Float64, y::Integer) = x^Int32(y)
690-
^(x::Float64, y::Int32) = powi_llvm(x, y)
691-
^(x::Float32, y::Integer) = x^Int32(y)
692-
^(x::Float32, y::Int32) = powi_llvm(x, y)
693-
^(x::Float16, y::Integer) = Float16(Float32(x)^y)
688+
@inline ^(x::Float64, y::Float64) = nan_dom_err(ccall("llvm.pow.f64", llvmcall, Float64, (Float64, Float64), x, y), x + y)
689+
@inline ^(x::Float32, y::Float32) = nan_dom_err(ccall("llvm.pow.f32", llvmcall, Float32, (Float32, Float32), x, y), x + y)
690+
@inline ^(x::Float64, y::Integer) = x ^ Float64(y)
691+
@inline ^(x::Float32, y::Integer) = x ^ Float32(y)
692+
@inline ^(x::Float16, y::Integer) = Float16(Float32(x) ^ Float32(y))
694693
^{p}(x::Float16, ::Type{Val{p}}) = Float16(Float32(x)^Val{p})
695694

696695
function angle_restrict_symm(theta)

src/codegen.cpp

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -397,10 +397,6 @@ static Function *jldlsym_func;
397397
static Function *jlnewbits_func;
398398
static Function *jltypeassert_func;
399399
static Function *jldepwarnpi_func;
400-
#if JL_LLVM_VERSION < 30600
401-
static Function *jlpow_func;
402-
static Function *jlpowf_func;
403-
#endif
404400
//static Function *jlgetnthfield_func;
405401
static Function *jlgetnthfieldchecked_func;
406402
//static Function *jlsetnthfield_func;
@@ -6857,25 +6853,6 @@ static void init_julia_llvm_env(Module *m)
68576853
"jl_gc_diff_total_bytes", m);
68586854
add_named_global(diff_gc_total_bytes_func, *jl_gc_diff_total_bytes);
68596855

6860-
#if JL_LLVM_VERSION < 30600
6861-
Type *powf_type[2] = { T_float32, T_float32 };
6862-
jlpowf_func = Function::Create(FunctionType::get(T_float32, powf_type, false),
6863-
Function::ExternalLinkage,
6864-
"powf", m);
6865-
add_named_global(jlpowf_func, &powf, false);
6866-
6867-
Type *pow_type[2] = { T_float64, T_float64 };
6868-
jlpow_func = Function::Create(FunctionType::get(T_float64, pow_type, false),
6869-
Function::ExternalLinkage,
6870-
"pow", m);
6871-
add_named_global(jlpow_func,
6872-
#ifdef _COMPILER_MICROSOFT_
6873-
static_cast<double (*)(double, double)>(&pow),
6874-
#else
6875-
&pow,
6876-
#endif
6877-
false);
6878-
#endif
68796856
std::vector<Type*> array_owner_args(0);
68806857
array_owner_args.push_back(T_pjlvalue);
68816858
jlarray_data_owner_func =

src/intrinsics.cpp

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ static void jl_init_intrinsic_functions_codegen(Module *m)
7171
float_func[rint_llvm] = true;
7272
float_func[sqrt_llvm] = true;
7373
float_func[sqrt_llvm_fast] = true;
74-
float_func[powi_llvm] = true;
7574
}
7675

7776
extern "C"
@@ -851,33 +850,6 @@ static jl_cgval_t emit_intrinsic(intrinsic f, jl_value_t **args, size_t nargs,
851850
return mark_julia_type(ans, false, x.typ, ctx);
852851
}
853852

854-
case powi_llvm: {
855-
const jl_cgval_t &x = argv[0];
856-
const jl_cgval_t &y = argv[1];
857-
if (!jl_is_primitivetype(x.typ) || !jl_is_primitivetype(y.typ) || jl_datatype_size(y.typ) != 4)
858-
return emit_runtime_call(f, argv, nargs, ctx);
859-
Type *xt = FLOATT(bitstype_to_llvm(x.typ));
860-
Type *yt = T_int32;
861-
if (!xt)
862-
return emit_runtime_call(f, argv, nargs, ctx);
863-
864-
Value *xv = emit_unbox(xt, x, x.typ);
865-
Value *yv = emit_unbox(yt, y, y.typ);
866-
#if JL_LLVM_VERSION >= 30600
867-
Value *powi = Intrinsic::getDeclaration(jl_Module, Intrinsic::powi, makeArrayRef(xt));
868-
#if JL_LLVM_VERSION >= 30700
869-
Value *ans = builder.CreateCall(powi, {xv, yv});
870-
#else
871-
Value *ans = builder.CreateCall2(powi, xv, yv);
872-
#endif
873-
#else
874-
// issue #6506
875-
Value *ans = builder.CreateCall2(prepare_call(xt == T_float64 ? jlpow_func : jlpowf_func),
876-
xv, builder.CreateSIToFP(yv, xt));
877-
#endif
878-
return mark_julia_type(ans, false, x.typ, ctx);
879-
}
880-
881853
default: {
882854
assert(nargs >= 1 && "invalid nargs for intrinsic call");
883855
const jl_cgval_t &xinfo = argv[0];

src/intrinsics.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@
9191
ADD_I(trunc_llvm, 1) \
9292
ADD_I(rint_llvm, 1) \
9393
ADD_I(sqrt_llvm, 1) \
94-
ADD_I(powi_llvm, 2) \
9594
ALIAS(sqrt_llvm_fast, sqrt_llvm) \
9695
/* pointer access */ \
9796
ADD_I(pointerref, 3) \

src/julia_internal.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,6 @@ JL_DLLEXPORT jl_value_t *jl_floor_llvm(jl_value_t *a);
680680
JL_DLLEXPORT jl_value_t *jl_trunc_llvm(jl_value_t *a);
681681
JL_DLLEXPORT jl_value_t *jl_rint_llvm(jl_value_t *a);
682682
JL_DLLEXPORT jl_value_t *jl_sqrt_llvm(jl_value_t *a);
683-
JL_DLLEXPORT jl_value_t *jl_powi_llvm(jl_value_t *a, jl_value_t *b);
684683
JL_DLLEXPORT jl_value_t *jl_abs_float(jl_value_t *a);
685684
JL_DLLEXPORT jl_value_t *jl_copysign_float(jl_value_t *a, jl_value_t *b);
686685
JL_DLLEXPORT jl_value_t *jl_flipsign_int(jl_value_t *a, jl_value_t *b);

src/runtime_intrinsics.c

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -925,31 +925,6 @@ un_fintrinsic(trunc_float,trunc_llvm)
925925
un_fintrinsic(rint_float,rint_llvm)
926926
un_fintrinsic(sqrt_float,sqrt_llvm)
927927

928-
JL_DLLEXPORT jl_value_t *jl_powi_llvm(jl_value_t *a, jl_value_t *b)
929-
{
930-
jl_ptls_t ptls = jl_get_ptls_states();
931-
jl_value_t *ty = jl_typeof(a);
932-
if (!jl_is_primitivetype(ty))
933-
jl_error("powi_llvm: a is not a primitive type");
934-
if (!jl_is_primitivetype(jl_typeof(b)) || jl_datatype_size(jl_typeof(b)) != 4)
935-
jl_error("powi_llvm: b is not a 32-bit primitive type");
936-
int sz = jl_datatype_size(ty);
937-
jl_value_t *newv = jl_gc_alloc(ptls, sz, ty);
938-
void *pa = jl_data_ptr(a), *pr = jl_data_ptr(newv);
939-
switch (sz) {
940-
/* choose the right size c-type operation */
941-
case 4:
942-
*(float*)pr = powf(*(float*)pa, (float)jl_unbox_int32(b));
943-
break;
944-
case 8:
945-
*(double*)pr = pow(*(double*)pa, (double)jl_unbox_int32(b));
946-
break;
947-
default:
948-
jl_error("powi_llvm: runtime floating point intrinsics are not implemented for bit sizes other than 32 and 64");
949-
}
950-
return newv;
951-
}
952-
953928
JL_DLLEXPORT jl_value_t *jl_select_value(jl_value_t *isfalse, jl_value_t *a, jl_value_t *b)
954929
{
955930
JL_TYPECHK(isfalse, bool, isfalse);

test/math.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,21 @@ end
590590
end
591591
end
592592

593+
@testset "issue #19872" begin
594+
f19872a(x) = x ^ 5
595+
f19872b(x) = x ^ (-1024)
596+
@test 0 < f19872b(2.0) < 1e-300
597+
@test issubnormal(2.0 ^ (-1024))
598+
@test issubnormal(f19872b(2.0))
599+
@test !issubnormal(f19872b(0.0))
600+
@test f19872a(2.0) === 32.0
601+
@test !issubnormal(f19872a(2.0))
602+
@test !issubnormal(0.0)
603+
end
604+
605+
@test Base.Math.f32(complex(1.0,1.0)) == complex(Float32(1.),Float32(1.))
606+
@test Base.Math.f16(complex(1.0,1.0)) == complex(Float16(1.),Float16(1.))
607+
593608
# no domain error is thrown for negative values
594609
@test invoke(cbrt, Tuple{AbstractFloat}, -1.0) == -1.0
595610

0 commit comments

Comments
 (0)