Skip to content

Rewrite BigDecimal#sqrt in ruby using Integer.sqrt #323

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 0 additions & 202 deletions ext/bigdecimal/bigdecimal.c
Original file line number Diff line number Diff line change
Expand Up @@ -2414,31 +2414,6 @@ BigDecimal_abs(VALUE self)
return VpCheckGetValue(c);
}

/* call-seq:
* sqrt(n)
*
* Returns the square root of the value.
*
* Result has at least n significant digits.
*/
static VALUE
BigDecimal_sqrt(VALUE self, VALUE nFig)
{
ENTER(5);
Real *c, *a;
size_t mx, n;

GUARD_OBJ(a, GetVpValue(self, 1));
mx = a->Prec * (VpBaseFig() + 1);

n = check_int_precision(nFig);
n += VpDblFig() + VpBaseFig();
if (mx <= n) mx = n;
GUARD_OBJ(c, NewZeroWrapLimited(1, mx));
VpSqrt(c, a);
return VpCheckGetValue(c);
}

/* Return the integer part of the number, as a BigDecimal.
*/
static VALUE
Expand Down Expand Up @@ -4607,7 +4582,6 @@ Init_bigdecimal(void)
rb_define_method(rb_cBigDecimal, "dup", BigDecimal_clone, 0);
rb_define_method(rb_cBigDecimal, "to_f", BigDecimal_to_f, 0);
rb_define_method(rb_cBigDecimal, "abs", BigDecimal_abs, 0);
rb_define_method(rb_cBigDecimal, "sqrt", BigDecimal_sqrt, 1);
rb_define_method(rb_cBigDecimal, "fix", BigDecimal_fix, 0);
rb_define_method(rb_cBigDecimal, "round", BigDecimal_round, -1);
rb_define_method(rb_cBigDecimal, "frac", BigDecimal_frac, 0);
Expand Down Expand Up @@ -4684,8 +4658,6 @@ static int gfCheckVal = 1; /* Value checking flag in VpNmlz() */

static Real *VpConstOne; /* constant 1.0 */
static Real *VpConstPt5; /* constant 0.5 */
#define maxnr 100UL /* Maximum iterations for calculating sqrt. */
/* used in VpSqrt() */

/* ETC */
#define MemCmp(x,y,z) memcmp(x,y,z)
Expand Down Expand Up @@ -7089,74 +7061,6 @@ VpVtoD(double *d, SIGNED_VALUE *e, Real *m)
return f;
}

/*
* m <- d
*/
VP_EXPORT void
VpDtoV(Real *m, double d)
{
size_t ind_m, mm;
SIGNED_VALUE ne;
DECDIG i;
double val, val2;

if (isnan(d)) {
VpSetNaN(m);
goto Exit;
}
if (isinf(d)) {
if (d > 0.0) VpSetPosInf(m);
else VpSetNegInf(m);
goto Exit;
}

if (d == 0.0) {
VpSetZero(m, 1);
goto Exit;
}
val = (d > 0.) ? d : -d;
ne = 0;
if (val >= 1.0) {
while (val >= 1.0) {
val /= (double)BASE;
++ne;
}
}
else {
val2 = 1.0 / (double)BASE;
while (val < val2) {
val *= (double)BASE;
--ne;
}
}
/* Now val = 0.xxxxx*BASE**ne */

mm = m->MaxPrec;
memset(m->frac, 0, mm * sizeof(DECDIG));
for (ind_m = 0; val > 0.0 && ind_m < mm; ind_m++) {
val *= (double)BASE;
i = (DECDIG)val;
val -= (double)i;
m->frac[ind_m] = i;
}
if (ind_m >= mm) ind_m = mm - 1;
VpSetSign(m, (d > 0.0) ? 1 : -1);
m->Prec = ind_m + 1;
m->exponent = ne;

VpInternalRound(m, 0, (m->Prec > 0) ? m->frac[m->Prec-1] : 0,
(DECDIG)(val*(double)BASE));

Exit:
#ifdef BIGDECIMAL_DEBUG
if (gfDebug) {
printf("VpDtoV d=%30.30e\n", d);
VPrint(stdout, " m=%\n", m);
}
#endif /* BIGDECIMAL_DEBUG */
return;
}

/*
* m <- ival
*/
Expand Down Expand Up @@ -7221,112 +7125,6 @@ VpItoV(Real *m, SIGNED_VALUE ival)
}
#endif

