Skip to content

Commit 8be7301

Browse files
committed
only use accurate powf function
The powi intrinsic optimization over calling powf is that it is inaccurate. We don't need that. When it is equally accurate (e.g. tiny constant powers), LLVM will already recognize and optimize any call to a function named `powf`, and produce the same speedup. fix #19872
1 parent 91127d3 commit 8be7301

File tree

9 files changed

+32
-74
lines changed

9 files changed

+32
-74
lines changed

base/fastmath.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ module FastMath
2323

2424
export @fastmath
2525

26-
import Core.Intrinsics: powi_llvm, sqrt_llvm_fast, neg_float_fast,
26+
import Core.Intrinsics: powf_llvm, sqrt_llvm_fast, neg_float_fast,
2727
add_float_fast, sub_float_fast, mul_float_fast, div_float_fast, rem_float_fast,
2828
eq_float_fast, ne_float_fast, lt_float_fast, le_float_fast
2929

@@ -243,8 +243,8 @@ end
243243

244244
# builtins
245245

246-
pow_fast{T<:FloatTypes}(x::T, y::Integer) = pow_fast(x, Int32(y))
247-
pow_fast{T<:FloatTypes}(x::T, y::Int32) = Base.powi_llvm(x, y)
246+
pow_fast{T<:FloatTypes}(x::T, y::Integer) = pow_fast(x, convert(T, y))
247+
pow_fast{T<:FloatTypes}(x::T, y::T) = powf_llvm(x, y)
248248

249249
# TODO: Change sqrt_llvm intrinsic to avoid nan checking; add nan
250250
# checking to sqrt in math.jl; remove sqrt_llvm_fast intrinsic

base/inference.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ add_tfunc(floor_llvm, 1, 1, math_tfunc)
427427
add_tfunc(trunc_llvm, 1, 1, math_tfunc)
428428
add_tfunc(rint_llvm, 1, 1, math_tfunc)
429429
add_tfunc(sqrt_llvm, 1, 1, math_tfunc)
430-
add_tfunc(powi_llvm, 2, 2, math_tfunc)
430+
add_tfunc(powf_llvm, 2, 2, math_tfunc)
431431
add_tfunc(sqrt_llvm_fast, 1, 1, math_tfunc)
432432
## same-type comparisons ##
433433
cmp_tfunc(x::ANY, y::ANY) = Bool

base/math.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ using Base: sign_mask, exponent_mask, exponent_one, exponent_bias,
3232
exponent_half, exponent_max, exponent_raw_max, fpinttype,
3333
significand_mask, significand_bits, exponent_bits
3434

35-
using Core.Intrinsics: sqrt_llvm, powi_llvm
35+
using Core.Intrinsics: sqrt_llvm, powf_llvm
3636

3737
# non-type specific math functions
3838

@@ -680,11 +680,10 @@ end
680680
^(x::Float64, y::Float64) = nan_dom_err(ccall((:pow,libm), Float64, (Float64,Float64), x, y), x+y)
681681
^(x::Float32, y::Float32) = nan_dom_err(ccall((:powf,libm), Float32, (Float32,Float32), x, y), x+y)
682682

683-
^(x::Float64, y::Integer) = x^Int32(y)
684-
^(x::Float64, y::Int32) = powi_llvm(x, y)
685-
^(x::Float32, y::Integer) = x^Int32(y)
686-
^(x::Float32, y::Int32) = powi_llvm(x, y)
687-
^(x::Float16, y::Integer) = Float16(Float32(x)^y)
683+
^(x::Float64, y::Integer) = x ^ Float64(y)
684+
^(x::Float64, y::Float64) = powf_llvm(x, y)
685+
^(x::Float32, y::Integer) = x ^ Float32(y)
686+
^(x::Float32, y::Float32) = powf_llvm(x, y)
688687

689688
function angle_restrict_symm(theta)
690689
const P1 = 4 * 7.8539812564849853515625e-01

src/codegen.cpp

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -394,10 +394,8 @@ static Function *expect_func;
394394
static Function *jldlsym_func;
395395
static Function *jlnewbits_func;
396396
static Function *jltypeassert_func;
397-
#if JL_LLVM_VERSION < 30600
398397
static Function *jlpow_func;
399398
static Function *jlpowf_func;
400-
#endif
401399
//static Function *jlgetnthfield_func;
402400
static Function *jlgetnthfieldchecked_func;
403401
//static Function *jlsetnthfield_func;
@@ -5974,7 +5972,6 @@ static void init_julia_llvm_env(Module *m)
59745972
"jl_gc_diff_total_bytes", m);
59755973
add_named_global(diff_gc_total_bytes_func, *jl_gc_diff_total_bytes);
59765974

