diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp index 8e3ab1d8fd30b..6bcc3ec739821 100644 --- a/flang/lib/Evaluate/fold-real.cpp +++ b/flang/lib/Evaluate/fold-real.cpp @@ -52,15 +52,28 @@ template class Norm2Accumulator { const Constant &array, const Constant &maxAbs, Rounding rounding) : array_{array}, maxAbs_{maxAbs}, rounding_{rounding} {}; void operator()(Scalar &element, const ConstantSubscripts &at) { - // Kahan summation of scaled elements + // Kahan summation of scaled elements: + // Naively, + // NORM2(A(:)) = SQRT(SUM(A(:)**2)) + // For any T > 0, we have mathematically + // SQRT(SUM(A(:)**2)) + // = SQRT(T**2 * (SUM(A(:)**2) / T**2)) + // = SQRT(T**2 * SUM(A(:)**2 / T**2)) + // = SQRT(T**2 * SUM((A(:)/T)**2)) + // = SQRT(T**2) * SQRT(SUM((A(:)/T)**2)) + // = T * SQRT(SUM((A(:)/T)**2)) + // By letting T = MAXVAL(ABS(A)), we ensure that + // ALL(ABS(A(:)/T) <= 1), so ALL((A(:)/T)**2 <= 1), and the SUM will + // not overflow unless absolutely necessary. auto scale{maxAbs_.At(maxAbsAt_)}; if (scale.IsZero()) { - // If maxAbs is zero, so are all elements, and result + // Maximum value is zero, and so will the result be. + // Avoid division by zero below. element = scale; } else { auto item{array_.At(at)}; auto scaled{item.Divide(scale).value}; - auto square{item.Multiply(scaled).value}; + auto square{scaled.Multiply(scaled).value}; auto next{square.Add(correction_, rounding_)}; overflow_ |= next.flags.test(RealFlag::Overflow); auto sum{element.Add(next.value, rounding_)}; @@ -73,13 +86,16 @@ template class Norm2Accumulator { } bool overflow() const { return overflow_; } void Done(Scalar &result) { + // result+correction == SUM((data(:)/maxAbs)**2) + // result = maxAbs * SQRT(result+correction) auto corrected{result.Add(correction_, rounding_)}; overflow_ |= corrected.flags.test(RealFlag::Overflow); correction_ = Scalar{}; - auto rescaled{corrected.value.Multiply(maxAbs_.At(maxAbsAt_))}; + auto root{corrected.value.SQRT().value}; + auto product{root.Multiply(maxAbs_.At(maxAbsAt_))}; maxAbs_.IncrementSubscripts(maxAbsAt_); - overflow_ |= rescaled.flags.test(RealFlag::Overflow); - result = rescaled.value.SQRT().value; + overflow_ |= product.flags.test(RealFlag::Overflow); + result = product.value; } private: diff --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h index cff7f54c60d91..0dd55124e6a51 100644 --- a/flang/lib/Evaluate/fold-reduction.h +++ b/flang/lib/Evaluate/fold-reduction.h @@ -228,7 +228,7 @@ template class MaxvalMinvalAccumulator { test.Rewrite(context_, std::move(test)))}; CHECK(folded.has_value()); if (folded->IsTrue()) { - element = array_.At(at); + element = aAt; } } void Done(Scalar &) const {} diff --git a/flang/test/Evaluate/fold-norm2.f90 b/flang/test/Evaluate/fold-norm2.f90 index 30d5289b5a6e3..370532bafaa13 100644 --- a/flang/test/Evaluate/fold-norm2.f90 +++ b/flang/test/Evaluate/fold-norm2.f90 @@ -17,13 +17,20 @@ module m real(dp), parameter :: a(3,4) = & reshape([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], shape(a)) real(dp), parameter :: nAll = norm2(a) - real(dp), parameter :: check_nAll = sqrt(sum(a * a)) + real(dp), parameter :: check_nAll = 11._dp * sqrt(sum((a/11._dp)**2)) logical, parameter :: test_all = nAll == check_nAll real(dp), parameter :: norms1(4) = norm2(a, dim=1) - real(dp), parameter :: check_norms1(4) = sqrt(sum(a * a, dim=1)) + real(dp), parameter :: check_norms1(4) = [ & + 2.236067977499789805051477742381393909454345703125_8, & + 7.07106781186547550532850436866283416748046875_8, & + 1.2206555615733702069292121450416743755340576171875e1_8, & + 1.7378147196982769884243680280633270740509033203125e1_8 ] logical, parameter :: test_norms1 = all(norms1 == check_norms1) real(dp), parameter :: norms2(3) = norm2(a, dim=2) - real(dp), parameter :: check_norms2(3) = sqrt(sum(a * a, dim=2)) + real(dp), parameter :: check_norms2(3) = [ & + 1.1224972160321822656214862945489585399627685546875e1_8, & + 1.28840987267251261272349438513629138469696044921875e1_8, & + 1.4628738838327791427218471653759479522705078125e1_8 ] logical, parameter :: test_norms2 = all(norms2 == check_norms2) logical, parameter :: test_normZ = norm2([0.,0.,0.]) == 0. end