/*
* y = SQRT(x), y*y - x =>0
*/
VP_EXPORT int
VpSqrt(Real *y, Real *x)
{
Real *f = NULL;
Real *r = NULL;
size_t y_prec;
SIGNED_VALUE n, e;
ssize_t nr;
double val;

/* Zero or +Infinity ? */
if (VpIsZero(x) || VpIsPosInf(x)) {
VpAsgn(y,x,1);
goto Exit;
}

/* Negative ? */
if (BIGDECIMAL_NEGATIVE_P(x)) {
VpSetNaN(y);
return VpException(VP_EXCEPTION_OP, "sqrt of negative value", 0);
}

/* NaN ? */
if (VpIsNaN(x)) {
VpSetNaN(y);
return VpException(VP_EXCEPTION_OP, "sqrt of 'NaN'(Not a Number)", 0);
}

/* One ? */
if (VpIsOne(x)) {
VpSetOne(y);
goto Exit;
}

n = (SIGNED_VALUE)y->MaxPrec;
if (x->MaxPrec > (size_t)n) n = (ssize_t)x->MaxPrec;

/* allocate temporally variables */
/* TODO: reconsider MaxPrec of f and r */
f = NewOneNolimit(1, y->MaxPrec * (BASE_FIG + 2));
r = NewOneNolimit(1, (n + n) * (BASE_FIG + 2));

nr = 0;
y_prec = y->MaxPrec;

VpVtoD(&val, &e, x); /* val <- x */
e /= (SIGNED_VALUE)BASE_FIG;
n = e / 2;
if (e - n * 2 != 0) {
val /= BASE;
n = (e + 1) / 2;
}
VpDtoV(y, sqrt(val)); /* y <- sqrt(val) */
y->exponent += n;
n = (SIGNED_VALUE)roomof(BIGDECIMAL_DOUBLE_FIGURES, BASE_FIG);
y->MaxPrec = Min((size_t)n , y_prec);
f->MaxPrec = y->MaxPrec + 1;
n = (SIGNED_VALUE)(y_prec * BASE_FIG);
if (n > (SIGNED_VALUE)maxnr) n = (SIGNED_VALUE)maxnr;

/*
* Perform: y_{n+1} = (y_n - x/y_n) / 2
*/
do {
y->MaxPrec *= 2;
if (y->MaxPrec > y_prec) y->MaxPrec = y_prec;
f->MaxPrec = y->MaxPrec;
VpDivd(f, r, x, y); /* f = x/y */
VpAddSub(r, f, y, -1); /* r = f - y */
VpMult(f, VpConstPt5, r); /* f = 0.5*r */
if (VpIsZero(f))
goto converge;
VpAddSub(r, f, y, 1); /* r = y + f */
VpAsgn(y, r, 1); /* y = r */
} while (++nr < n);

#ifdef BIGDECIMAL_DEBUG
if (gfDebug) {
printf("ERROR(VpSqrt): did not converge within %ld iterations.\n", nr);
}
#endif /* BIGDECIMAL_DEBUG */
y->MaxPrec = y_prec;

converge:
VpChangeSign(y, 1);
#ifdef BIGDECIMAL_DEBUG
if (gfDebug) {
VpMult(r, y, y);
VpAddSub(f, x, r, -1);
printf("VpSqrt: iterations = %"PRIdSIZE"\n", nr);
VPrint(stdout, " y =% \n", y);
VPrint(stdout, " x =% \n", x);
VPrint(stdout, " x-y*y = % \n", f);
}
#endif /* BIGDECIMAL_DEBUG */
y->MaxPrec = y_prec;

Exit:
rbd_free_struct(f);
rbd_free_struct(r);
return 1;
}