5977-
#if JL_LLVM_VERSION < 30600
59785975
Type *powf_type[2] = { T_float32, T_float32 };
59795976
jlpowf_func = Function::Create(FunctionType::get(T_float32, powf_type, false),
59805977
Function::ExternalLinkage,
@@ -5986,13 +5983,9 @@ static void init_julia_llvm_env(Module *m)
59865983
Function::ExternalLinkage,
59875984
"pow", m);
59885985
add_named_global(jlpow_func,
5989-
#ifdef _COMPILER_MICROSOFT_
5990-
static_cast<double (*)(double, double)>(&pow),
5991-
#else
5992-
&pow,
5993-
#endif
5986+
static_cast<double (*)(double, double)>(&::pow),
59945987
false);
5995-
#endif
5988+
59965989
std::vector<Type*> array_owner_args(0);
59975990
array_owner_args.push_back(T_pjlvalue);
59985991
jlarray_data_owner_func =

src/intrinsics.cpp

Lines changed: 9 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ static void jl_init_intrinsic_functions_codegen(Module *m)
7171
float_func[rint_llvm] = true;
7272
float_func[sqrt_llvm] = true;
7373
float_func[sqrt_llvm_fast] = true;
74-
float_func[powi_llvm] = true;
74+
float_func[powf_llvm] = true;
7575
}
7676

7777
extern "C"
@@ -915,33 +915,6 @@ static jl_cgval_t emit_intrinsic(intrinsic f, jl_value_t **args, size_t nargs,
915915
return mark_julia_type(ans, false, x.typ, ctx);
916916
}
917917

