From 1fdf5c47d78ddab64d680d86c5af51591a97852d Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Wed, 25 Sep 2024 14:25:21 -0400 Subject: [PATCH 01/13] Fix pow(a,b) overload resolution under llvm19 --- stan/math/fwd/fun/pow.hpp | 2 +- stan/math/prim/fun/pow.hpp | 6 +++--- stan/math/rev/fun/pow.hpp | 24 ++++++++++++++++++++++-- 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/stan/math/fwd/fun/pow.hpp b/stan/math/fwd/fun/pow.hpp index c181df79b62..b85061db8b0 100644 --- a/stan/math/fwd/fun/pow.hpp +++ b/stan/math/fwd/fun/pow.hpp @@ -193,7 +193,7 @@ inline std::complex> pow(const fvar& x, const std::complex& y) { * @return first argument to the power of the second argument */ template > -inline std::complex> pow(T x, const std::complex>& y) { +inline std::complex> pow(const T& x, const std::complex>& y) { return internal::complex_pow(x, y); } diff --git a/stan/math/prim/fun/pow.hpp b/stan/math/prim/fun/pow.hpp index 2ed54148042..ec9ddb8c162 100644 --- a/stan/math/prim/fun/pow.hpp +++ b/stan/math/prim/fun/pow.hpp @@ -59,11 +59,11 @@ inline auto pow(const T1& a, const T2& b) { * second argument. */ template * = nullptr, - require_all_not_matrix_st* = nullptr> + require_all_not_matrix_st* = nullptr, + require_all_st_arithmetic* = nullptr> inline auto pow(const T1& a, const T2& b) { return apply_scalar_binary(a, b, [](const auto& c, const auto& d) { - using std::pow; - return pow(c, d); + return stan::math::pow(c, d); }); } } // namespace math diff --git a/stan/math/rev/fun/pow.hpp b/stan/math/rev/fun/pow.hpp index 8a2383880b9..e89f2e8e875 100644 --- a/stan/math/rev/fun/pow.hpp +++ b/stan/math/rev/fun/pow.hpp @@ -319,7 +319,7 @@ inline std::complex pow(const std::complex& x, const var& y) { * @return first argument to the power of the second argument */ template > -inline std::complex pow(const std::complex& x, T y) { +inline std::complex pow(const std::complex& x, const T& y) { return internal::complex_pow(x, y); } @@ -382,7 +382,7 @@ inline std::complex pow(const var& x, std::complex y) { * @return first argument to the power of the second argument */ template > -inline std::complex pow(T x, const std::complex& y) { +inline std::complex pow(const T& x, const std::complex& y) { return internal::complex_pow(x, y); } @@ -401,6 +401,26 @@ inline std::complex pow(const std::complex& x, int y) { return internal::complex_pow(x, y); } +/** + * Returns the elementwise raising of the first argument to the power of the + * second argument. + * + * @tparam T1 type of first argument + * @tparam T2 type of second argument + * @param a first argument + * @param b second argument + * @return the elementwise raising of the first argument to the power of the + * second argument. + */ +template * = nullptr, + require_all_not_matrix_st* = nullptr, + require_any_not_st_arithmetic* = nullptr> +inline auto pow(const T1& a, const T2& b) { + return apply_scalar_binary(a, b, [](const auto& c, const auto& d) { + return stan::math::pow(c, d); + }); +} + } // namespace math } // namespace stan #endif From d070e181021b604a8afad27611db0bf9e2265bc4 Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Wed, 25 Sep 2024 14:32:06 -0400 Subject: [PATCH 02/13] [Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1 --- stan/math/prim/fun/pow.hpp | 5 ++--- stan/math/rev/fun/pow.hpp | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/stan/math/prim/fun/pow.hpp b/stan/math/prim/fun/pow.hpp index ec9ddb8c162..72d935e15c0 100644 --- a/stan/math/prim/fun/pow.hpp +++ b/stan/math/prim/fun/pow.hpp @@ -62,9 +62,8 @@ template * = nullptr, require_all_not_matrix_st* = nullptr, require_all_st_arithmetic* = nullptr> inline auto pow(const T1& a, const T2& b) { - return apply_scalar_binary(a, b, [](const auto& c, const auto& d) { - return stan::math::pow(c, d); - }); + return apply_scalar_binary( + a, b, [](const auto& c, const auto& d) { return stan::math::pow(c, d); }); } } // namespace math } // namespace stan diff --git a/stan/math/rev/fun/pow.hpp b/stan/math/rev/fun/pow.hpp index e89f2e8e875..afc808608d2 100644 --- a/stan/math/rev/fun/pow.hpp +++ b/stan/math/rev/fun/pow.hpp @@ -416,9 +416,8 @@ template * = nullptr, require_all_not_matrix_st* = nullptr, require_any_not_st_arithmetic* = nullptr> inline auto pow(const T1& a, const T2& b) { - return apply_scalar_binary(a, b, [](const auto& c, const auto& d) { - return stan::math::pow(c, d); - }); + return apply_scalar_binary( + a, b, [](const auto& c, const auto& d) { return stan::math::pow(c, d); }); } } // namespace math From 09434283ddbe0d0244e9cbfc428c26df193d3c6a Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Wed, 25 Sep 2024 14:47:36 -0400 Subject: [PATCH 03/13] Missed a couple --- stan/math/fwd/fun/pow.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stan/math/fwd/fun/pow.hpp b/stan/math/fwd/fun/pow.hpp index b85061db8b0..ce6af077619 100644 --- a/stan/math/fwd/fun/pow.hpp +++ b/stan/math/fwd/fun/pow.hpp @@ -25,7 +25,7 @@ inline fvar pow(const fvar& x1, const fvar& x2) { } template > -inline fvar pow(U x1, const fvar& x2) { +inline fvar pow(const U& x1, const fvar& x2) { using std::log; using std::pow; T u = pow(x1, x2.val_); @@ -33,7 +33,7 @@ inline fvar pow(U x1, const fvar& x2) { } template > -inline fvar pow(const fvar& x1, U x2) { +inline fvar pow(const fvar& x1, const U& x2) { using std::pow; using std::sqrt; if (x2 == -2) { From 07aadfb4add6e6f7b5779328700bffc7df42c99c Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Wed, 25 Sep 2024 14:59:54 -0400 Subject: [PATCH 04/13] Patch test resolution as well --- test/unit/math/mix/fun/pow_part1_test.cpp | 33 +++---- test/unit/math/mix/fun/pow_part2_test.cpp | 106 +++++++++++----------- 2 files changed, 67 insertions(+), 72 deletions(-) diff --git a/test/unit/math/mix/fun/pow_part1_test.cpp b/test/unit/math/mix/fun/pow_part1_test.cpp index f97b645a8ec..72100764562 100644 --- a/test/unit/math/mix/fun/pow_part1_test.cpp +++ b/test/unit/math/mix/fun/pow_part1_test.cpp @@ -5,20 +5,19 @@ template void expect_arith_instantiate() { - using stan::math::pow; - auto a1 = pow(T(1.0), 1); - auto b1 = pow(T(1.0), 1.0); - auto c1 = pow(1, T(1.0)); - auto d1 = pow(1.0, T(1.0)); - auto e1 = pow(T(1.0), T(1.0)); - - auto a2 = pow(std::complex(1.0), 1); - auto b2 = pow(std::complex(1.0), 1.0); - auto c2 = pow(1, std::complex(1.0)); - auto d2 = pow(1.0, std::complex(1.0)); - auto e2 = pow(std::complex(1.0), std::complex(1.0)); - auto f2 = pow(std::complex(1.0), std::complex(1.0)); - auto g2 = pow(std::complex(1.0), std::complex(1.0)); + auto a1 = stan::math::pow(T(1.0), 1); + auto b1 = stan::math::pow(T(1.0), 1.0); + auto c1 = stan::math::pow(1, T(1.0)); + auto d1 = stan::math::pow(1.0, T(1.0)); + auto e1 = stan::math::pow(T(1.0), T(1.0)); + + auto a2 = stan::math::pow(std::complex(1.0), 1); + auto b2 = stan::math::pow(std::complex(1.0), 1.0); + auto c2 = stan::math::pow(1, std::complex(1.0)); + auto d2 = stan::math::pow(1.0, std::complex(1.0)); + auto e2 = stan::math::pow(std::complex(1.0), std::complex(1.0)); + auto f2 = stan::math::pow(std::complex(1.0), std::complex(1.0)); + auto g2 = stan::math::pow(std::complex(1.0), std::complex(1.0)); } // this one's been tricky to instantiate, so test all instances @@ -34,10 +33,8 @@ TEST(mathMixScalFun, powInstantiations) { } TEST(mathMixScalFun, pow) { - auto f = [](const auto& x1, const auto& x2) { - using stan::math::pow; - return pow(x1, x2); - }; + auto f + = [](const auto& x1, const auto& x2) { return stan::math::pow(x1, x2); }; stan::test::expect_ad(f, -0.4, 0.5); stan::test::expect_ad(f, 0.5, 0.5); diff --git a/test/unit/math/mix/fun/pow_part2_test.cpp b/test/unit/math/mix/fun/pow_part2_test.cpp index 6fb46d08e2d..9340dad9481 100644 --- a/test/unit/math/mix/fun/pow_part2_test.cpp +++ b/test/unit/math/mix/fun/pow_part2_test.cpp @@ -49,7 +49,6 @@ TEST(mathMixFun, complexPow) { } TEST(mathMixFun, powIntAmbiguityTest) { - using stan::math::pow; // included to check ambiguities using stan::math::var; using std::complex; int i = 2; @@ -58,35 +57,35 @@ TEST(mathMixFun, powIntAmbiguityTest) { complex cd = 2.5; complex cv = 2.5; - auto a1 = pow(i, i); - auto a2 = pow(i, d); - auto a3 = pow(i, v); - auto a4 = pow(i, cd); - auto a5 = pow(i, cv); - - auto b1 = pow(d, i); - auto b2 = pow(d, d); - auto b3 = pow(d, v); - auto b4 = pow(d, cd); - auto b5 = pow(d, cv); - - auto e1 = pow(v, i); - auto e2 = pow(v, d); - auto e3 = pow(v, v); - auto e4 = pow(v, cd); - auto e5 = pow(v, cv); - - auto c1 = pow(cd, i); - auto c2 = pow(cd, d); - auto c3 = pow(cd, v); - auto c4 = pow(cd, cd); - auto c5 = pow(cd, cv); - - auto d1 = pow(cv, i); - auto d2 = pow(cv, d); - auto d3 = pow(cv, v); - auto d4 = pow(cv, cd); - auto d5 = pow(cv, cv); + auto a1 = stan::math::pow(i, i); + auto a2 = stan::math::pow(i, d); + auto a3 = stan::math::pow(i, v); + auto a4 = stan::math::pow(i, cd); + auto a5 = stan::math::pow(i, cv); + + auto b1 = stan::math::pow(d, i); + auto b2 = stan::math::pow(d, d); + auto b3 = stan::math::pow(d, v); + auto b4 = stan::math::pow(d, cd); + auto b5 = stan::math::pow(d, cv); + + auto e1 = stan::math::pow(v, i); + auto e2 = stan::math::pow(v, d); + auto e3 = stan::math::pow(v, v); + auto e4 = stan::math::pow(v, cd); + auto e5 = stan::math::pow(v, cv); + + auto c1 = stan::math::pow(cd, i); + auto c2 = stan::math::pow(cd, d); + auto c3 = stan::math::pow(cd, v); + auto c4 = stan::math::pow(cd, cd); + auto c5 = stan::math::pow(cd, cv); + + auto d1 = stan::math::pow(cv, i); + auto d2 = stan::math::pow(cv, d); + auto d3 = stan::math::pow(cv, v); + auto d4 = stan::math::pow(cv, cd); + auto d5 = stan::math::pow(cv, cv); auto e = a1 + a2 + a3 + a4 + a5 + b1 + b2 + b3 + b4 + b5 + c1 + c2 + c3 + c4 + c5 + d1 + d2 + d3 + d4 + d5 + e1 + e2 + e3 + e4 + e5; @@ -96,7 +95,6 @@ TEST(mathMixFun, powIntAmbiguityTest) { TEST(mathMixFun, powIntAmbiguityTestFvar) { using stan::math::fvar; - using stan::math::pow; // included to check ambiguities using std::complex; int i = 2; double d = 2.5; @@ -104,29 +102,29 @@ TEST(mathMixFun, powIntAmbiguityTestFvar) { complex cd = 2.5; complex> cv = 2.5; - auto a1 = pow(i, i); - auto a2 = pow(i, d); - auto a3 = pow(i, v); - auto a4 = pow(i, cd); - auto a5 = pow(i, cv); - - auto b1 = pow(d, i); - auto b2 = pow(d, d); - auto b3 = pow(d, v); - auto b4 = pow(d, cd); - auto b5 = pow(d, cv); - - auto c1 = pow(cd, i); - auto c2 = pow(cd, d); - auto c3 = pow(cd, v); - auto c4 = pow(cd, cd); - auto c5 = pow(cd, cv); - - auto d1 = pow(cv, i); - auto d2 = pow(cv, d); - auto d3 = pow(cv, v); - auto d4 = pow(cv, cd); - auto d5 = pow(cv, cv); + auto a1 = stan::math::pow(i, i); + auto a2 = stan::math::pow(i, d); + auto a3 = stan::math::pow(i, v); + auto a4 = stan::math::pow(i, cd); + auto a5 = stan::math::pow(i, cv); + + auto b1 = stan::math::pow(d, i); + auto b2 = stan::math::pow(d, d); + auto b3 = stan::math::pow(d, v); + auto b4 = stan::math::pow(d, cd); + auto b5 = stan::math::pow(d, cv); + + auto c1 = stan::math::pow(cd, i); + auto c2 = stan::math::pow(cd, d); + auto c3 = stan::math::pow(cd, v); + auto c4 = stan::math::pow(cd, cd); + auto c5 = stan::math::pow(cd, cv); + + auto d1 = stan::math::pow(cv, i); + auto d2 = stan::math::pow(cv, d); + auto d3 = stan::math::pow(cv, v); + auto d4 = stan::math::pow(cv, cd); + auto d5 = stan::math::pow(cv, cv); auto e = a1 + a2 + a3 + a4 + a5 + b1 + b2 + b3 + b4 + b5 + c1 + c2 + c3 + c4 + c5 + d1 + d2 + d3 + d4 + d5; From f7eea1430019f5144603f3dc189be2bd71e36c15 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Wed, 25 Sep 2024 15:01:41 -0400 Subject: [PATCH 05/13] Patch test resolution as well --- test/unit/math/mix/fun/pow_part2_test.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/unit/math/mix/fun/pow_part2_test.cpp b/test/unit/math/mix/fun/pow_part2_test.cpp index 9340dad9481..bc4d9d6160d 100644 --- a/test/unit/math/mix/fun/pow_part2_test.cpp +++ b/test/unit/math/mix/fun/pow_part2_test.cpp @@ -5,8 +5,7 @@ TEST(mathMixFun, complexPow) { auto f = [](const auto& x1, const auto& x2) { - using stan::math::pow; - return pow(x1, x2); + return stan::math::pow(x1, x2); }; stan::test::ad_tolerances tols; tols.hessian_hessian_ = 5e-3; From 9239589aca0acaf54874794eda6d90c5ba888c00 Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Wed, 25 Sep 2024 15:03:24 -0400 Subject: [PATCH 06/13] [Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1 --- test/unit/math/mix/fun/pow_part2_test.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/unit/math/mix/fun/pow_part2_test.cpp b/test/unit/math/mix/fun/pow_part2_test.cpp index bc4d9d6160d..d3805f61cba 100644 --- a/test/unit/math/mix/fun/pow_part2_test.cpp +++ b/test/unit/math/mix/fun/pow_part2_test.cpp @@ -4,9 +4,8 @@ #include TEST(mathMixFun, complexPow) { - auto f = [](const auto& x1, const auto& x2) { - return stan::math::pow(x1, x2); - }; + auto f + = [](const auto& x1, const auto& x2) { return stan::math::pow(x1, x2); }; stan::test::ad_tolerances tols; tols.hessian_hessian_ = 5e-3; tols.hessian_fvar_hessian_ = 5e-3; From 836a5efaed69da39b356d261c15fd10d0ea3b608 Mon Sep 17 00:00:00 2001 From: Steve Bronder Date: Thu, 26 Sep 2024 13:36:00 -0400 Subject: [PATCH 07/13] Removes many of the signatures for pow. Fix expr tests so they can work with python > 3 --- stan/math/fwd/fun/pow.hpp | 234 ++++++++---------------------- stan/math/mix.hpp | 14 +- stan/math/prim/fun/pow.hpp | 26 +++- stan/math/prim/meta/base_type.hpp | 1 - stan/math/prim/meta/is_var.hpp | 2 +- stan/math/rev/fun/pow.hpp | 214 ++++++--------------------- test/code_generator.py | 2 +- 7 files changed, 130 insertions(+), 363 deletions(-) diff --git a/stan/math/fwd/fun/pow.hpp b/stan/math/fwd/fun/pow.hpp index ce6af077619..1d2f1964df2 100644 --- a/stan/math/fwd/fun/pow.hpp +++ b/stan/math/fwd/fun/pow.hpp @@ -3,10 +3,11 @@ #include #include -#include #include #include #include +#include +#include #include #include #include @@ -14,49 +15,54 @@ namespace stan { namespace math { - -template -inline fvar pow(const fvar& x1, const fvar& x2) { - using std::log; - using std::pow; - T pow_x1_x2(pow(x1.val_, x2.val_)); - return fvar(pow_x1_x2, (x2.d_ * log(x1.val_) + x2.val_ * x1.d_ / x1.val_) - * pow_x1_x2); -} - -template > -inline fvar pow(const U& x1, const fvar& x2) { +/* + * + * @tparam T1 Either an `fvar`, `arithmetic`, or `complex` type with an inner `fvar` or `arithmetic` type. + * @tparam T2 Either a `fvar`, `arithmetic`, or `complex` type with an inner `fvar` or `arithmetic` type. + * @param x1 Base variable. + * @param x2 Exponent variable. + * @return Base raised to the exponent. + */ +template , base_type_t>* = nullptr, + require_all_stan_scalar_t* = nullptr> +inline auto pow(const T1& x1, const T2& x2) { using std::log; using std::pow; - T u = pow(x1, x2.val_); - return fvar(u, x2.d_ * log(x1) * u); -} - -template > -inline fvar pow(const fvar& x1, const U& x2) { - using std::pow; - using std::sqrt; - if (x2 == -2) { - return inv_square(x1); - } - if (x2 == -1) { - return inv(x1); - } - if (x2 == -0.5) { - return inv_sqrt(x1); - } - if (x2 == 0.5) { - return sqrt(x1); - } - if (x2 == 1.0) { - return x1; + if constexpr (is_complex::value || is_complex::value) { + return internal::complex_pow(x1, x2); + } else if constexpr (is_fvar::value && is_fvar::value) { + auto pow_x1_x2(stan::math::pow(x1.val_, x2.val_)); + return T1(pow_x1_x2, (x2.d_ * stan::math::log(x1.val_) + x2.val_ * x1.d_ / x1.val_) + * pow_x1_x2); + } else if constexpr (is_fvar::value) { + auto u = stan::math::pow(x1, x2.val_); + return T2(u, x2.d_ * stan::math::log(x1) * u); + } else { + using std::sqrt; + if (x2 == -2) { + return stan::math::inv_square(x1); + } + if (x2 == -1) { + return stan::math::inv(x1); + } + if (x2 == -0.5) { + return stan::math::inv_sqrt(x1); + } + if (x2 == 0.5) { + return stan::math::sqrt(x1); + } + if (x2 == 1.0) { + return x1; + } + if (x2 == 2.0) { + return stan::math::square(x1); + } + return T1(stan::math::pow(x1.val_, x2), x1.d_ * x2 * stan::math::pow(x1.val_, x2 - 1)); } - if (x2 == 2.0) { - return square(x1); - } - return fvar(pow(x1.val_, x2), x1.d_ * x2 * pow(x1.val_, x2 - 1)); } + // must uniquely match all pairs of: // { complex>, complex, fvar, T } // with at least one fvar and at least one complex, where T is arithmetic: @@ -70,148 +76,26 @@ inline fvar pow(const fvar& x1, const U& x2) { // 8) fvar, complex // 9) T, complex> -/** - * Return the first argument raised to the power of the second argument. - * - * @param x first argument - * @param y second argument - * @return first argument to the power of the second argument - */ -template -inline std::complex> pow(const std::complex>& x, - const std::complex>& y) { - return internal::complex_pow(x, y); -} -/** - * Return the first argument raised to the power of the second argument. - * - * @tparam V autodiff value type - * @tparam T arithmetic type - * @param x first argument - * @param y second argument - * @return first argument to the power of the second argument - */ -template > -inline std::complex> pow(const std::complex>& x, - const std::complex& y) { - return internal::complex_pow(x, y); -} -/** - * Return the first argument raised to the power of the second argument. - * - * @tparam V autodiff value type - * @param x first argument - * @param y second argument - * @return first argument to the power of the second argument - */ -template -inline std::complex> pow(const std::complex>& x, - const fvar& y) { - return internal::complex_pow(x, y); -} /** - * Return the first argument raised to the power of the second argument. - * - * @tparam V autodiff value type - * @tparam T arithmetic type - * @param x first argument - * @param y second argument - * @return first argument to the power of the second argument - */ -template > -inline std::complex> pow(const std::complex>& x, const T& y) { - return internal::complex_pow(x, y); -} - -/** - * Return the first argument raised to the power of the second argument. - * - * @tparam V autodiff value type - * @tparam T arithmetic type - * @param x first argument - * @param y second argument - * @return first argument to the power of the second argument - */ -template > -inline std::complex> pow(const std::complex& x, - const std::complex>& y) { - return internal::complex_pow(x, y); -} - -/** - * Return the first argument raised to the power of the second argument. - * - * @tparam V autodiff value type - * @tparam T arithmetic type - * @param x first argument - * @param y second argument - * @return first argument to the power of the second argument - */ -template > -inline std::complex> pow(const std::complex& x, const fvar& y) { - return internal::complex_pow(x, y); -} - -/** - * Return the first argument raised to the power of the second argument. - * - * @tparam V autodiff value type - * @param x first argument - * @param y second argument - * @return first argument to the power of the second argument - */ -template -inline std::complex> pow(const fvar& x, - const std::complex>& y) { - return internal::complex_pow(x, y); -} - -/** - * Return the first argument raised to the power of the second argument. - * - * @tparam V autodiff value type - * @tparam T arithmetic type - * @param x first argument - * @param y second argument - * @return first argument to the power of the second argument - */ -template > -inline std::complex> pow(const fvar& x, const std::complex& y) { - return internal::complex_pow(x, y); -} - -/** - * Return the first argument raised to the power of the second argument. - * - * @tparam V autodiff value type - * @tparam T real type (`fvar` or arithmetic) - * @param x first argument - * @param y second argument - * @return first argument to the power of the second argument - */ -template > -inline std::complex> pow(const T& x, const std::complex>& y) { - return internal::complex_pow(x, y); -} - -/** - * Return the first argument raised to the power of the second argument. - * - * Note: this overload is required because gcc still provides the - * C++99 template function `pow(complex, int)`, which introduces - * an ambiguity. + * Returns the elementwise raising of the first argument to the power of the + * second argument. * - * @tparam T autodiff value type - * @param x first argument - * @param y second argument - * @return first argument to the power of the second argument + * @tparam T1 type of first argument + * @tparam T2 type of second argument + * @param a first argument + * @param b second argument + * @return the elementwise raising of the first argument to the power of the + * second argument. */ -template -inline std::complex> pow(const std::complex>& x, int y) { - return internal::complex_pow(x, y); +template * = nullptr, + require_all_not_matrix_st* = nullptr, + require_any_fvar_t, base_type_t>* = nullptr> +inline auto pow(const T1& a, const T2& b) { + return apply_scalar_binary( + a, b, [](const auto& c, const auto& d) { return stan::math::pow(c, d); }); } } // namespace math diff --git a/stan/math/mix.hpp b/stan/math/mix.hpp index bb7ff7c0b1f..876916443ce 100644 --- a/stan/math/mix.hpp +++ b/stan/math/mix.hpp @@ -5,6 +5,13 @@ #include #include +#include +#include +#include +#include +#include +#include + #include #include #include @@ -17,13 +24,6 @@ #include #endif -#include -#include -#include -#include -#include -#include - #include #endif diff --git a/stan/math/prim/fun/pow.hpp b/stan/math/prim/fun/pow.hpp index 72d935e15c0..c58e0bc5b16 100644 --- a/stan/math/prim/fun/pow.hpp +++ b/stan/math/prim/fun/pow.hpp @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -40,13 +41,28 @@ inline complex_return_t complex_pow(const U& x, const V& y) { * argument. */ template , std::is_arithmetic>, - disjunction, std::is_arithmetic>>* = nullptr> -inline auto pow(const T1& a, const T2& b) { + require_arithmetic_t* = nullptr, require_arithmetic_t* = nullptr> +inline auto pow(const std::complex& a, const std::complex& b) { return std::pow(a, b); } +template * = nullptr, require_arithmetic_t* = nullptr> +inline auto pow(const T1& a, const std::complex& b) { + return std::pow(a, b); +} + +template * = nullptr, require_arithmetic_t* = nullptr> +inline auto pow(const std::complex& a, const T2& b) { + return std::pow(a, b); +} + +template * = nullptr, require_arithmetic_t* = nullptr> +inline auto pow(const T1& a, const T2& b) { + return std::pow(a, b); +} /** * Returns the elementwise raising of the first argument to the power of the * second argument. @@ -60,7 +76,7 @@ inline auto pow(const T1& a, const T2& b) { */ template * = nullptr, require_all_not_matrix_st* = nullptr, - require_all_st_arithmetic* = nullptr> + require_all_arithmetic_t, base_type_t>* = nullptr> inline auto pow(const T1& a, const T2& b) { return apply_scalar_binary( a, b, [](const auto& c, const auto& d) { return stan::math::pow(c, d); }); diff --git a/stan/math/prim/meta/base_type.hpp b/stan/math/prim/meta/base_type.hpp index 80cebb44cb0..2a6ae4afef9 100644 --- a/stan/math/prim/meta/base_type.hpp +++ b/stan/math/prim/meta/base_type.hpp @@ -4,7 +4,6 @@ #include #include #include -#include #include #include #include diff --git a/stan/math/prim/meta/is_var.hpp b/stan/math/prim/meta/is_var.hpp index 7f3f9c7afac..3f9cec31f7c 100644 --- a/stan/math/prim/meta/is_var.hpp +++ b/stan/math/prim/meta/is_var.hpp @@ -1,7 +1,7 @@ #ifndef STAN_MATH_PRIM_META_IS_VAR_HPP #define STAN_MATH_PRIM_META_IS_VAR_HPP -#include +#include #include diff --git a/stan/math/rev/fun/pow.hpp b/stan/math/rev/fun/pow.hpp index afc808608d2..a5cc42a6474 100644 --- a/stan/math/rev/fun/pow.hpp +++ b/stan/math/rev/fun/pow.hpp @@ -61,45 +61,51 @@ namespace math { \end{cases} \f] * + * @tparam Scal1 Either a `var`, `arithmetic`, or `complex` type with an inner `var` or `arithmetic` type. + * @tparam Scal2 Either a `var`, `arithmetic`, or `complex` type with an inner `var` or `arithmetic` type. * @param base Base variable. * @param exponent Exponent variable. * @return Base raised to the exponent. */ template * = nullptr, + require_any_var_t, base_type_t>* = nullptr, require_all_stan_scalar_t* = nullptr> -inline var pow(const Scal1& base, const Scal2& exponent) { - if (is_constant::value) { - if (exponent == 0.5) { - return sqrt(base); - } else if (exponent == 1.0) { - return base; - } else if (exponent == 2.0) { - return square(base); - } else if (exponent == -2.0) { - return inv_square(base); - } else if (exponent == -1.0) { - return inv(base); - } else if (exponent == -0.5) { - return inv_sqrt(base); +inline auto pow(const Scal1& base, const Scal2& exponent) { + if constexpr (is_complex::value || is_complex::value) { + return internal::complex_pow(base, exponent); + } else { + if constexpr (is_constant::value) { + if (exponent == 0.5) { + return sqrt(base); + } else if (exponent == 1.0) { + return base; + } else if (exponent == 2.0) { + return square(base); + } else if (exponent == -2.0) { + return inv_square(base); + } else if (exponent == -1.0) { + return inv(base); + } else if (exponent == -0.5) { + return inv_sqrt(base); + } } - } - return make_callback_var( - std::pow(value_of(base), value_of(exponent)), - [base, exponent](auto&& vi) mutable { - if (value_of(base) == 0.0) { - return; // partials zero, avoids 0 & log(0) - } - const double vi_mul = vi.adj() * vi.val(); + return make_callback_var( + std::pow(value_of(base), value_of(exponent)), + [base, exponent](auto&& vi) mutable { + if (value_of(base) == 0.0) { + return; // partials zero, avoids 0 & log(0) + } + const double vi_mul = vi.adj() * vi.val(); - if (!is_constant::value) { - forward_as(base).adj() - += vi_mul * value_of(exponent) / value_of(base); - } - if (!is_constant::value) { - forward_as(exponent).adj() += vi_mul * std::log(value_of(base)); - } - }); + if (!is_constant::value) { + forward_as(base).adj() + += vi_mul * value_of(exponent) / value_of(base); + } + if (!is_constant::value) { + forward_as(exponent).adj() += vi_mul * std::log(value_of(base)); + } + }); + } } /** @@ -164,6 +170,7 @@ inline auto pow(const Mat1& base, const Mat2& exponent) { * @tparam Mat1 An Eigen type deriving from Eigen::EigenBase or * a `var_value` with inner Eigen type as defined above. The `scalar_type` * must be a `var` or Arithmetic. + * @tparam Scal1 An arithmetic type or a `var_value` with inner arithmetic type. * @param base Base variable. * @param exponent Exponent variable. * @return Base raised to the exponent. @@ -225,10 +232,10 @@ inline auto pow(const Mat1& base, const Scal1& exponent) { * \f$\frac{d}{d y} \mbox{pow}(c, y) = c^y \log c \f$. * * - * @tparam Mat An Eigen type deriving from Eigen::EigenBase or + * @tparam Mat1 An Eigen type deriving from Eigen::EigenBase or * a `var_value` with inner Eigen type as defined above. The `scalar_type` * must be a `var`. - * + * @tparam Scal1 An arithmetic type or a `var_value` with inner arithmetic type. * @param base Base scalar. * @param exponent Exponent variable. * @return Base raised to the exponent. @@ -261,145 +268,6 @@ inline auto pow(Scal1 base, const Mat1& exponent) { return ret_type(ret); } -// must uniquely match all pairs of { complex, complex, var, T } -// with at least one var and at least one complex, where T is arithmetic: -// 1) complex, complex -// 2) complex, complex -// 3) complex, var -// 4) complex, T -// 5) complex, complex -// 6) complex, var -// 7) var, complex -// 8) var, complex -// 9) T, complex - -/** - * Return the first argument raised to the power of the second argument. - * - * @param x first argument - * @param y second argument - * @return first argument to the power of the second argument - */ -inline std::complex pow(const std::complex& x, - const std::complex& y) { - return internal::complex_pow(x, y); -} - -/** - * Return the first argument raised to the power of the second argument. - * - * @tparam T arithmetic type - * @param x first argument - * @param y second argument - * @return first argument to the power of the second argument - */ -template > -inline std::complex pow(const std::complex& x, - const std::complex y) { - return internal::complex_pow(x, y); -} - -/** - * Return the first argument raised to the power of the second argument. - * - * @param x first argument - * @param y second argument - * @return first argument to the power of the second argument - */ -inline std::complex pow(const std::complex& x, const var& y) { - return internal::complex_pow(x, y); -} - -/** - * Return the first argument raised to the power of the second argument. - * - * @tparam T arithmetic type - * @param x first argument - * @param y second argument - * @return first argument to the power of the second argument - */ -template > -inline std::complex pow(const std::complex& x, const T& y) { - return internal::complex_pow(x, y); -} - -/** - * Return the first argument raised to the power of the second argument. - * - * @tparam T arithmetic type - * @param x first argument - * @param y second argument - * @return first argument to the power of the second argument - */ -template > -inline std::complex pow(std::complex x, const std::complex& y) { - return internal::complex_pow(x, y); -} - -/** - * Return the first argument raised to the power of the second argument. - * - * @tparam T arithmetic type - * @param x first argument - * @param y second argument - * @return first argument to the power of the second argument - */ -template > -inline std::complex pow(std::complex x, const var& y) { - return internal::complex_pow(x, y); -} - -/** - * Return the first argument raised to the power of the second argument. - * - * @param x first argument - * @param y second argument - * @return first argument to the power of the second argument - */ -inline std::complex pow(const var& x, const std::complex& y) { - return internal::complex_pow(x, y); -} - -/** - * Return the first argument raised to the power of the second argument. - * - * @tparam T arithmetic type - * @param x first argument - * @param y second argument - * @return first argument to the power of the second argument - */ -template > -inline std::complex pow(const var& x, std::complex y) { - return internal::complex_pow(x, y); -} - -/** - * Return the first argument raised to the power of the second argument. - * - * @tparam T arithmetic type - * @param x first argument - * @param y second argument - * @return first argument to the power of the second argument - */ -template > -inline std::complex pow(const T& x, const std::complex& y) { - return internal::complex_pow(x, y); -} - -/** - * Return the first argument raised to the power of the second argument. - * - * Note: this overload is required because gcc still provides the - * C++99 template function `pow(complex, int)`, which introduces - * an ambiguity. - * - * @param x first argument - * @param y second argument - * @return first argument to the power of the second argument - */ -inline std::complex pow(const std::complex& x, int y) { - return internal::complex_pow(x, y); -} /** * Returns the elementwise raising of the first argument to the power of the @@ -414,7 +282,7 @@ inline std::complex pow(const std::complex& x, int y) { */ template * = nullptr, require_all_not_matrix_st* = nullptr, - require_any_not_st_arithmetic* = nullptr> + require_any_var_t, base_type_t>* = nullptr> inline auto pow(const T1& a, const T2& b) { return apply_scalar_binary( a, b, [](const auto& c, const auto& d) { return stan::math::pow(c, d); }); diff --git a/test/code_generator.py b/test/code_generator.py index 2671b72429d..895267a7742 100644 --- a/test/code_generator.py +++ b/test/code_generator.py @@ -75,7 +75,7 @@ def build_arguments(self, signature_parser, arg_overloads, size): # The first case here is used for the array initializers in sig_utils.special_arg_values # Everything else uses the second case - if number_nested_arrays > 0 and isinstance(value, collections.Iterable): + if number_nested_arrays > 0 and isinstance(value, collections.abc.Iterable): arg = statement_types.ArrayVariable( overload, "array" + suffix, From a01c01fe6d89f0718edeeda449466d6a6f57a07c Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Thu, 26 Sep 2024 13:37:38 -0400 Subject: [PATCH 08/13] [Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1 --- stan/math/fwd/fun/pow.hpp | 24 ++++++++++++------------ stan/math/prim/fun/pow.hpp | 16 ++++++++-------- stan/math/rev/fun/pow.hpp | 38 ++++++++++++++++++++------------------ 3 files changed, 40 insertions(+), 38 deletions(-) diff --git a/stan/math/fwd/fun/pow.hpp b/stan/math/fwd/fun/pow.hpp index 1d2f1964df2..ed1c42adf1e 100644 --- a/stan/math/fwd/fun/pow.hpp +++ b/stan/math/fwd/fun/pow.hpp @@ -17,15 +17,17 @@ namespace stan { namespace math { /* * - * @tparam T1 Either an `fvar`, `arithmetic`, or `complex` type with an inner `fvar` or `arithmetic` type. - * @tparam T2 Either a `fvar`, `arithmetic`, or `complex` type with an inner `fvar` or `arithmetic` type. + * @tparam T1 Either an `fvar`, `arithmetic`, or `complex` type with an inner + * `fvar` or `arithmetic` type. + * @tparam T2 Either a `fvar`, `arithmetic`, or `complex` type with an inner + * `fvar` or `arithmetic` type. * @param x1 Base variable. * @param x2 Exponent variable. * @return Base raised to the exponent. */ -template , base_type_t>* = nullptr, - require_all_stan_scalar_t* = nullptr> +template , base_type_t>* = nullptr, + require_all_stan_scalar_t* = nullptr> inline auto pow(const T1& x1, const T2& x2) { using std::log; using std::pow; @@ -33,8 +35,9 @@ inline auto pow(const T1& x1, const T2& x2) { return internal::complex_pow(x1, x2); } else if constexpr (is_fvar::value && is_fvar::value) { auto pow_x1_x2(stan::math::pow(x1.val_, x2.val_)); - return T1(pow_x1_x2, (x2.d_ * stan::math::log(x1.val_) + x2.val_ * x1.d_ / x1.val_) - * pow_x1_x2); + return T1(pow_x1_x2, + (x2.d_ * stan::math::log(x1.val_) + x2.val_ * x1.d_ / x1.val_) + * pow_x1_x2); } else if constexpr (is_fvar::value) { auto u = stan::math::pow(x1, x2.val_); return T2(u, x2.d_ * stan::math::log(x1) * u); @@ -58,11 +61,11 @@ inline auto pow(const T1& x1, const T2& x2) { if (x2 == 2.0) { return stan::math::square(x1); } - return T1(stan::math::pow(x1.val_, x2), x1.d_ * x2 * stan::math::pow(x1.val_, x2 - 1)); + return T1(stan::math::pow(x1.val_, x2), + x1.d_ * x2 * stan::math::pow(x1.val_, x2 - 1)); } } - // must uniquely match all pairs of: // { complex>, complex, fvar, T } // with at least one fvar and at least one complex, where T is arithmetic: @@ -76,9 +79,6 @@ inline auto pow(const T1& x1, const T2& x2) { // 8) fvar, complex // 9) T, complex> - - - /** * Returns the elementwise raising of the first argument to the power of the * second argument. diff --git a/stan/math/prim/fun/pow.hpp b/stan/math/prim/fun/pow.hpp index c58e0bc5b16..95809bd8830 100644 --- a/stan/math/prim/fun/pow.hpp +++ b/stan/math/prim/fun/pow.hpp @@ -40,26 +40,26 @@ inline complex_return_t complex_pow(const U& x, const V& y) { * @return the first argument raised to the power of the second * argument. */ -template * = nullptr, require_arithmetic_t* = nullptr> +template * = nullptr, + require_arithmetic_t* = nullptr> inline auto pow(const std::complex& a, const std::complex& b) { return std::pow(a, b); } -template * = nullptr, require_arithmetic_t* = nullptr> +template * = nullptr, + require_arithmetic_t* = nullptr> inline auto pow(const T1& a, const std::complex& b) { return std::pow(a, b); } -template * = nullptr, require_arithmetic_t* = nullptr> +template * = nullptr, + require_arithmetic_t* = nullptr> inline auto pow(const std::complex& a, const T2& b) { return std::pow(a, b); } -template * = nullptr, require_arithmetic_t* = nullptr> +template * = nullptr, + require_arithmetic_t* = nullptr> inline auto pow(const T1& a, const T2& b) { return std::pow(a, b); } diff --git a/stan/math/rev/fun/pow.hpp b/stan/math/rev/fun/pow.hpp index a5cc42a6474..669f53908b4 100644 --- a/stan/math/rev/fun/pow.hpp +++ b/stan/math/rev/fun/pow.hpp @@ -61,8 +61,10 @@ namespace math { \end{cases} \f] * - * @tparam Scal1 Either a `var`, `arithmetic`, or `complex` type with an inner `var` or `arithmetic` type. - * @tparam Scal2 Either a `var`, `arithmetic`, or `complex` type with an inner `var` or `arithmetic` type. + * @tparam Scal1 Either a `var`, `arithmetic`, or `complex` type with an inner + `var` or `arithmetic` type. + * @tparam Scal2 Either a `var`, `arithmetic`, or `complex` type with an inner + `var` or `arithmetic` type. * @param base Base variable. * @param exponent Exponent variable. * @return Base raised to the exponent. @@ -89,22 +91,23 @@ inline auto pow(const Scal1& base, const Scal2& exponent) { return inv_sqrt(base); } } - return make_callback_var( - std::pow(value_of(base), value_of(exponent)), - [base, exponent](auto&& vi) mutable { - if (value_of(base) == 0.0) { - return; // partials zero, avoids 0 & log(0) - } - const double vi_mul = vi.adj() * vi.val(); + return make_callback_var(std::pow(value_of(base), value_of(exponent)), + [base, exponent](auto&& vi) mutable { + if (value_of(base) == 0.0) { + return; // partials zero, avoids 0 & log(0) + } + const double vi_mul = vi.adj() * vi.val(); - if (!is_constant::value) { - forward_as(base).adj() - += vi_mul * value_of(exponent) / value_of(base); - } - if (!is_constant::value) { - forward_as(exponent).adj() += vi_mul * std::log(value_of(base)); - } - }); + if (!is_constant::value) { + forward_as(base).adj() + += vi_mul * value_of(exponent) + / value_of(base); + } + if (!is_constant::value) { + forward_as(exponent).adj() + += vi_mul * std::log(value_of(base)); + } + }); } } @@ -268,7 +271,6 @@ inline auto pow(Scal1 base, const Mat1& exponent) { return ret_type(ret); } - /** * Returns the elementwise raising of the first argument to the power of the * second argument. From e0514cd109a91bc5142735c16c41f75c8a000b8b Mon Sep 17 00:00:00 2001 From: Steve Bronder Date: Thu, 26 Sep 2024 13:59:14 -0400 Subject: [PATCH 09/13] fix headers --- stan/math/fwd/fun/pow.hpp | 17 +---------------- stan/math/prim/meta/base_type.hpp | 1 + stan/math/prim/meta/is_var.hpp | 2 +- 3 files changed, 3 insertions(+), 17 deletions(-) diff --git a/stan/math/fwd/fun/pow.hpp b/stan/math/fwd/fun/pow.hpp index 1d2f1964df2..8a0a17ab9dd 100644 --- a/stan/math/fwd/fun/pow.hpp +++ b/stan/math/fwd/fun/pow.hpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -63,22 +64,6 @@ inline auto pow(const T1& x1, const T2& x2) { } -// must uniquely match all pairs of: -// { complex>, complex, fvar, T } -// with at least one fvar and at least one complex, where T is arithmetic: -// 1) complex>, complex> -// 2) complex>, complex -// 3) complex>, fvar -// 4) complex>, T -// 5) complex, complex> -// 6) complex, fvar -// 7) fvar, complex> -// 8) fvar, complex -// 9) T, complex> - - - - /** * Returns the elementwise raising of the first argument to the power of the * second argument. diff --git a/stan/math/prim/meta/base_type.hpp b/stan/math/prim/meta/base_type.hpp index 2a6ae4afef9..80cebb44cb0 100644 --- a/stan/math/prim/meta/base_type.hpp +++ b/stan/math/prim/meta/base_type.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include diff --git a/stan/math/prim/meta/is_var.hpp b/stan/math/prim/meta/is_var.hpp index 3f9cec31f7c..7f3f9c7afac 100644 --- a/stan/math/prim/meta/is_var.hpp +++ b/stan/math/prim/meta/is_var.hpp @@ -1,7 +1,7 @@ #ifndef STAN_MATH_PRIM_META_IS_VAR_HPP #define STAN_MATH_PRIM_META_IS_VAR_HPP -#include +#include #include From 2d1dbbdbdff0fe69eefdf2119b1b2f4551142cd4 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 27 Sep 2024 10:04:25 -0400 Subject: [PATCH 10/13] Resolve merge conflict --- stan/math/fwd/fun/pow.hpp | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/stan/math/fwd/fun/pow.hpp b/stan/math/fwd/fun/pow.hpp index 631f2dfd37a..c2056c9f5cf 100644 --- a/stan/math/fwd/fun/pow.hpp +++ b/stan/math/fwd/fun/pow.hpp @@ -67,22 +67,6 @@ inline auto pow(const T1& x1, const T2& x2) { } } -<<<<<<< HEAD -======= -// must uniquely match all pairs of: -// { complex>, complex, fvar, T } -// with at least one fvar and at least one complex, where T is arithmetic: -// 1) complex>, complex> -// 2) complex>, complex -// 3) complex>, fvar -// 4) complex>, T -// 5) complex, complex> -// 6) complex, fvar -// 7) fvar, complex> -// 8) fvar, complex -// 9) T, complex> ->>>>>>> origin/fix/pow-overload-resolution - /** * Returns the elementwise raising of the first argument to the power of the * second argument. From b31f4ee15b6a51effc2863fa5cd84cf9597e35a9 Mon Sep 17 00:00:00 2001 From: Steve Bronder Date: Fri, 27 Sep 2024 10:36:23 -0400 Subject: [PATCH 11/13] update to 17 --- make/compiler_flags | 28 ++-------------------------- 1 file changed, 2 insertions(+), 26 deletions(-) diff --git a/make/compiler_flags b/make/compiler_flags index d6f41aec41b..26e15a0e989 100644 --- a/make/compiler_flags +++ b/make/compiler_flags @@ -119,33 +119,9 @@ INC_GTEST ?= -I $(GTEST)/include -I $(GTEST) ## setup precompiler options CPPFLAGS_BOOST ?= -DBOOST_DISABLE_ASSERTS CPPFLAGS_SUNDIALS ?= -DNO_FPRINTF_OUTPUT $(CPPFLAGS_OPTIM_SUNDIALS) $(CXXFLAGS_FLTO_SUNDIALS) -#CPPFLAGS_GTEST ?= -STAN_HAS_CXX17 ?= false -ifeq ($(CXX_TYPE), gcc) - GCC_GE_73 := $(shell [ $(CXX_MAJOR) -gt 7 -o \( $(CXX_MAJOR) -eq 7 -a $(CXX_MINOR) -ge 1 \) ] && echo true) - ifeq ($(GCC_GE_73),true) - STAN_HAS_CXX17 := true - endif -else ifeq ($(CXX_TYPE), clang) - CLANG_GE_5 := $(shell [ $(CXX_MAJOR) -gt 5 -o \( $(CXX_MAJOR) -eq 5 -a $(CXX_MINOR) -ge 0 \) ] && echo true) - ifeq ($(CLANG_GE_5),true) - STAN_HAS_CXX17 := true - endif -else ifeq ($(CXX_TYPE), mingw32-gcc) - MINGW_GE_50 := $(shell [ $(CXX_MAJOR) -gt 5 -o \( $(CXX_MAJOR) -eq 5 -a $(CXX_MINOR) -ge 0 \) ] && echo true) - ifeq ($(MINGW_GE_50),true) - STAN_HAS_CXX17 := true - endif -endif -ifeq ($(STAN_HAS_CXX17), true) - CXXFLAGS_LANG ?= -std=c++17 - CXXFLAGS_STANDARD ?= c++17 -else - $(warning "Stan cannot detect if your compiler has the C++17 standard. If it does, please set STAN_HAS_CXX17=true in your make/local file. C++17 support is mandatory in the next release of Stan. Defaulting to C++14") - CXXFLAGS_LANG ?= -std=c++1y - CXXFLAGS_STANDARD ?= c++1y -endif +CXXFLAGS_LANG ?= -std=c++17 +CXXFLAGS_STANDARD ?= c++17 #CXXFLAGS_BOOST ?= CXXFLAGS_SUNDIALS ?= -pipe $(CXXFLAGS_OPTIM_SUNDIALS) $(CPPFLAGS_FLTO_SUNDIALS) #CXXFLAGS_GTEST From d328e0b18d6f24ee0ddd91c3d4e8d1f851739210 Mon Sep 17 00:00:00 2001 From: Steve Bronder Date: Fri, 27 Sep 2024 16:31:17 -0400 Subject: [PATCH 12/13] add accumulator for sum --- stan/math/fwd/fun/accumulator.hpp | 91 +++++++++++++++++++++++++++++++ stan/math/fwd/fun/sum.hpp | 15 ++--- stan/math/prim/fun/sum.hpp | 2 +- 3 files changed, 100 insertions(+), 8 deletions(-) diff --git a/stan/math/fwd/fun/accumulator.hpp b/stan/math/fwd/fun/accumulator.hpp index e1bbaea5d3c..251b469d770 100644 --- a/stan/math/fwd/fun/accumulator.hpp +++ b/stan/math/fwd/fun/accumulator.hpp @@ -6,7 +6,98 @@ #include #include #include + #include #include +namespace stan { +namespace math { +template +class accumulator; +/** + * Class to accumulate values and eventually return their sum. If + * no values are ever added, the return value is 0. + * + * This class is useful for speeding up autodiff of long sums + * because it uses the sum() operation (either from + * stan::math or one defined by argument-dependent lookup. + * + * @tparam T Type of scalar added + */ +template +class accumulator> { + private: + std::vector buf_; + + public: + /** + * Add the specified arithmetic type value to the buffer after + * static casting it to the class type T. + * + *

See the std library doc for std::is_arithmetic + * for information on what counts as an arithmetic type. + * + * @tparam S Type of argument + * @param x Value to add + */ + template > + inline void add(S x) { + buf_.push_back(x); + } + + /** + * Add each entry in the specified matrix, vector, or row vector + * of values to the buffer. + * + * @tparam S type of the matrix + * @param m Matrix of values to add + */ + template * = nullptr> + inline void add(const S& m) { + buf_.push_back(stan::math::sum(m)); + } + + /** + * Recursively add each entry in the specified standard vector + * to the buffer. This will allow vectors of primitives, + * autodiff variables to be added; if the vector entries + * are collections, their elements are recursively added. + * + * @tparam S Type of value to recursively add. + * @param xs Vector of entries to add + */ + template + inline void add(const std::vector& xs) { + for (size_t i = 0; i < xs.size(); ++i) { + this->add(xs[i]); + } + } + +#ifdef STAN_OPENCL + + /** + * Sum each entry and then push to the buffer. + * @tparam S A Type inheriting from `matrix_cl_base` + * @param xs An OpenCL matrix + */ + template * = nullptr> + inline void add(const S& xs) { + buf_.push_back(stan::math::sum(xs)); + } + +#endif + + /** + * Return the sum of the accumulated values. + * + * @return Sum of accumulated values. + */ + inline T sum() const { return stan::math::sum(buf_); } +}; + +} // namespace math +} // namespace stan + + #endif diff --git a/stan/math/fwd/fun/sum.hpp b/stan/math/fwd/fun/sum.hpp index 2ae7887c1ca..36eef6ad687 100644 --- a/stan/math/fwd/fun/sum.hpp +++ b/stan/math/fwd/fun/sum.hpp @@ -1,10 +1,11 @@ #ifndef STAN_MATH_FWD_FUN_SUM_HPP #define STAN_MATH_FWD_FUN_SUM_HPP +#include +#include #include #include #include -#include #include namespace stan { @@ -18,18 +19,18 @@ namespace math { * @param m Vector. * @return Sum of vector entries. */ -template -inline fvar sum(const std::vector>& m) { +template * = nullptr> +inline auto sum(const std::vector& m) { if (m.size() == 0) { - return 0.0; + return T(0.0); } - std::vector vals(m.size()); - std::vector tans(m.size()); + std::vector> vals(m.size()); + std::vector> tans(m.size()); for (size_t i = 0; i < m.size(); ++i) { vals[i] = m[i].val(); tans[i] = m[i].d(); } - return fvar(sum(vals), sum(tans)); + return T(sum(vals), sum(tans)); } /** diff --git a/stan/math/prim/fun/sum.hpp b/stan/math/prim/fun/sum.hpp index f1256c375e4..0440997fda0 100644 --- a/stan/math/prim/fun/sum.hpp +++ b/stan/math/prim/fun/sum.hpp @@ -29,7 +29,7 @@ inline T sum(T&& m) { * @param m Standard vector to sum. * @return Sum of elements. */ -template * = nullptr> +template * = nullptr> inline T sum(const std::vector& m) { return std::accumulate(m.begin(), m.end(), T{0}); } From 35459b2eb43087bc00eee43b0fba24e032c0ec84 Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Fri, 27 Sep 2024 16:32:25 -0400 Subject: [PATCH 13/13] [Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1 --- stan/math/fwd/fun/accumulator.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/stan/math/fwd/fun/accumulator.hpp b/stan/math/fwd/fun/accumulator.hpp index 251b469d770..a5ceb9f6dbc 100644 --- a/stan/math/fwd/fun/accumulator.hpp +++ b/stan/math/fwd/fun/accumulator.hpp @@ -99,5 +99,4 @@ class accumulator> { } // namespace math } // namespace stan - #endif