/*
* Round relatively from the decimal point.
* f: rounding mode
Expand Down
2 changes: 0 additions & 2 deletions ext/bigdecimal/bigdecimal.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,9 @@ VP_EXPORT void VpToString(Real *a, char *buf, size_t bufsize, size_t fFmt, int f
VP_EXPORT void VpToFString(Real *a, char *buf, size_t bufsize, size_t fFmt, int fPlus);
VP_EXPORT int VpCtoV(Real *a, const char *int_chr, size_t ni, const char *frac, size_t nf, const char *exp_chr, size_t ne);
VP_EXPORT int VpVtoD(double *d, SIGNED_VALUE *e, Real *m);
VP_EXPORT void VpDtoV(Real *m,double d);
#if 0 /* unused */
VP_EXPORT void VpItoV(Real *m,S_INT ival);
#endif
VP_EXPORT int VpSqrt(Real *y,Real *x);
VP_EXPORT int VpActiveRound(Real *y, Real *x, unsigned short f, ssize_t il);
VP_EXPORT int VpMidRound(Real *y, unsigned short f, ssize_t nf);
VP_EXPORT int VpLeftRound(Real *y, unsigned short f, ssize_t nf);
Expand Down
33 changes: 33 additions & 0 deletions lib/bigdecimal.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,36 @@
else
require 'bigdecimal.so'
end

class BigDecimal

# Returns the square root of the value.
#
# Result has at least prec significant digits.
#
def sqrt(prec)
if infinite? == 1
exception_mode = BigDecimal.mode(BigDecimal::EXCEPTION_ALL)
raise FloatDomainError, "Computation results in 'Infinity'" if exception_mode.anybits?(BigDecimal::EXCEPTION_INFINITY)
return INFINITY
end
raise ArgumentError, 'negative precision' if prec < 0
raise FloatDomainError, 'sqrt of negative value' if self < 0
raise FloatDomainError, "sqrt of 'NaN'(Not a Number)" if nan?

n_digits = n_significant_digits
prec = [prec, n_digits].max

if n_digits < prec / 2
# Fast path for sqrt(16e100) => 4e50
ex = (n_digits - exponent + 1) / 2
n = (self * BigDecimal("1e#{2 * ex}")).to_i
sqrt = Integer.sqrt(n)
return BigDecimal(sqrt) * BigDecimal("1e#{-ex}") if sqrt * sqrt == n
end

ex = prec + BigDecimal.double_fig - exponent / 2
sqrt = Integer.sqrt(self * BigDecimal("1e#{2 * ex}"))
BigDecimal(sqrt) * BigDecimal("1e#{-ex}")
end
end
22 changes: 20 additions & 2 deletions test/bigdecimal/test_bigdecimal.rb
Original file line number Diff line number Diff line change
Expand Up @@ -1172,8 +1172,6 @@
assert_equal(true, (x.sqrt(300) - y).abs < BigDecimal("1E#{e-300}"))
x = BigDecimal("-" + (2**100).to_s)
assert_raise_with_message(FloatDomainError, "sqrt of negative value") { x.sqrt(1) }
x = BigDecimal((2**200).to_s)
assert_equal(2**100, x.sqrt(1))

BigDecimal.mode(BigDecimal::EXCEPTION_OVERFLOW, false)
BigDecimal.mode(BigDecimal::EXCEPTION_NaN, false)
Expand All @@ -1184,11 +1182,17 @@
assert_equal(0, BigDecimal("-0").sqrt(1))
assert_equal(1, BigDecimal("1").sqrt(1))
assert_positive_infinite(BigDecimal("Infinity").sqrt(1))

assert_equal(BigDecimal('11.1'), BigDecimal('123.21').sqrt(100))
assert_equal(BigDecimal('11e20'), BigDecimal('121e40').sqrt(100))
assert_in_epsilon(Math.sqrt(121e41), BigDecimal('121e41').sqrt(100))
assert_in_epsilon(Math.sqrt(121.5e40), BigDecimal('121.5e40').sqrt(100))
assert_in_epsilon(Math.sqrt(2e100), BigDecimal('2e100').sqrt(10))
end