918-
case powi_llvm: {
919-
const jl_cgval_t &x = argv[0];
920-
const jl_cgval_t &y = argv[1];
921-
if (!jl_is_bitstype(x.typ) || !jl_is_bitstype(y.typ) || jl_datatype_size(y.typ) != 4)
922-
return emit_runtime_call(f, argv, nargs, ctx);
923-
Type *xt = FLOATT(bitstype_to_llvm(x.typ));
924-
Type *yt = T_int32;
925-
if (!xt)
926-
return emit_runtime_call(f, argv, nargs, ctx);
927-
928-
Value *xv = emit_unbox(xt, x, x.typ);
929-
Value *yv = emit_unbox(yt, y, y.typ);
930-
#if JL_LLVM_VERSION >= 30600
931-
Value *powi = Intrinsic::getDeclaration(jl_Module, Intrinsic::powi, makeArrayRef(xt));
932-
#if JL_LLVM_VERSION >= 30700
933-
Value *ans = builder.CreateCall(powi, {xv, yv});
934-
#else
935-
Value *ans = builder.CreateCall2(powi, xv, yv);
936-
#endif
937-
#else
938-
// issue #6506
939-
Value *ans = builder.CreateCall2(prepare_call(xt == T_float64 ? jlpow_func : jlpowf_func),
940-
xv, builder.CreateSIToFP(yv, xt));
941-
#endif
942-
return mark_julia_type(ans, false, x.typ, ctx);
943-
}
944-
945918
default: {
946919
assert(nargs >= 1 && "invalid nargs for intrinsic call");
947920
const jl_cgval_t &xinfo = argv[0];
@@ -1296,6 +1269,14 @@ static Value *emit_untyped_intrinsic(intrinsic f, Value **argvalues, size_t narg
12961269
Value *sqrtintr = Intrinsic::getDeclaration(jl_Module, Intrinsic::sqrt, makeArrayRef(t));
12971270
return builder.CreateCall(sqrtintr, x);
12981271
}
1272+
case powf_llvm: {
1273+
Function *powf = (t == T_float64 ? jlpow_func : jlpowf_func);
1274+
#if JL_LLVM_VERSION >= 30700
1275+
return builder.CreateCall(prepare_call(powf), {x, y});
1276+
#else
1277+
return builder.CreateCall2(prepare_call(powf), x, y);
1278+
#endif
1279+
}
12991280

13001281
default:
13011282
assert(0 && "invalid intrinsic");

src/intrinsics.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@
9191
ADD_I(trunc_llvm, 1) \
9292
ADD_I(rint_llvm, 1) \
9393
ADD_I(sqrt_llvm, 1) \
94-
ADD_I(powi_llvm, 2) \
94+
ADD_I(powf_llvm, 2) \
9595
ALIAS(sqrt_llvm_fast, sqrt_llvm) \
9696
/* pointer access */ \
9797
ADD_I(pointerref, 3) \

src/julia_internal.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,7 @@ JL_DLLEXPORT jl_value_t *jl_floor_llvm(jl_value_t *a);
677677
JL_DLLEXPORT jl_value_t *jl_trunc_llvm(jl_value_t *a);
678678
JL_DLLEXPORT jl_value_t *jl_rint_llvm(jl_value_t *a);
679679
JL_DLLEXPORT jl_value_t *jl_sqrt_llvm(jl_value_t *a);
680-
JL_DLLEXPORT jl_value_t *jl_powi_llvm(jl_value_t *a, jl_value_t *b);
680+
JL_DLLEXPORT jl_value_t *jl_powf_llvm(jl_value_t *a, jl_value_t *b);
681681
JL_DLLEXPORT jl_value_t *jl_abs_float(jl_value_t *a);
682682
JL_DLLEXPORT jl_value_t *jl_copysign_float(jl_value_t *a, jl_value_t *b);
683683
JL_DLLEXPORT jl_value_t *jl_flipsign_int(jl_value_t *a, jl_value_t *b);

src/runtime_intrinsics.c

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,8 @@ bi_iintrinsic_fast(jl_LLVMFlipSign, flipsign, flipsign_int, )
947947
*pr = fp_select(a, sqrt)
948948
#define copysign_float(a, b) \
949949
fp_select2(a, b, copysign)
950+
#define pow_float(a, b) \
951+
fp_select2(a, b, pow)
950952

951953
un_fintrinsic(abs_float,abs_float)
952954
bi_fintrinsic(copysign_float,copysign_float)
@@ -955,31 +957,7 @@ un_fintrinsic(floor_float,floor_llvm)
955957
un_fintrinsic(trunc_float,trunc_llvm)
956958
un_fintrinsic(rint_float,rint_llvm)
957959
un_fintrinsic(sqrt_float,sqrt_llvm)
958-
959-
JL_DLLEXPORT jl_value_t *jl_powi_llvm(jl_value_t *a, jl_value_t *b)
960-
{
961-
jl_ptls_t ptls = jl_get_ptls_states();
962-
jl_value_t *ty = jl_typeof(a);
963-
if (!jl_is_bitstype(ty))
964-
jl_error("powi_llvm: a is not a bitstype");
965-
if (!jl_is_bitstype(jl_typeof(b)) || jl_datatype_size(jl_typeof(b)) != 4)
966-
jl_error("powi_llvm: b is not a 32-bit bitstype");
967-
int sz = jl_datatype_size(ty);
968-
jl_value_t *newv = jl_gc_alloc(ptls, sz, ty);
969-
void *pa = jl_data_ptr(a), *pr = jl_data_ptr(newv);
970-
switch (sz) {
971-
/* choose the right size c-type operation */
972-
case 4:
973-
*(float*)pr = powf(*(float*)pa, (float)jl_unbox_int32(b));
974-
break;
975-
case 8:
976-
*(double*)pr = pow(*(double*)pa, (double)jl_unbox_int32(b));
977-
break;
978-
default:
979-
jl_error("powi_llvm: runtime floating point intrinsics are not implemented for bit sizes other than 32 and 64");
980-
}
981-
return newv;
982-
}
960+
bi_fintrinsic(pow_float,powf_llvm)
983961

984962
JL_DLLEXPORT jl_value_t *jl_select_value(jl_value_t *isfalse, jl_value_t *a, jl_value_t *b)
985963
{

test/math.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -996,6 +996,13 @@ end
996996
end
997997
end
998998

999+
@testset "issue #19872" begin
1000+
f19872(x) = x ^ 3
1001+
@test issubnormal(2.0 ^ (-1024))
1002+
@test f19872(2.0) === 8.0
1003+
@test !issubnormal(0.0)
1004+
end
1005+
9991006
@test Base.Math.f32(complex(1.0,1.0)) == complex(Float32(1.),Float32(1.))
10001007
@test Base.Math.f16(complex(1.0,1.0)) == complex(Float16(1.),Float16(1.))
10011008

0 commit comments

Comments
 (0)