diff --git a/libc/config/linux/x86_64/entrypoints.txt b/libc/config/linux/x86_64/entrypoints.txt index 0e4fbdd2f2f57..dc3120ca913a3 100644 --- a/libc/config/linux/x86_64/entrypoints.txt +++ b/libc/config/linux/x86_64/entrypoints.txt @@ -740,6 +740,7 @@ if(LIBC_TYPES_HAS_FLOAT16) libc.src.math.rintf16 libc.src.math.roundevenf16 libc.src.math.roundf16 + libc.src.math.rsqrtf16 libc.src.math.scalblnf16 libc.src.math.scalbnf16 libc.src.math.setpayloadf16 diff --git a/libc/docs/headers/math/index.rst b/libc/docs/headers/math/index.rst index cf88e6237d1e3..a29e70a7b625b 100644 --- a/libc/docs/headers/math/index.rst +++ b/libc/docs/headers/math/index.rst @@ -333,7 +333,7 @@ Higher Math Functions +-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+ | rootn | | | | | | 7.12.7.8 | F.10.4.8 | +-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+ -| rsqrt | | | | | | 7.12.7.9 | F.10.4.9 | +| rsqrt | | | | |check| | | 7.12.7.9 | F.10.4.9 | +-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+ | sin | |check| | |check| | | |check| | | 7.12.4.6 | F.10.1.6 | +-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+ diff --git a/libc/include/math.yaml b/libc/include/math.yaml index b725c33c0bb06..35aee5f54d8cd 100644 --- a/libc/include/math.yaml +++ b/libc/include/math.yaml @@ -2237,6 +2237,13 @@ functions: return_type: long double arguments: - type: long double + - name: rsqrtf16 + standards: + - stdc + return_type: _Float16 + arguments: + - type: _Float16 + guard: LIBC_TYPES_HAS_FLOAT16 - name: scalbln standards: - stdc diff --git a/libc/src/math/CMakeLists.txt b/libc/src/math/CMakeLists.txt index 5161b2b61aa94..689327b2f1249 100644 --- a/libc/src/math/CMakeLists.txt +++ b/libc/src/math/CMakeLists.txt @@ -467,6 +467,8 @@ add_math_entrypoint_object(roundevenl) add_math_entrypoint_object(roundevenf16) add_math_entrypoint_object(roundevenf128) +add_math_entrypoint_object(rsqrtf16) + add_math_entrypoint_object(scalbln) add_math_entrypoint_object(scalblnf) add_math_entrypoint_object(scalblnl) diff --git a/libc/src/math/generic/CMakeLists.txt b/libc/src/math/generic/CMakeLists.txt index 25f583035fbc2..11ffd8d69829f 100644 --- a/libc/src/math/generic/CMakeLists.txt +++ b/libc/src/math/generic/CMakeLists.txt @@ -955,6 +955,25 @@ add_entrypoint_object( libc.src.__support.FPUtil.nearest_integer_operations ) +add_entrypoint_object( + rsqrtf16 + SRCS + rsqrtf16.cpp + HDRS + ../rsqrtf16.h + DEPENDS + libc.hdr.errno_macros + libc.hdr.fenv_macros + libc.src.__support.FPUtil.cast + libc.src.__support.FPUtil.fenv_impl + libc.src.__support.FPUtil.fp_bits + libc.src.__support.FPUtil.fma + libc.src.__support.FPUtil.manipulation_functions + libc.src.__support.FPUtil.polyeval + libc.src.__support.macros.optimization + libc.src.__support.macros.properties.types +) + add_entrypoint_object( lround SRCS diff --git a/libc/src/math/generic/rsqrtf16.cpp b/libc/src/math/generic/rsqrtf16.cpp new file mode 100644 index 0000000000000..6ad9f5f968772 --- /dev/null +++ b/libc/src/math/generic/rsqrtf16.cpp @@ -0,0 +1,120 @@ +//===-- Half-precision rsqrt function -------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception. +// +//===----------------------------------------------------------------------===// + +#include "src/math/rsqrtf16.h" +#include "hdr/errno_macros.h" +#include "hdr/fenv_macros.h" +#include "src/__support/FPUtil/FEnvImpl.h" +#include "src/__support/FPUtil/FMA.h" +#include "src/__support/FPUtil/FPBits.h" +#include "src/__support/FPUtil/ManipulationFunctions.h" +#include "src/__support/FPUtil/PolyEval.h" +#include "src/__support/FPUtil/cast.h" +#include "src/__support/macros/optimization.h" + +namespace LIBC_NAMESPACE_DECL { + +LLVM_LIBC_FUNCTION(float16, rsqrtf16, (float16 x)) { + using FPBits = fputil::FPBits; + FPBits xbits(x); + + uint16_t x_u = xbits.uintval(); + uint16_t x_abs = x_u & 0x7fff; + uint16_t x_sign = x_u >> 15; + + // x is NaN + if (LIBC_UNLIKELY(xbits.is_nan())) { + if (xbits.is_signaling_nan()) { + fputil::raise_except_if_required(FE_INVALID); + return FPBits::quiet_nan().get_val(); + } + return x; + } + + // |x| = 0 + if (LIBC_UNLIKELY(x_abs == 0x0)) { + fputil::raise_except_if_required(FE_DIVBYZERO); + fputil::set_errno_if_required(ERANGE); + return FPBits::inf(Sign::POS).get_val(); + } + + // -inf <= x < 0 + if (LIBC_UNLIKELY(x_sign == 1)) { + fputil::raise_except_if_required(FE_INVALID); + fputil::set_errno_if_required(EDOM); + return FPBits::quiet_nan().get_val(); + } + + // x = +inf => rsqrt(x) = 0 + if (LIBC_UNLIKELY(xbits.is_inf())) { + return fputil::cast(0.0f); + } + + // x is valid, estimate the result + // Range reduction: + // x can be expressed as m*2^e, where e - int exponent and m - mantissa + // rsqrtf16(x) = rsqrtf16(m*2^e) + // rsqrtf16(m*2^e) = 1/sqrt(m) * 1/sqrt(2^e) = 1/sqrt(m) * 1/2^(e/2) + // 1/sqrt(m) * 1/2^(e/2) = 1/sqrt(m) * 2^(-e/2) + + float xf = x; + int exponent; + float mantissa = fputil::frexp(xf, exponent); + + float result; + int exp_floored = -(exponent >> 1); + + if (mantissa == 0.5f) { + // When mantissa is 0.5f, x was a power of 2 (or subnormal that normalizes this way). + // 1/sqrt(0.5f) = sqrt(2.0f) = 0x1.6a09e6p0f + // If exponent is odd (exponent = 2k + 1): + // rsqrt(x) = (1/sqrt(0.5)) * 2^(-(2k+1)/2) = sqrt(2) * 2^(-k-0.5) + // = sqrt(2) * 2^(-k) * (1/sqrt(2)) = 2^(-k) + // exp_floored = -((2k+1)>>1) = -(k) = -k + // So result = ldexp(1.0f, exp_floored) + // If exponent is even (exponent = 2k): + // rsqrt(x) = (1/sqrt(0.5)) * 2^(-2k/2) = sqrt(2) * 2^(-k) + // exp_floored = -((2k)>>1) = -(k) = -k + // So result = ldexp(sqrt(2.0f), exp_floored) + if (exponent & 1) { + result = fputil::ldexp(1.0f, exp_floored); + } else { + constexpr float SQRT_2_F = 0x1.6a09e6p0f; // sqrt(2.0f) + result = fputil::ldexp(SQRT_2_F, exp_floored); + } + } else { + // 6-degree polynomial generated using Sollya + // P = fpminimax(1/sqrt(x), [|0,1,2,3,4,5|], [|SG...|], [0.5, 1]); + float interm = fputil::polyeval( + mantissa, 0x1.9c81c4p1f, -0x1.e2c57cp2f, 0x1.91e8bp3f, + -0x1.899954p3f, 0x1.9edcp2f, -0x1.6bd93cp0f); + + // Apply one Newton-Raphson iteration to refine the approximation of + // 1/sqrt(mantissa) y_new = y_old * (1.5 - 0.5 * mantissa * y_old^2) Using + // fputil::fma for potential precision benefits in the factor calculation + float interm_sq = interm * interm; + float factor = fputil::fma(-0.5f * mantissa, interm_sq, 1.5f); + float interm_refined = interm * factor; + + // Apply a second Newton-Raphson iteration + // y_new = y_old * (1.5 - 0.5 * mantissa * y_old^2) + float interm_refined_sq = interm_refined * interm_refined; + float factor2 = + fputil::fma(-0.5f * mantissa, interm_refined_sq, 1.5f); + float interm_refined2 = interm_refined * factor2; + + result = fputil::ldexp(interm_refined2, exp_floored); + if (exponent & 1) { + const float ONE_OVER_SQRT2 = 0x1.6a09e6p-1f; + result = fputil::fma(result, ONE_OVER_SQRT2, 0.0f); + } + } + + return fputil::cast(result); +} +} // namespace LIBC_NAMESPACE_DECL diff --git a/libc/src/math/rsqrtf16.h b/libc/src/math/rsqrtf16.h new file mode 100644 index 0000000000000..c88ab5256ce88 --- /dev/null +++ b/libc/src/math/rsqrtf16.h @@ -0,0 +1,21 @@ +//===-- Implementation header for rsqrtf16 ----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC_MATH_RSQRTF16_H +#define LLVM_LIBC_SRC_MATH_RSQRTF16_H + +#include "src/__support/macros/config.h" +#include "src/__support/macros/properties/types.h" + +namespace LIBC_NAMESPACE_DECL { + +float16 rsqrtf16(float16 x); + +} // namespace LIBC_NAMESPACE_DECL + +#endif // LLVM_LIBC_SRC_MATH_RSQRTF16_H diff --git a/libc/test/src/math/CMakeLists.txt b/libc/test/src/math/CMakeLists.txt index ebf9f1c86cf15..8daf59032622c 100644 --- a/libc/test/src/math/CMakeLists.txt +++ b/libc/test/src/math/CMakeLists.txt @@ -1560,6 +1560,17 @@ add_fp_unittest( libc.src.math.sqrtl ) +add_fp_unittest( + rsqrtf16_test + NEED_MPFR + SUITE + libc-math-unittests + SRCS + rsqrtf16_test.cpp + DEPENDS + libc.src.math.rsqrtf16 +) + add_fp_unittest( sqrtf16_test NEED_MPFR diff --git a/libc/test/src/math/rsqrtf16_test.cpp b/libc/test/src/math/rsqrtf16_test.cpp new file mode 100644 index 0000000000000..d2f3fe8f49b92 --- /dev/null +++ b/libc/test/src/math/rsqrtf16_test.cpp @@ -0,0 +1,42 @@ +//===-- Exhaustive test for rsqrtf16 --------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "src/math/rsqrtf16.h" +#include "test/UnitTest/FPMatcher.h" +#include "test/UnitTest/Test.h" +#include "utils/MPFRWrapper/MPFRUtils.h" + +using LlvmLibcRsqrtf16Test = LIBC_NAMESPACE::testing::FPTest; + +namespace mpfr = LIBC_NAMESPACE::testing::mpfr; + +// Range: [0, Inf] +static constexpr uint16_t POS_START = 0x0000U; +static constexpr uint16_t POS_STOP = 0x7c00U; + +// Range: [-Inf, 0] +static constexpr uint16_t NEG_START = 0x8000U; +static constexpr uint16_t NEG_STOP = 0xfc00U; + +TEST_F(LlvmLibcRsqrtf16Test, PositiveRange) { + for (uint16_t v = POS_START; v <= POS_STOP; ++v) { + float16 x = FPBits(v).get_val(); + + EXPECT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Rsqrt, x, + LIBC_NAMESPACE::rsqrtf16(x), 0.5); + } +} + +TEST_F(LlvmLibcRsqrtf16Test, NegativeRange) { + for (uint16_t v = NEG_START; v <= NEG_STOP; ++v) { + float16 x = FPBits(v).get_val(); + + EXPECT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Rsqrt, x, + LIBC_NAMESPACE::rsqrtf16(x), 0.5); + } +} diff --git a/libc/test/src/math/smoke/CMakeLists.txt b/libc/test/src/math/smoke/CMakeLists.txt index 79b189159e9d8..fe5d8a6b3b864 100644 --- a/libc/test/src/math/smoke/CMakeLists.txt +++ b/libc/test/src/math/smoke/CMakeLists.txt @@ -2958,6 +2958,17 @@ add_fp_unittest( libc.src.math.sqrtl ) +add_fp_unittest( + rsqrtf16_test + SUITE + libc-math-smoke-tests + SRCS + rsqrtf16_test.cpp + DEPENDS + libc.src.errno.errno + libc.src.math.rsqrtf16 +) + add_fp_unittest( sqrtf16_test SUITE diff --git a/libc/test/src/math/smoke/rsqrtf16_test.cpp b/libc/test/src/math/smoke/rsqrtf16_test.cpp new file mode 100644 index 0000000000000..8e69027e67e13 --- /dev/null +++ b/libc/test/src/math/smoke/rsqrtf16_test.cpp @@ -0,0 +1,37 @@ +//===-- Unittests for rsqrtf16 --------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception. +// +//===----------------------------------------------------------------------===// + +#include "src/errno/libc_errno.h" +#include "src/math/rsqrtf16.h" +#include "test/UnitTest/FPMatcher.h" +#include "test/UnitTest/Test.h" + +using LlvmLibcRsqrtf16Test = LIBC_NAMESPACE::testing::FPTest; +TEST_F(LlvmLibcRsqrtf16Test, SpecialNumbers) { + LIBC_NAMESPACE::libc_errno = 0; + EXPECT_FP_EQ(aNaN, LIBC_NAMESPACE::rsqrtf16(aNaN)); + EXPECT_MATH_ERRNO(0); + + EXPECT_FP_EQ_WITH_EXCEPTION(aNaN, LIBC_NAMESPACE::rsqrtf16(sNaN), FE_INVALID); + EXPECT_MATH_ERRNO(0); + + EXPECT_FP_EQ(inf, LIBC_NAMESPACE::rsqrtf16(0.0f)); + EXPECT_MATH_ERRNO(ERANGE); + + EXPECT_FP_EQ(1.0f, LIBC_NAMESPACE::rsqrtf16(1.0f)); + EXPECT_MATH_ERRNO(0); + + EXPECT_FP_EQ(0.0f, LIBC_NAMESPACE::rsqrtf16(inf)); + EXPECT_MATH_ERRNO(0); + + EXPECT_FP_EQ(aNaN, LIBC_NAMESPACE::rsqrtf16(neg_inf)); + EXPECT_MATH_ERRNO(EDOM); + + EXPECT_FP_EQ(aNaN, LIBC_NAMESPACE::rsqrtf16(-2.0f)); + EXPECT_MATH_ERRNO(EDOM); +} diff --git a/libc/utils/MPFRWrapper/MPCommon.cpp b/libc/utils/MPFRWrapper/MPCommon.cpp index ccd4d2d01a4e2..1a78ca5c24ba2 100644 --- a/libc/utils/MPFRWrapper/MPCommon.cpp +++ b/libc/utils/MPFRWrapper/MPCommon.cpp @@ -366,6 +366,12 @@ MPFRNumber MPFRNumber::rint(mpfr_rnd_t rnd) const { return result; } +MPFRNumber MPFRNumber::rsqrt() const { + MPFRNumber result(*this); + mpfr_rec_sqrt(result.value, value, mpfr_rounding); + return result; +} + MPFRNumber MPFRNumber::mod_2pi() const { MPFRNumber result(0.0, 1280); MPFRNumber _2pi(0.0, 1280); diff --git a/libc/utils/MPFRWrapper/MPCommon.h b/libc/utils/MPFRWrapper/MPCommon.h index 99cb7ec66a2ca..43218ee7662db 100644 --- a/libc/utils/MPFRWrapper/MPCommon.h +++ b/libc/utils/MPFRWrapper/MPCommon.h @@ -216,6 +216,7 @@ class MPFRNumber { bool round_to_long(long &result) const; bool round_to_long(mpfr_rnd_t rnd, long &result) const; MPFRNumber rint(mpfr_rnd_t rnd) const; + MPFRNumber rsqrt() const; MPFRNumber mod_2pi() const; MPFRNumber mod_pi_over_2() const; MPFRNumber mod_pi_over_4() const; diff --git a/libc/utils/MPFRWrapper/MPFRUtils.cpp b/libc/utils/MPFRWrapper/MPFRUtils.cpp index 8853f96ef8f92..a68d21650f004 100644 --- a/libc/utils/MPFRWrapper/MPFRUtils.cpp +++ b/libc/utils/MPFRWrapper/MPFRUtils.cpp @@ -86,6 +86,8 @@ unary_operation(Operation op, InputType input, unsigned int precision, return mpfrInput.round(); case Operation::RoundEven: return mpfrInput.roundeven(); + case Operation::Rsqrt: + return mpfrInput.rsqrt(); case Operation::Sin: return mpfrInput.sin(); case Operation::Sinpi: diff --git a/libc/utils/MPFRWrapper/MPFRUtils.h b/libc/utils/MPFRWrapper/MPFRUtils.h index c77a6aa3adeae..532b2c8d9d819 100644 --- a/libc/utils/MPFRWrapper/MPFRUtils.h +++ b/libc/utils/MPFRWrapper/MPFRUtils.h @@ -55,6 +55,7 @@ enum class Operation : int { ModPIOver4, Round, RoundEven, + Rsqrt, Sin, Sinpi, Sinh,