def test_sqrt_5266
x = BigDecimal('2' + '0'*100)
assert_equal('0.14142135623730950488016887242096980785696718753769480731',

Check failure on line 1195 in test/bigdecimal/test_bigdecimal.rb

View workflow job for this annotation

GitHub Actions / ubuntu-latest truffleruby

Failure

<"0.14142135623730950488016887242096980785696718753769480731">("UTF-8") expected but was <"0.14142135623730950533076675614619546927108765441564332116">("US-ASCII").

Check failure on line 1195 in test/bigdecimal/test_bigdecimal.rb

View workflow job for this annotation

GitHub Actions / ubuntu-latest truffleruby-head

Failure

<"0.14142135623730950488016887242096980785696718753769480731">("UTF-8") expected but was <"0.14142135623730950533076675614619546927108765441564332116">("US-ASCII").

Check failure on line 1195 in test/bigdecimal/test_bigdecimal.rb

View workflow job for this annotation

GitHub Actions / macos-latest truffleruby

Failure

<"0.14142135623730950488016887242096980785696718753769480731">("UTF-8") expected but was <"0.14142135623730950533076675614619546927108765441564332116">("US-ASCII").

Check failure on line 1195 in test/bigdecimal/test_bigdecimal.rb

View workflow job for this annotation

GitHub Actions / macos-latest truffleruby-head

Failure

<"0.14142135623730950488016887242096980785696718753769480731">("UTF-8") expected but was <"0.14142135623730950533076675614619546927108765441564332116">("US-ASCII").

Check failure on line 1195 in test/bigdecimal/test_bigdecimal.rb

View workflow job for this annotation

GitHub Actions / macos-14 truffleruby

Failure

<"0.14142135623730950488016887242096980785696718753769480731">("UTF-8") expected but was <"0.14142135623730950533076675614619546927108765441564332116">("US-ASCII").

Check failure on line 1195 in test/bigdecimal/test_bigdecimal.rb

View workflow job for this annotation

GitHub Actions / macos-14 truffleruby-head

Failure

<"0.14142135623730950488016887242096980785696718753769480731">("UTF-8") expected but was <"0.14142135623730950533076675614619546927108765441564332116">("US-ASCII").
x.sqrt(56).to_s(56).split(' ')[0])
assert_equal('0.1414213562373095048801688724209698078569671875376948073',
x.sqrt(55).to_s(55).split(' ')[0])
Expand All @@ -1200,6 +1204,20 @@
x.sqrt(109).to_s(109).split(' ')[0])
end

def test_sqrt_minimum_precision
x = BigDecimal((2**200).to_s)
assert_equal(2**100, x.sqrt(1))

Check failure on line 1209 in test/bigdecimal/test_bigdecimal.rb

View workflow job for this annotation

GitHub Actions / ubuntu-latest truffleruby

Failure

<1267650600228229401496703205376> expected but was <0.126765060022822945707791245897957913861639059079182405399280302397405406429184e31>.

Check failure on line 1209 in test/bigdecimal/test_bigdecimal.rb

View workflow job for this annotation

GitHub Actions / ubuntu-latest truffleruby-head

Failure

<1267650600228229401496703205376> expected but was <0.126765060022822945707791245897957913861639059079182405399280302397405406429184e31>.

Check failure on line 1209 in test/bigdecimal/test_bigdecimal.rb

View workflow job for this annotation

GitHub Actions / macos-latest truffleruby

Failure

<1267650600228229401496703205376> expected but was <0.126765060022822945707791245897957913861639059079182405399280302397405406429184e31>.

Check failure on line 1209 in test/bigdecimal/test_bigdecimal.rb

View workflow job for this annotation

GitHub Actions / macos-latest truffleruby-head

Failure

<1267650600228229401496703205376> expected but was <0.126765060022822945707791245897957913861639059079182405399280302397405406429184e31>.

Check failure on line 1209 in test/bigdecimal/test_bigdecimal.rb

View workflow job for this annotation

GitHub Actions / macos-14 truffleruby

Failure

<1267650600228229401496703205376> expected but was <0.126765060022822945707791245897957913861639059079182405399280302397405406429184e31>.

Check failure on line 1209 in test/bigdecimal/test_bigdecimal.rb

View workflow job for this annotation

GitHub Actions / macos-14 truffleruby-head

Failure

<1267650600228229401496703205376> expected but was <0.126765060022822945707791245897957913861639059079182405399280302397405406429184e31>.

x = BigDecimal('1' * 60 + '.' + '1' * 40)
assert_in_delta(BigDecimal('3' * 30 + '.' + '3' * 70), x.sqrt(1), BigDecimal('1e-70'))

x = BigDecimal('1' * 40 + '.' + '1' * 60)
assert_in_delta(BigDecimal('3' * 20 + '.' + '3' * 80), x.sqrt(1), BigDecimal('1e-80'))

x = BigDecimal('0.' + '0' * 50 + '1' * 100)
assert_in_delta(BigDecimal('0.' + '0' * 25 + '3' * 100), x.sqrt(1), BigDecimal('1e-125'))
end

def test_fix
x = BigDecimal("1.1")
assert_equal(1, x.fix)
Expand Down
Loading