diff --git a/ext/bigdecimal/bigdecimal.c b/ext/bigdecimal/bigdecimal.c index 486aee8..cced09d 100644 --- a/ext/bigdecimal/bigdecimal.c +++ b/ext/bigdecimal/bigdecimal.c @@ -2384,31 +2384,6 @@ BigDecimal_abs(VALUE self) return VpCheckGetValue(c); } -/* call-seq: - * sqrt(n) - * - * Returns the square root of the value. - * - * Result has at least n significant digits. - */ -static VALUE -BigDecimal_sqrt(VALUE self, VALUE nFig) -{ - ENTER(5); - Real *c, *a; - size_t mx, n; - - GUARD_OBJ(a, GetVpValue(self, 1)); - mx = a->Prec * (VpBaseFig() + 1); - - n = check_int_precision(nFig); - n += VpDblFig() + VpBaseFig(); - if (mx <= n) mx = n; - GUARD_OBJ(c, NewZeroWrapLimited(1, mx)); - VpSqrt(c, a); - return VpCheckGetValue(c); -} - /* Return the integer part of the number, as a BigDecimal. */ static VALUE @@ -4578,7 +4553,6 @@ Init_bigdecimal(void) rb_define_method(rb_cBigDecimal, "dup", BigDecimal_clone, 0); rb_define_method(rb_cBigDecimal, "to_f", BigDecimal_to_f, 0); rb_define_method(rb_cBigDecimal, "abs", BigDecimal_abs, 0); - rb_define_method(rb_cBigDecimal, "sqrt", BigDecimal_sqrt, 1); rb_define_method(rb_cBigDecimal, "fix", BigDecimal_fix, 0); rb_define_method(rb_cBigDecimal, "round", BigDecimal_round, -1); rb_define_method(rb_cBigDecimal, "frac", BigDecimal_frac, 0); @@ -4655,8 +4629,6 @@ static int gfCheckVal = 1; /* Value checking flag in VpNmlz() */ static Real *VpConstOne; /* constant 1.0 */ static Real *VpConstPt5; /* constant 0.5 */ -#define maxnr 100UL /* Maximum iterations for calculating sqrt. */ - /* used in VpSqrt() */ /* ETC */ #define MemCmp(x,y,z) memcmp(x,y,z) @@ -7060,74 +7032,6 @@ VpVtoD(double *d, SIGNED_VALUE *e, Real *m) return f; } -/* - * m <- d - */ -VP_EXPORT void -VpDtoV(Real *m, double d) -{ - size_t ind_m, mm; - SIGNED_VALUE ne; - DECDIG i; - double val, val2; - - if (isnan(d)) { - VpSetNaN(m); - goto Exit; - } - if (isinf(d)) { - if (d > 0.0) VpSetPosInf(m); - else VpSetNegInf(m); - goto Exit; - } - - if (d == 0.0) { - VpSetZero(m, 1); - goto Exit; - } - val = (d > 0.) ? d : -d; - ne = 0; - if (val >= 1.0) { - while (val >= 1.0) { - val /= (double)BASE; - ++ne; - } - } - else { - val2 = 1.0 / (double)BASE; - while (val < val2) { - val *= (double)BASE; - --ne; - } - } - /* Now val = 0.xxxxx*BASE**ne */ - - mm = m->MaxPrec; - memset(m->frac, 0, mm * sizeof(DECDIG)); - for (ind_m = 0; val > 0.0 && ind_m < mm; ind_m++) { - val *= (double)BASE; - i = (DECDIG)val; - val -= (double)i; - m->frac[ind_m] = i; - } - if (ind_m >= mm) ind_m = mm - 1; - VpSetSign(m, (d > 0.0) ? 1 : -1); - m->Prec = ind_m + 1; - m->exponent = ne; - - VpInternalRound(m, 0, (m->Prec > 0) ? m->frac[m->Prec-1] : 0, - (DECDIG)(val*(double)BASE)); - -Exit: -#ifdef BIGDECIMAL_DEBUG - if (gfDebug) { - printf("VpDtoV d=%30.30e\n", d); - VPrint(stdout, " m=%\n", m); - } -#endif /* BIGDECIMAL_DEBUG */ - return; -} - /* * m <- ival */ @@ -7192,112 +7096,6 @@ VpItoV(Real *m, SIGNED_VALUE ival) } #endif -/* - * y = SQRT(x), y*y - x =>0 - */ -VP_EXPORT int -VpSqrt(Real *y, Real *x) -{ - Real *f = NULL; - Real *r = NULL; - size_t y_prec; - SIGNED_VALUE n, e; - ssize_t nr; - double val; - - /* Zero or +Infinity ? */ - if (VpIsZero(x) || VpIsPosInf(x)) { - VpAsgn(y,x,1); - goto Exit; - } - - /* Negative ? */ - if (BIGDECIMAL_NEGATIVE_P(x)) { - VpSetNaN(y); - return VpException(VP_EXCEPTION_OP, "sqrt of negative value", 0); - } - - /* NaN ? */ - if (VpIsNaN(x)) { - VpSetNaN(y); - return VpException(VP_EXCEPTION_OP, "sqrt of 'NaN'(Not a Number)", 0); - } - - /* One ? */ - if (VpIsOne(x)) { - VpSetOne(y); - goto Exit; - } - - n = (SIGNED_VALUE)y->MaxPrec; - if (x->MaxPrec > (size_t)n) n = (ssize_t)x->MaxPrec; - - /* allocate temporally variables */ - /* TODO: reconsider MaxPrec of f and r */ - f = NewOneNolimit(1, y->MaxPrec * (BASE_FIG + 2)); - r = NewOneNolimit(1, (n + n) * (BASE_FIG + 2)); - - nr = 0; - y_prec = y->MaxPrec; - - VpVtoD(&val, &e, x); /* val <- x */ - e /= (SIGNED_VALUE)BASE_FIG; - n = e / 2; - if (e - n * 2 != 0) { - val /= BASE; - n = (e + 1) / 2; - } - VpDtoV(y, sqrt(val)); /* y <- sqrt(val) */ - y->exponent += n; - n = (SIGNED_VALUE)roomof(BIGDECIMAL_DOUBLE_FIGURES, BASE_FIG); - y->MaxPrec = Min((size_t)n , y_prec); - f->MaxPrec = y->MaxPrec + 1; - n = (SIGNED_VALUE)(y_prec * BASE_FIG); - if (n > (SIGNED_VALUE)maxnr) n = (SIGNED_VALUE)maxnr; - - /* - * Perform: y_{n+1} = (y_n - x/y_n) / 2 - */ - do { - y->MaxPrec *= 2; - if (y->MaxPrec > y_prec) y->MaxPrec = y_prec; - f->MaxPrec = y->MaxPrec; - VpDivd(f, r, x, y); /* f = x/y */ - VpAddSub(r, f, y, -1); /* r = f - y */ - VpMult(f, VpConstPt5, r); /* f = 0.5*r */ - if (VpIsZero(f)) - goto converge; - VpAddSub(r, f, y, 1); /* r = y + f */ - VpAsgn(y, r, 1); /* y = r */ - } while (++nr < n); - -#ifdef BIGDECIMAL_DEBUG - if (gfDebug) { - printf("ERROR(VpSqrt): did not converge within %ld iterations.\n", nr); - } -#endif /* BIGDECIMAL_DEBUG */ - y->MaxPrec = y_prec; - -converge: - VpChangeSign(y, 1); -#ifdef BIGDECIMAL_DEBUG - if (gfDebug) { - VpMult(r, y, y); - VpAddSub(f, x, r, -1); - printf("VpSqrt: iterations = %"PRIdSIZE"\n", nr); - VPrint(stdout, " y =% \n", y); - VPrint(stdout, " x =% \n", x); - VPrint(stdout, " x-y*y = % \n", f); - } -#endif /* BIGDECIMAL_DEBUG */ - y->MaxPrec = y_prec; - -Exit: - rbd_free_struct(f); - rbd_free_struct(r); - return 1; -} - /* * Round relatively from the decimal point. * f: rounding mode diff --git a/ext/bigdecimal/bigdecimal.h b/ext/bigdecimal/bigdecimal.h index 54fed81..7e6be80 100644 --- a/ext/bigdecimal/bigdecimal.h +++ b/ext/bigdecimal/bigdecimal.h @@ -237,11 +237,9 @@ VP_EXPORT void VpToString(Real *a, char *buf, size_t bufsize, size_t fFmt, int f VP_EXPORT void VpToFString(Real *a, char *buf, size_t bufsize, size_t fFmt, int fPlus); VP_EXPORT int VpCtoV(Real *a, const char *int_chr, size_t ni, const char *frac, size_t nf, const char *exp_chr, size_t ne); VP_EXPORT int VpVtoD(double *d, SIGNED_VALUE *e, Real *m); -VP_EXPORT void VpDtoV(Real *m,double d); #if 0 /* unused */ VP_EXPORT void VpItoV(Real *m,S_INT ival); #endif -VP_EXPORT int VpSqrt(Real *y,Real *x); VP_EXPORT int VpActiveRound(Real *y, Real *x, unsigned short f, ssize_t il); VP_EXPORT int VpMidRound(Real *y, unsigned short f, ssize_t nf); VP_EXPORT int VpLeftRound(Real *y, unsigned short f, ssize_t nf); diff --git a/lib/bigdecimal.rb b/lib/bigdecimal.rb index 82b3e1b..959e9b8 100644 --- a/lib/bigdecimal.rb +++ b/lib/bigdecimal.rb @@ -3,3 +3,36 @@ else require 'bigdecimal.so' end + +class BigDecimal + + # Returns the square root of the value. + # + # Result has at least prec significant digits. + # + def sqrt(prec) + if infinite? == 1 + exception_mode = BigDecimal.mode(BigDecimal::EXCEPTION_ALL) + raise FloatDomainError, "Computation results in 'Infinity'" if exception_mode.anybits?(BigDecimal::EXCEPTION_INFINITY) + return INFINITY + end + raise ArgumentError, 'negative precision' if prec < 0 + raise FloatDomainError, 'sqrt of negative value' if self < 0 + raise FloatDomainError, "sqrt of 'NaN'(Not a Number)" if nan? + + n_digits = n_significant_digits + prec = [prec, n_digits].max + + if n_digits < prec / 2 + # Fast path for sqrt(16e100) => 4e50 + ex = (n_digits - exponent + 1) / 2 + n = (self * BigDecimal("1e#{2 * ex}")).to_i + sqrt = Integer.sqrt(n) + return BigDecimal(sqrt) * BigDecimal("1e#{-ex}") if sqrt * sqrt == n + end + + ex = prec + BigDecimal.double_fig - exponent / 2 + sqrt = Integer.sqrt(self * BigDecimal("1e#{2 * ex}")) + BigDecimal(sqrt) * BigDecimal("1e#{-ex}") + end +end diff --git a/test/bigdecimal/test_bigdecimal.rb b/test/bigdecimal/test_bigdecimal.rb index c7cbe0f..7d0ea11 100644 --- a/test/bigdecimal/test_bigdecimal.rb +++ b/test/bigdecimal/test_bigdecimal.rb @@ -1213,8 +1213,6 @@ def test_sqrt_bigdecimal assert_equal(true, (x.sqrt(300) - y).abs < BigDecimal("1E#{e-300}")) x = BigDecimal("-" + (2**100).to_s) assert_raise_with_message(FloatDomainError, "sqrt of negative value") { x.sqrt(1) } - x = BigDecimal((2**200).to_s) - assert_equal(2**100, x.sqrt(1)) BigDecimal.mode(BigDecimal::EXCEPTION_OVERFLOW, false) BigDecimal.mode(BigDecimal::EXCEPTION_NaN, false) @@ -1225,6 +1223,12 @@ def test_sqrt_bigdecimal assert_equal(0, BigDecimal("-0").sqrt(1)) assert_equal(1, BigDecimal("1").sqrt(1)) assert_positive_infinite(BigDecimal("Infinity").sqrt(1)) + + assert_equal(BigDecimal('11.1'), BigDecimal('123.21').sqrt(100)) + assert_equal(BigDecimal('11e20'), BigDecimal('121e40').sqrt(100)) + assert_in_epsilon(Math.sqrt(121e41), BigDecimal('121e41').sqrt(100)) + assert_in_epsilon(Math.sqrt(121.5e40), BigDecimal('121.5e40').sqrt(100)) + assert_in_epsilon(Math.sqrt(2e100), BigDecimal('2e100').sqrt(10)) end def test_sqrt_5266 @@ -1241,6 +1245,20 @@ def test_sqrt_5266 x.sqrt(109).to_s(109).split(' ')[0]) end + def test_sqrt_minimum_precision + x = BigDecimal((2**200).to_s) + assert_equal(2**100, x.sqrt(1)) + + x = BigDecimal('1' * 60 + '.' + '1' * 40) + assert_in_delta(BigDecimal('3' * 30 + '.' + '3' * 70), x.sqrt(1), BigDecimal('1e-70')) + + x = BigDecimal('1' * 40 + '.' + '1' * 60) + assert_in_delta(BigDecimal('3' * 20 + '.' + '3' * 80), x.sqrt(1), BigDecimal('1e-80')) + + x = BigDecimal('0.' + '0' * 50 + '1' * 100) + assert_in_delta(BigDecimal('0.' + '0' * 25 + '3' * 100), x.sqrt(1), BigDecimal('1e-125')) + end + def test_fix x = BigDecimal("1.1") assert_equal(1, x.fix)