diff --git a/Doc/library/imath.rst b/Doc/library/imath.rst new file mode 100644 index 00000000000000..d3dfce1befc1d7 --- /dev/null +++ b/Doc/library/imath.rst @@ -0,0 +1,77 @@ +:mod:`imath` --- Mathematical functions for integer numbers +=========================================================== + +.. module:: imath + :synopsis: Mathematical functions for integer numbers. + +.. versionadded:: 3.8 + +**Source code:** :source:`Lib/imath.py` + +-------------- + +This module provides access to the mathematical functions for integer arguments. +These functions accept integers and objects that implement the +:meth:`__index__` method which is used to convert the object to an integer +number. They cannot be used with floating-point numbers or complex +numbers. + +The following functions are provided by this module. All return values are +integers. + + +.. function:: comb(n, k) + + Return the number of ways to choose *k* items from *n* items without repetition + and without order. + + Also called the binomial coefficient. It is mathematically equal to the expression + ``n! / (k! (n - k)!)``. It is equivalent to the coefficient of the *k*-th term in the + polynomial expansion of the expression ``(1 + x) ** n``. + + Raises :exc:`TypeError` if the arguments not integers. + Raises :exc:`ValueError` if the arguments are negative or if *k* > *n*. + + +.. function:: gcd(a, b) + + Return the greatest common divisor of the integers *a* and *b*. If either + *a* or *b* is nonzero, then the value of ``gcd(a, b)`` is the largest + positive integer that divides both *a* and *b*. ``gcd(0, 0)`` returns + ``0``. + + +.. function:: ilog2(n) + + Return the integer base 2 logarithm of the positive integer *n*. This is the + floor of the exact base 2 logarithm root of *n*, or equivalently the + greatest integer *k* such that + 2\ :sup:`k` |nbsp| ≤ |nbsp| *n* |nbsp| < |nbsp| 2\ :sup:`k+1`. + + It is equivalent to ``n.bit_length() - 1`` for positive *n*. + + +.. function:: isqrt(n) + + Return the integer square root of the nonnegative integer *n*. This is the + floor of the exact square root of *n*, or equivalently the greatest integer + *a* such that *a*\ ² |nbsp| ≤ |nbsp| *n*. + + For some applications, it may be more convenient to have the least integer + *a* such that *n* |nbsp| ≤ |nbsp| *a*\ ², or in other words the ceiling of + the exact square root of *n*. For positive *n*, this can be computed using + ``a = 1 + isqrt(n - 1)``. + + +.. function:: perm(n, k) + + Return the number of ways to choose *k* items from *n* items + without repetition and with order. + + It is mathematically equal to the expression ``n! / (n - k)!``. + + Raises :exc:`TypeError` if the arguments not integers. + Raises :exc:`ValueError` if the arguments are negative or if *k* > *n*. + +.. |nbsp| unicode:: 0xA0 + :trim: diff --git a/Doc/library/math.rst b/Doc/library/math.rst index c5a77f1fab9fd6..4798f3074ea6f1 100644 --- a/Doc/library/math.rst +++ b/Doc/library/math.rst @@ -36,21 +36,6 @@ Number-theoretic and representation functions :class:`~numbers.Integral` value. -.. function:: comb(n, k) - - Return the number of ways to choose *k* items from *n* items without repetition - and without order. - - Also called the binomial coefficient. It is mathematically equal to the expression - ``n! / (k! (n - k)!)``. It is equivalent to the coefficient of the *k*-th term in the - polynomial expansion of the expression ``(1 + x) ** n``. - - Raises :exc:`TypeError` if the arguments not integers. - Raises :exc:`ValueError` if the arguments are negative or if *k* > *n*. - - .. versionadded:: 3.8 - - .. function:: copysign(x, y) Return a float with the magnitude (absolute value) of *x* but the sign of @@ -65,8 +50,8 @@ Number-theoretic and representation functions .. function:: factorial(x) - Return *x* factorial as an integer. Raises :exc:`ValueError` if *x* is not integral or - is negative. + Similar to :func:`imath.factorial`, but accepts also floating-point numbers + with integer value (like ``3.0``). .. function:: floor(x) @@ -122,10 +107,7 @@ Number-theoretic and representation functions .. function:: gcd(a, b) - Return the greatest common divisor of the integers *a* and *b*. If either - *a* or *b* is nonzero, then the value of ``gcd(a, b)`` is the largest - positive integer that divides both *a* and *b*. ``gcd(0, 0)`` returns - ``0``. + An alias of :func:`imath.gcd`. .. versionadded:: 3.5 @@ -181,20 +163,6 @@ Number-theoretic and representation functions Return ``True`` if *x* is a NaN (not a number), and ``False`` otherwise. -.. function:: isqrt(n) - - Return the integer square root of the nonnegative integer *n*. This is the - floor of the exact square root of *n*, or equivalently the greatest integer - *a* such that *a*\ ² |nbsp| ≤ |nbsp| *n*. - - For some applications, it may be more convenient to have the least integer - *a* such that *n* |nbsp| ≤ |nbsp| *a*\ ², or in other words the ceiling of - the exact square root of *n*. For positive *n*, this can be computed using - ``a = 1 + isqrt(n - 1)``. - - .. versionadded:: 3.8 - - .. function:: ldexp(x, i) Return ``x * (2**i)``. This is essentially the inverse of function @@ -207,19 +175,6 @@ Number-theoretic and representation functions of *x* and are floats. -.. function:: perm(n, k) - - Return the number of ways to choose *k* items from *n* items - without repetition and with order. - - It is mathematically equal to the expression ``n! / (n - k)!``. - - Raises :exc:`TypeError` if the arguments not integers. - Raises :exc:`ValueError` if the arguments are negative or if *k* > *n*. - - .. versionadded:: 3.8 - - .. function:: prod(iterable, *, start=1) Calculate the product of all the elements in the input *iterable*. @@ -580,6 +535,3 @@ Constants Module :mod:`cmath` Complex number versions of many of these functions. - -.. |nbsp| unicode:: 0xA0 - :trim: diff --git a/Lib/test/test_imath.py b/Lib/test/test_imath.py new file mode 100644 index 00000000000000..376da0cdaf4a4a --- /dev/null +++ b/Lib/test/test_imath.py @@ -0,0 +1,344 @@ +from decimal import Decimal +from fractions import Fraction +import unittest +from test import support + +py_imath = support.import_fresh_module('imath', blocked=['_imath']) +c_imath = support.import_fresh_module('imath', fresh=['_imath']) + +class IntSubclass(int): + pass + +# Class providing an __index__ method. +class MyIndexable(object): + def __init__(self, value): + self.value = value + + def __index__(self): + return self.value + +# Here's a pure Python version of the math.factorial algorithm, for +# documentation and comparison purposes. +# +# Formula: +# +# factorial(n) = factorial_odd_part(n) << (n - count_set_bits(n)) +# +# where +# +# factorial_odd_part(n) = product_{i >= 0} product_{0 < j <= n >> i; j odd} j +# +# The outer product above is an infinite product, but once i >= n.bit_length, +# (n >> i) < 1 and the corresponding term of the product is empty. So only the +# finitely many terms for 0 <= i < n.bit_length() contribute anything. +# +# We iterate downwards from i == n.bit_length() - 1 to i == 0. The inner +# product in the formula above starts at 1 for i == n.bit_length(); for each i +# < n.bit_length() we get the inner product for i from that for i + 1 by +# multiplying by all j in {n >> i+1 < j <= n >> i; j odd}. In Python terms, +# this set is range((n >> i+1) + 1 | 1, (n >> i) + 1 | 1, 2). + +def count_set_bits(n): + """Number of '1' bits in binary expansion of a nonnnegative integer.""" + return 1 + count_set_bits(n & n - 1) if n else 0 + +def partial_product(start, stop): + """Product of integers in range(start, stop, 2), computed recursively. + start and stop should both be odd, with start <= stop. + + """ + numfactors = (stop - start) >> 1 + if not numfactors: + return 1 + elif numfactors == 1: + return start + else: + mid = (start + numfactors) | 1 + return partial_product(start, mid) * partial_product(mid, stop) + +def py_factorial(n): + """Factorial of nonnegative integer n, via "Binary Split Factorial Formula" + described at http://www.luschny.de/math/factorial/binarysplitfact.html + + """ + inner = outer = 1 + for i in reversed(range(n.bit_length())): + inner *= partial_product((n >> i + 1) + 1 | 1, (n >> i) + 1 | 1) + outer *= inner + return outer << (n - count_set_bits(n)) + + +class IMathTests: + + def assertIntEqual(self, actual, expected): + self.assertEqual(actual, expected) + self.assertIs(type(actual), int) + + def testFactorial(self): + factorial = self.module.factorial + self.assertEqual(factorial(0), 1) + total = 1 + for i in range(1, 1000): + total *= i + self.assertEqual(factorial(i), total) + self.assertEqual(factorial(i), py_factorial(i)) + + self.assertIntEqual(factorial(False), 1) + self.assertIntEqual(factorial(True), 1) + for i in range(3): + expected = factorial(i) + self.assertIntEqual(factorial(IntSubclass(i)), expected) + self.assertIntEqual(factorial(MyIndexable(i)), expected) + + self.assertRaises(ValueError, factorial, -1) + self.assertRaises(ValueError, factorial, -10**1000) + + self.assertRaises(TypeError, factorial, 5.0) + self.assertRaises(TypeError, factorial, -1.0) + self.assertRaises(TypeError, factorial, -1.0e100) + self.assertRaises(TypeError, factorial, Decimal(5.0)) + self.assertRaises(TypeError, factorial, Fraction(5, 1)) + self.assertRaises(TypeError, factorial, "5") + + if self.module is c_imath: + self.assertRaises((OverflowError, MemoryError), factorial, 10**100) + + def testGcd(self): + gcd = self.module.gcd + self.assertEqual(gcd(0, 0), 0) + self.assertEqual(gcd(1, 0), 1) + self.assertEqual(gcd(-1, 0), 1) + self.assertEqual(gcd(0, 1), 1) + self.assertEqual(gcd(0, -1), 1) + self.assertEqual(gcd(7, 1), 1) + self.assertEqual(gcd(7, -1), 1) + self.assertEqual(gcd(-23, 15), 1) + self.assertEqual(gcd(120, 84), 12) + self.assertEqual(gcd(84, -120), 12) + self.assertEqual(gcd(1216342683557601535506311712, + 436522681849110124616458784), 32) + c = 652560 + x = 434610456570399902378880679233098819019853229470286994367836600566 + y = 1064502245825115327754847244914921553977 + a = x * c + b = y * c + self.assertEqual(gcd(a, b), c) + self.assertEqual(gcd(b, a), c) + self.assertEqual(gcd(-a, b), c) + self.assertEqual(gcd(b, -a), c) + self.assertEqual(gcd(a, -b), c) + self.assertEqual(gcd(-b, a), c) + self.assertEqual(gcd(-a, -b), c) + self.assertEqual(gcd(-b, -a), c) + c = 576559230871654959816130551884856912003141446781646602790216406874 + a = x * c + b = y * c + self.assertEqual(gcd(a, b), c) + self.assertEqual(gcd(b, a), c) + self.assertEqual(gcd(-a, b), c) + self.assertEqual(gcd(b, -a), c) + self.assertEqual(gcd(a, -b), c) + self.assertEqual(gcd(-b, a), c) + self.assertEqual(gcd(-a, -b), c) + self.assertEqual(gcd(-b, -a), c) + + self.assertRaises(TypeError, gcd, 120.0, 84) + self.assertRaises(TypeError, gcd, 120, 84.0) + self.assertIntEqual(gcd(IntSubclass(120), IntSubclass(84)), 12) + self.assertIntEqual(gcd(MyIndexable(120), MyIndexable(84)), 12) + + def testIlog(self): + ilog2 = self.module.ilog2 + for value in range(1, 1000): + k = ilog2(value) + self.assertLessEqual(2**k, value) + self.assertLess(value, 2**(k+1)) + self.assertRaises(ValueError, ilog2, 0) + self.assertRaises(ValueError, ilog2, -1) + self.assertRaises(ValueError, ilog2, -2**1000) + + self.assertIntEqual(ilog2(True), 0) + self.assertIntEqual(ilog2(IntSubclass(5)), 2) + self.assertIntEqual(ilog2(MyIndexable(5)), 2) + + self.assertRaises(TypeError, ilog2, 5.0) + self.assertRaises(TypeError, ilog2, Decimal('5')) + self.assertRaises(TypeError, ilog2, Fraction(5, 1)) + self.assertRaises(TypeError, ilog2, '5') + + def testIsqrt(self): + isqrt = self.module.isqrt + # Test a variety of inputs, large and small. + test_values = ( + list(range(1000)) + + list(range(10**6 - 1000, 10**6 + 1000)) + + [2**e + i for e in range(60, 200) for i in range(-40, 40)] + + [3**9999, 10**5001] + ) + + for value in test_values: + with self.subTest(value=value): + s = isqrt(value) + self.assertIs(type(s), int) + self.assertLessEqual(s*s, value) + self.assertLess(value, (s+1)*(s+1)) + + # Negative values + with self.assertRaises(ValueError): + isqrt(-1) + + # Integer-like things + self.assertIntEqual(isqrt(False), 0) + self.assertIntEqual(isqrt(True), 1) + self.assertIntEqual(isqrt(MyIndexable(1729)), 41) + + with self.assertRaises(ValueError): + isqrt(MyIndexable(-3)) + + # Non-integer-like things + for value in 4.0, "4", Decimal("4.0"), Fraction(4, 1), 4j, -4.0: + with self.subTest(value=value): + with self.assertRaises(TypeError): + isqrt(value) + + def testPerm(self): + perm = self.module.perm + factorial = self.module.factorial + # Test if factorial defintion is satisfied + for n in range(100): + for k in range(n + 1): + self.assertEqual(perm(n, k), + factorial(n) // factorial(n - k)) + + # Test for Pascal's identity + for n in range(1, 100): + for k in range(1, n): + self.assertEqual(perm(n, k), perm(n - 1, k - 1) * k + perm(n - 1, k)) + + # Test corner cases + for n in range(1, 100): + self.assertEqual(perm(n, 0), 1) + self.assertEqual(perm(n, 1), n) + self.assertEqual(perm(n, n), factorial(n)) + + # Raises TypeError if any argument is non-integer or argument count is + # not 2 + self.assertRaises(TypeError, perm, 10, 1.0) + self.assertRaises(TypeError, perm, 10, Decimal(1.0)) + self.assertRaises(TypeError, perm, 10, Fraction(1, 1)) + self.assertRaises(TypeError, perm, 10, "1") + self.assertRaises(TypeError, perm, 10.0, 1) + self.assertRaises(TypeError, perm, Decimal(10.0), 1) + self.assertRaises(TypeError, perm, Fraction(10, 1), 1) + self.assertRaises(TypeError, perm, "10", 1) + + self.assertRaises(TypeError, perm, 10) + self.assertRaises(TypeError, perm, 10, 1, 3) + self.assertRaises(TypeError, perm) + + # Raises Value error if not k or n are negative numbers + self.assertRaises(ValueError, perm, -1, 1) + self.assertRaises(ValueError, perm, -2**1000, 1) + self.assertRaises(ValueError, perm, 1, -1) + self.assertRaises(ValueError, perm, 1, -2**1000) + + # Raises value error if k is greater than n + self.assertRaises(ValueError, perm, 1, 2) + self.assertRaises(ValueError, perm, 1, 2**1000) + + n = 2**1000 + self.assertEqual(perm(n, 0), 1) + self.assertEqual(perm(n, 1), n) + self.assertEqual(perm(n, 2), n * (n-1)) + if self.module is c_imath: + self.assertRaises((OverflowError, MemoryError), perm, n, n) + + for n, k in (True, True), (True, False), (False, False): + self.assertIntEqual(perm(n, k), 1) + self.assertEqual(perm(IntSubclass(5), IntSubclass(2)), 20) + self.assertEqual(perm(MyIndexable(5), MyIndexable(2)), 20) + for k in range(3): + self.assertIs(type(perm(IntSubclass(5), IntSubclass(k))), int) + self.assertIs(type(perm(MyIndexable(5), MyIndexable(k))), int) + + def testComb(self): + comb = self.module.comb + factorial = self.module.factorial + # Test if factorial defintion is satisfied + for n in range(100): + for k in range(n + 1): + self.assertEqual(comb(n, k), factorial(n) + // (factorial(k) * factorial(n - k))) + + # Test for Pascal's identity + for n in range(1, 100): + for k in range(1, n): + self.assertEqual(comb(n, k), comb(n - 1, k - 1) + comb(n - 1, k)) + + # Test corner cases + for n in range(100): + self.assertEqual(comb(n, 0), 1) + self.assertEqual(comb(n, n), 1) + + for n in range(1, 100): + self.assertEqual(comb(n, 1), n) + self.assertEqual(comb(n, n - 1), n) + + # Test Symmetry + for n in range(100): + for k in range(n // 2): + self.assertEqual(comb(n, k), comb(n, n - k)) + + # Raises TypeError if any argument is non-integer or argument count is + # not 2 + self.assertRaises(TypeError, comb, 10, 1.0) + self.assertRaises(TypeError, comb, 10, Decimal(1.0)) + self.assertRaises(TypeError, comb, 10, Fraction(1, 1)) + self.assertRaises(TypeError, comb, 10, "1") + self.assertRaises(TypeError, comb, 10.0, 1) + self.assertRaises(TypeError, comb, Fraction(10, 1), 1) + self.assertRaises(TypeError, comb, "10", 1) + + self.assertRaises(TypeError, comb, 10) + self.assertRaises(TypeError, comb, 10, 1, 3) + self.assertRaises(TypeError, comb) + + # Raises Value error if not k or n are negative numbers + self.assertRaises(ValueError, comb, -1, 1) + self.assertRaises(ValueError, comb, -2**1000, 1) + self.assertRaises(ValueError, comb, 1, -1) + self.assertRaises(ValueError, comb, 1, -2**1000) + + # Raises value error if k is greater than n + self.assertRaises(ValueError, comb, 1, 2) + self.assertRaises(ValueError, comb, 1, 2**1000) + + n = 2**1000 + self.assertEqual(comb(n, 0), 1) + self.assertEqual(comb(n, 1), n) + self.assertEqual(comb(n, 2), n * (n-1) // 2) + self.assertEqual(comb(n, n), 1) + self.assertEqual(comb(n, n-1), n) + self.assertEqual(comb(n, n-2), n * (n-1) // 2) + if self.module is c_imath: + self.assertRaises((OverflowError, MemoryError), comb, n, n//2) + + for n, k in (True, True), (True, False), (False, False): + self.assertIntEqual(comb(n, k), 1) + self.assertIntEqual(comb(IntSubclass(5), IntSubclass(2)), 10) + self.assertIntEqual(comb(MyIndexable(5), MyIndexable(2)), 10) + for k in range(3): + self.assertIs(type(comb(IntSubclass(5), IntSubclass(k))), int) + self.assertIs(type(comb(MyIndexable(5), MyIndexable(k))), int) + + +class PyIMathTests(IMathTests, unittest.TestCase): + module = py_imath + +@unittest.skipUnless(c_imath, 'requires _imath') +class CIMathTests(IMathTests, unittest.TestCase): + module = c_imath + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_math.py b/Lib/test/test_math.py index 96e0cf2fe67197..b0532d0bbac925 100644 --- a/Lib/test/test_math.py +++ b/Lib/test/test_math.py @@ -3,6 +3,7 @@ from test.support import run_unittest, verbose, requires_IEEE_754 from test import support +from .test_imath import py_factorial import unittest import itertools import decimal @@ -77,56 +78,6 @@ def ulp(x): else: return x_next - x -# Here's a pure Python version of the math.factorial algorithm, for -# documentation and comparison purposes. -# -# Formula: -# -# factorial(n) = factorial_odd_part(n) << (n - count_set_bits(n)) -# -# where -# -# factorial_odd_part(n) = product_{i >= 0} product_{0 < j <= n >> i; j odd} j -# -# The outer product above is an infinite product, but once i >= n.bit_length, -# (n >> i) < 1 and the corresponding term of the product is empty. So only the -# finitely many terms for 0 <= i < n.bit_length() contribute anything. -# -# We iterate downwards from i == n.bit_length() - 1 to i == 0. The inner -# product in the formula above starts at 1 for i == n.bit_length(); for each i -# < n.bit_length() we get the inner product for i from that for i + 1 by -# multiplying by all j in {n >> i+1 < j <= n >> i; j odd}. In Python terms, -# this set is range((n >> i+1) + 1 | 1, (n >> i) + 1 | 1, 2). - -def count_set_bits(n): - """Number of '1' bits in binary expansion of a nonnnegative integer.""" - return 1 + count_set_bits(n & n - 1) if n else 0 - -def partial_product(start, stop): - """Product of integers in range(start, stop, 2), computed recursively. - start and stop should both be odd, with start <= stop. - - """ - numfactors = (stop - start) >> 1 - if not numfactors: - return 1 - elif numfactors == 1: - return start - else: - mid = (start + numfactors) | 1 - return partial_product(start, mid) * partial_product(mid, stop) - -def py_factorial(n): - """Factorial of nonnegative integer n, via "Binary Split Factorial Formula" - described at http://www.luschny.de/math/factorial/binarysplitfact.html - - """ - inner = outer = 1 - for i in reversed(range(n.bit_length())): - inner *= partial_product((n >> i + 1) + 1 | 1, (n >> i) + 1 | 1) - outer *= inner - return outer << (n - count_set_bits(n)) - def ulp_abs_check(expected, got, ulp_tol, abs_tol): """Given finite floats `expected` and `got`, check that they're approximately equal to within the given number of ulps or the @@ -240,9 +191,6 @@ def result_check(expected, got, ulp_tol=5, abs_tol=0.0): else: return None -class IntSubclass(int): - pass - # Class providing an __index__ method. class MyIndexable(object): def __init__(self, value): @@ -915,59 +863,6 @@ class T(tuple): self.assertEqual(math.dist(p, q), 5*scale) self.assertEqual(math.dist(q, p), 5*scale) - def testIsqrt(self): - # Test a variety of inputs, large and small. - test_values = ( - list(range(1000)) - + list(range(10**6 - 1000, 10**6 + 1000)) - + [2**e + i for e in range(60, 200) for i in range(-40, 40)] - + [3**9999, 10**5001] - ) - - for value in test_values: - with self.subTest(value=value): - s = math.isqrt(value) - self.assertIs(type(s), int) - self.assertLessEqual(s*s, value) - self.assertLess(value, (s+1)*(s+1)) - - # Negative values - with self.assertRaises(ValueError): - math.isqrt(-1) - - # Integer-like things - s = math.isqrt(True) - self.assertIs(type(s), int) - self.assertEqual(s, 1) - - s = math.isqrt(False) - self.assertIs(type(s), int) - self.assertEqual(s, 0) - - class IntegerLike(object): - def __init__(self, value): - self.value = value - - def __index__(self): - return self.value - - s = math.isqrt(IntegerLike(1729)) - self.assertIs(type(s), int) - self.assertEqual(s, 41) - - with self.assertRaises(ValueError): - math.isqrt(IntegerLike(-3)) - - # Non-integer-like things - bad_values = [ - 3.5, "a string", decimal.Decimal("3.5"), 3.5j, - 100.0, -4.0, - ] - for value in bad_values: - with self.subTest(value=value): - with self.assertRaises(TypeError): - math.isqrt(value) - def testLdexp(self): self.assertRaises(TypeError, math.ldexp) self.ftest('ldexp(0,1)', math.ldexp(0,1), 0) @@ -1865,133 +1760,6 @@ def test_fractions(self): self.assertAllClose(fraction_examples, rel_tol=1e-8) self.assertAllNotClose(fraction_examples, rel_tol=1e-9) - def testPerm(self): - perm = math.perm - factorial = math.factorial - # Test if factorial defintion is satisfied - for n in range(100): - for k in range(n + 1): - self.assertEqual(perm(n, k), - factorial(n) // factorial(n - k)) - - # Test for Pascal's identity - for n in range(1, 100): - for k in range(1, n): - self.assertEqual(perm(n, k), perm(n - 1, k - 1) * k + perm(n - 1, k)) - - # Test corner cases - for n in range(1, 100): - self.assertEqual(perm(n, 0), 1) - self.assertEqual(perm(n, 1), n) - self.assertEqual(perm(n, n), factorial(n)) - - # Raises TypeError if any argument is non-integer or argument count is - # not 2 - self.assertRaises(TypeError, perm, 10, 1.0) - self.assertRaises(TypeError, perm, 10, decimal.Decimal(1.0)) - self.assertRaises(TypeError, perm, 10, "1") - self.assertRaises(TypeError, perm, 10.0, 1) - self.assertRaises(TypeError, perm, decimal.Decimal(10.0), 1) - self.assertRaises(TypeError, perm, "10", 1) - - self.assertRaises(TypeError, perm, 10) - self.assertRaises(TypeError, perm, 10, 1, 3) - self.assertRaises(TypeError, perm) - - # Raises Value error if not k or n are negative numbers - self.assertRaises(ValueError, perm, -1, 1) - self.assertRaises(ValueError, perm, -2**1000, 1) - self.assertRaises(ValueError, perm, 1, -1) - self.assertRaises(ValueError, perm, 1, -2**1000) - - # Raises value error if k is greater than n - self.assertRaises(ValueError, perm, 1, 2) - self.assertRaises(ValueError, perm, 1, 2**1000) - - n = 2**1000 - self.assertEqual(perm(n, 0), 1) - self.assertEqual(perm(n, 1), n) - self.assertEqual(perm(n, 2), n * (n-1)) - self.assertRaises((OverflowError, MemoryError), perm, n, n) - - for n, k in (True, True), (True, False), (False, False): - self.assertEqual(perm(n, k), 1) - self.assertIs(type(perm(n, k)), int) - self.assertEqual(perm(IntSubclass(5), IntSubclass(2)), 20) - self.assertEqual(perm(MyIndexable(5), MyIndexable(2)), 20) - for k in range(3): - self.assertIs(type(perm(IntSubclass(5), IntSubclass(k))), int) - self.assertIs(type(perm(MyIndexable(5), MyIndexable(k))), int) - - def testComb(self): - comb = math.comb - factorial = math.factorial - # Test if factorial defintion is satisfied - for n in range(100): - for k in range(n + 1): - self.assertEqual(comb(n, k), factorial(n) - // (factorial(k) * factorial(n - k))) - - # Test for Pascal's identity - for n in range(1, 100): - for k in range(1, n): - self.assertEqual(comb(n, k), comb(n - 1, k - 1) + comb(n - 1, k)) - - # Test corner cases - for n in range(100): - self.assertEqual(comb(n, 0), 1) - self.assertEqual(comb(n, n), 1) - - for n in range(1, 100): - self.assertEqual(comb(n, 1), n) - self.assertEqual(comb(n, n - 1), n) - - # Test Symmetry - for n in range(100): - for k in range(n // 2): - self.assertEqual(comb(n, k), comb(n, n - k)) - - # Raises TypeError if any argument is non-integer or argument count is - # not 2 - self.assertRaises(TypeError, comb, 10, 1.0) - self.assertRaises(TypeError, comb, 10, decimal.Decimal(1.0)) - self.assertRaises(TypeError, comb, 10, "1") - self.assertRaises(TypeError, comb, 10.0, 1) - self.assertRaises(TypeError, comb, decimal.Decimal(10.0), 1) - self.assertRaises(TypeError, comb, "10", 1) - - self.assertRaises(TypeError, comb, 10) - self.assertRaises(TypeError, comb, 10, 1, 3) - self.assertRaises(TypeError, comb) - - # Raises Value error if not k or n are negative numbers - self.assertRaises(ValueError, comb, -1, 1) - self.assertRaises(ValueError, comb, -2**1000, 1) - self.assertRaises(ValueError, comb, 1, -1) - self.assertRaises(ValueError, comb, 1, -2**1000) - - # Raises value error if k is greater than n - self.assertRaises(ValueError, comb, 1, 2) - self.assertRaises(ValueError, comb, 1, 2**1000) - - n = 2**1000 - self.assertEqual(comb(n, 0), 1) - self.assertEqual(comb(n, 1), n) - self.assertEqual(comb(n, 2), n * (n-1) // 2) - self.assertEqual(comb(n, n), 1) - self.assertEqual(comb(n, n-1), n) - self.assertEqual(comb(n, n-2), n * (n-1) // 2) - self.assertRaises((OverflowError, MemoryError), comb, n, n//2) - - for n, k in (True, True), (True, False), (False, False): - self.assertEqual(comb(n, k), 1) - self.assertIs(type(comb(n, k)), int) - self.assertEqual(comb(IntSubclass(5), IntSubclass(2)), 10) - self.assertEqual(comb(MyIndexable(5), MyIndexable(2)), 10) - for k in range(3): - self.assertIs(type(comb(IntSubclass(5), IntSubclass(k))), int) - self.assertIs(type(comb(MyIndexable(5), MyIndexable(k))), int) - def test_main(): from doctest import DocFileSuite diff --git a/Misc/NEWS.d/next/Library/2019-06-02-13-56-16.bpo-37132.axawSH.rst b/Misc/NEWS.d/next/Library/2019-06-02-13-56-16.bpo-37132.axawSH.rst new file mode 100644 index 00000000000000..e41edf69ef1cd0 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2019-06-02-13-56-16.bpo-37132.axawSH.rst @@ -0,0 +1 @@ +Add the :mod:`imath` module. diff --git a/Modules/Setup b/Modules/Setup index e729ab883f410b..316f5c43500da9 100644 --- a/Modules/Setup +++ b/Modules/Setup @@ -169,6 +169,7 @@ _symtable symtablemodule.c #array arraymodule.c # array objects #cmath cmathmodule.c _math.c # -lm # complex math library functions #math mathmodule.c _math.c # -lm # math library functions, e.g. sin() +#_imath _imathmodule.c # integer math library functions, e.g. isqrt() #_contextvars _contextvarsmodule.c # Context Variables #_struct _struct.c # binary structure packing/unpacking #_weakref _weakref.c # basic weak reference support diff --git a/Modules/_imathmodule.c b/Modules/_imathmodule.c new file mode 100644 index 00000000000000..13f872ca149866 --- /dev/null +++ b/Modules/_imathmodule.c @@ -0,0 +1,943 @@ +/* imath module -- integer-related mathematical functions */ + +#include "Python.h" + +#include "clinic/_imathmodule.c.h" + +/*[clinic input] +module imath +[clinic start generated code]*/ +/*[clinic end generated code: output=da39a3ee5e6b4b0d input=0b5ec353335010dd]*/ + + +/*[clinic input] +imath.gcd + + x: object + y: object + / + +Greatest common divisor of x and y. +[clinic start generated code]*/ + +static PyObject * +imath_gcd_impl(PyObject *module, PyObject *x, PyObject *y) +/*[clinic end generated code: output=14eee3e4a3bd7a1d input=11612898ad79c57c]*/ +{ + PyObject *result; + + x = PyNumber_Index(x); + if (x == NULL) + return NULL; + y = PyNumber_Index(y); + if (y == NULL) { + Py_DECREF(x); + return NULL; + } + result = _PyLong_GCD(x, y); + Py_DECREF(x); + Py_DECREF(y); + return result; +} + + +/*[clinic input] +imath.ilog2 + + n: object + / + +Return the integer part of the base 2 logarithm of the input. +[clinic start generated code]*/ + +static PyObject * +imath_ilog2(PyObject *module, PyObject *n) +/*[clinic end generated code: output=6ab48d1a7f5160c2 input=e2d8e8631ec5c29b]*/ +{ + size_t bits; + + n = PyNumber_Index(n); + if (n == NULL) { + return NULL; + } + + if (Py_SIZE(n) <= 0) { + PyErr_SetString( + PyExc_ValueError, + "ilog() argument must be positive"); + Py_DECREF(n); + return NULL; + } + + bits = _PyLong_NumBits(n); + Py_DECREF(n); + if (bits == (size_t)(-1)) { + return NULL; + } + return PyLong_FromSize_t(bits - 1); +} + + +/* Integer square root + +Given a nonnegative integer `n`, we want to compute the largest integer +`a` for which `a * a <= n`, or equivalently the integer part of the exact +square root of `n`. + +We use an adaptive-precision pure-integer version of Newton's iteration. Given +a positive integer `n`, the algorithm produces at each iteration an integer +approximation `a` to the square root of `n >> s` for some even integer `s`, +with `s` decreasing as the iterations progress. On the final iteration, `s` is +zero and we have an approximation to the square root of `n` itself. + +At every step, the approximation `a` is strictly within 1.0 of the true square +root, so we have + + (a - 1)**2 < (n >> s) < (a + 1)**2 + +After the final iteration, a check-and-correct step is needed to determine +whether `a` or `a - 1` gives the desired integer square root of `n`. + +The algorithm is remarkable in its simplicity. There's no need for a +per-iteration check-and-correct step, and termination is straightforward: the +number of iterations is known in advance (it's exactly `floor(log2(log2(n)))` +for `n > 1`). The only tricky part of the correctness proof is in establishing +that the bound `(a - 1)**2 < (n >> s) < (a + 1)**2` is maintained from one +iteration to the next. A sketch of the proof of this is given below. + +In addition to the proof sketch, a formal, computer-verified proof +of correctness (using Lean) of an equivalent recursive algorithm can be found +here: + + https://github.com/mdickinson/snippets/blob/master/proofs/isqrt/src/isqrt.lean + + +Here's Python code equivalent to the C implementation below: + + def isqrt(n): + """ + Return the integer part of the square root of the input. + """ + n = operator.index(n) + + if n < 0: + raise ValueError("isqrt() argument must be nonnegative") + if n == 0: + return 0 + + c = (n.bit_length() - 1) // 2 + a = 1 + d = 0 + for s in reversed(range(c.bit_length())): + e = d + d = c >> s + a = (a << d - e - 1) + (n >> 2*c - e - d + 1) // a + assert (a-1)**2 < n >> 2*(c - d) < (a+1)**2 + + return a - (a*a > n) + + +Sketch of proof of correctness +------------------------------ + +The delicate part of the correctness proof is showing that the loop invariant +is preserved from one iteration to the next. That is, just before the line + + a = (a << d - e - 1) + (n >> 2*c - e - d + 1) // a + +is executed in the above code, we know that + + (1) (a - 1)**2 < (n >> 2*(c - e)) < (a + 1)**2. + +(since `e` is always the value of `d` from the previous iteration). We must +prove that after that line is executed, we have + + (a - 1)**2 < (n >> 2*(c - d)) < (a + 1)**2 + +To faciliate the proof, we make some changes of notation. Write `m` for +`n >> 2*(c-d)`, and write `b` for the new value of `a`, so + + b = (a << d - e - 1) + (n >> 2*c - e - d + 1) // a + +or equivalently: + + (2) b = (a << d - e - 1) + (m >> d - e + 1) // a + +Then we can rewrite (1) as: + + (3) (a - 1)**2 < (m >> 2*(d - e)) < (a + 1)**2 + +and we must show that (b - 1)**2 < m < (b + 1)**2. + +From this point on, we switch to mathematical notation, so `/` means exact +division rather than integer division and `^` is used for exponentiation. We +use the `√` symbol for the exact square root. In (3), we can remove the +implicit floor operation to give: + + (4) (a - 1)^2 < m / 4^(d - e) < (a + 1)^2 + +Taking square roots throughout (4), scaling by `2^(d-e)`, and rearranging gives + + (5) 0 <= | 2^(d-e)a - √m | < 2^(d-e) + +Squaring and dividing through by `2^(d-e+1) a` gives + + (6) 0 <= 2^(d-e-1) a + m / (2^(d-e+1) a) - √m < 2^(d-e-1) / a + +We'll show below that `2^(d-e-1) <= a`. Given that, we can replace the +right-hand side of (6) with `1`, and now replacing the central +term `m / (2^(d-e+1) a)` with its floor in (6) gives + + (7) -1 < 2^(d-e-1) a + m // 2^(d-e+1) a - √m < 1 + +Or equivalently, from (2): + + (7) -1 < b - √m < 1 + +and rearranging gives that `(b-1)^2 < m < (b+1)^2`, which is what we needed +to prove. + +We're not quite done: we still have to prove the inequality `2^(d - e - 1) <= +a` that was used to get line (7) above. From the definition of `c`, we have +`4^c <= n`, which implies + + (8) 4^d <= m + +also, since `e == d >> 1`, `d` is at most `2e + 1`, from which it follows +that `2d - 2e - 1 <= d` and hence that + + (9) 4^(2d - 2e - 1) <= m + +Dividing both sides by `4^(d - e)` gives + + (10) 4^(d - e - 1) <= m / 4^(d - e) + +But we know from (4) that `m / 4^(d-e) < (a + 1)^2`, hence + + (11) 4^(d - e - 1) < (a + 1)^2 + +Now taking square roots of both sides and observing that both `2^(d-e-1)` and +`a` are integers gives `2^(d - e - 1) <= a`, which is what we needed. This +completes the proof sketch. + +*/ + + +/* Approximate square root of a large 64-bit integer. + + Given `n` satisfying `2**62 <= n < 2**64`, return `a` + satisfying `(a - 1)**2 < n < (a + 1)**2`. */ + +static uint64_t +_approximate_isqrt(uint64_t n) +{ + uint32_t u = 1U + (n >> 62); + u = (u << 1) + (n >> 59) / u; + u = (u << 3) + (n >> 53) / u; + u = (u << 7) + (n >> 41) / u; + return (u << 15) + (n >> 17) / u; +} + +/*[clinic input] +imath.isqrt + + n: object + / + +Return the integer part of the square root of the input. +[clinic start generated code]*/ + +static PyObject * +imath_isqrt(PyObject *module, PyObject *n) +/*[clinic end generated code: output=5ad16a80dd47c888 input=64f5b3e586986fc9]*/ +{ + int a_too_large, c_bit_length; + size_t c, d; + uint64_t m, u; + PyObject *a = NULL, *b; + + n = PyNumber_Index(n); + if (n == NULL) { + return NULL; + } + + if (_PyLong_Sign(n) < 0) { + PyErr_SetString( + PyExc_ValueError, + "isqrt() argument must be nonnegative"); + goto error; + } + if (_PyLong_Sign(n) == 0) { + Py_DECREF(n); + return PyLong_FromLong(0); + } + + /* c = (n.bit_length() - 1) // 2 */ + c = _PyLong_NumBits(n); + if (c == (size_t)(-1)) { + goto error; + } + c = (c - 1U) / 2U; + + /* Fast path: if c <= 31 then n < 2**64 and we can compute directly with a + fast, almost branch-free algorithm. In the final correction, we use `u*u + - 1 >= m` instead of the simpler `u*u > m` in order to get the correct + result in the corner case where `u=2**32`. */ + if (c <= 31U) { + m = (uint64_t)PyLong_AsUnsignedLongLong(n); + Py_DECREF(n); + if (m == (uint64_t)(-1) && PyErr_Occurred()) { + return NULL; + } + u = _approximate_isqrt(m << (62U - 2U*c)) >> (31U - c); + u -= u * u - 1U >= m; + return PyLong_FromUnsignedLongLong((unsigned long long)u); + } + + /* Slow path: n >= 2**64. We perform the first five iterations in C integer + arithmetic, then switch to using Python long integers. */ + + /* From n >= 2**64 it follows that c.bit_length() >= 6. */ + c_bit_length = 6; + while ((c >> c_bit_length) > 0U) { + ++c_bit_length; + } + + /* Initialise d and a. */ + d = c >> (c_bit_length - 5); + b = _PyLong_Rshift(n, 2U*c - 62U); + if (b == NULL) { + goto error; + } + m = (uint64_t)PyLong_AsUnsignedLongLong(b); + Py_DECREF(b); + if (m == (uint64_t)(-1) && PyErr_Occurred()) { + goto error; + } + u = _approximate_isqrt(m) >> (31U - d); + a = PyLong_FromUnsignedLongLong((unsigned long long)u); + if (a == NULL) { + goto error; + } + + for (int s = c_bit_length - 6; s >= 0; --s) { + PyObject *q; + size_t e = d; + + d = c >> s; + + /* q = (n >> 2*c - e - d + 1) // a */ + q = _PyLong_Rshift(n, 2U*c - d - e + 1U); + if (q == NULL) { + goto error; + } + Py_SETREF(q, PyNumber_FloorDivide(q, a)); + if (q == NULL) { + goto error; + } + + /* a = (a << d - 1 - e) + q */ + Py_SETREF(a, _PyLong_Lshift(a, d - 1U - e)); + if (a == NULL) { + Py_DECREF(q); + goto error; + } + Py_SETREF(a, PyNumber_Add(a, q)); + Py_DECREF(q); + if (a == NULL) { + goto error; + } + } + + /* The correct result is either a or a - 1. Figure out which, and + decrement a if necessary. */ + + /* a_too_large = n < a * a */ + b = PyNumber_Multiply(a, a); + if (b == NULL) { + goto error; + } + a_too_large = PyObject_RichCompareBool(n, b, Py_LT); + Py_DECREF(b); + if (a_too_large == -1) { + goto error; + } + + if (a_too_large) { + Py_SETREF(a, PyNumber_Subtract(a, _PyLong_One)); + } + Py_DECREF(n); + return a; + + error: + Py_XDECREF(a); + Py_DECREF(n); + return NULL; +} + + +/* Return the smallest integer k such that n < 2**k, or 0 if n == 0. + * Equivalent to floor(lg(x))+1. Also equivalent to: bitwidth_of_type - + * count_leading_zero_bits(x) + */ + +/* XXX: This routine does more or less the same thing as + * bits_in_digit() in Objects/longobject.c. Someday it would be nice to + * consolidate them. On BSD, there's a library function called fls() + * that we could use, and GCC provides __builtin_clz(). + */ + +static unsigned long +bit_length(unsigned long n) +{ + unsigned long len = 0; + while (n != 0) { + ++len; + n >>= 1; + } + return len; +} + +static unsigned long +count_set_bits(unsigned long n) +{ + unsigned long count = 0; + while (n != 0) { + ++count; + n &= n - 1; /* clear least significant bit */ + } + return count; +} + + +/* Divide-and-conquer factorial algorithm + * + * Based on the formula and pseudo-code provided at: + * http://www.luschny.de/math/factorial/binarysplitfact.html + * + * Faster algorithms exist, but they're more complicated and depend on + * a fast prime factorization algorithm. + * + * Notes on the algorithm + * ---------------------- + * + * factorial(n) is written in the form 2**k * m, with m odd. k and m are + * computed separately, and then combined using a left shift. + * + * The function factorial_odd_part computes the odd part m (i.e., the greatest + * odd divisor) of factorial(n), using the formula: + * + * factorial_odd_part(n) = + * + * product_{i >= 0} product_{0 < j <= n / 2**i, j odd} j + * + * Example: factorial_odd_part(20) = + * + * (1) * + * (1) * + * (1 * 3 * 5) * + * (1 * 3 * 5 * 7 * 9) + * (1 * 3 * 5 * 7 * 9 * 11 * 13 * 15 * 17 * 19) + * + * Here i goes from large to small: the first term corresponds to i=4 (any + * larger i gives an empty product), and the last term corresponds to i=0. + * Each term can be computed from the last by multiplying by the extra odd + * numbers required: e.g., to get from the penultimate term to the last one, + * we multiply by (11 * 13 * 15 * 17 * 19). + * + * To see a hint of why this formula works, here are the same numbers as above + * but with the even parts (i.e., the appropriate powers of 2) included. For + * each subterm in the product for i, we multiply that subterm by 2**i: + * + * factorial(20) = + * + * (16) * + * (8) * + * (4 * 12 * 20) * + * (2 * 6 * 10 * 14 * 18) * + * (1 * 3 * 5 * 7 * 9 * 11 * 13 * 15 * 17 * 19) + * + * The factorial_partial_product function computes the product of all odd j in + * range(start, stop) for given start and stop. It's used to compute the + * partial products like (11 * 13 * 15 * 17 * 19) in the example above. It + * operates recursively, repeatedly splitting the range into two roughly equal + * pieces until the subranges are small enough to be computed using only C + * integer arithmetic. + * + * The two-valuation k (i.e., the exponent of the largest power of 2 dividing + * the factorial) is computed independently in the main math_factorial + * function. By standard results, its value is: + * + * two_valuation = n//2 + n//4 + n//8 + .... + * + * It can be shown (e.g., by complete induction on n) that two_valuation is + * equal to n - count_set_bits(n), where count_set_bits(n) gives the number of + * '1'-bits in the binary expansion of n. + */ + +/* factorial_partial_product: Compute product(range(start, stop, 2)) using + * divide and conquer. Assumes start and stop are odd and stop > start. + * max_bits must be >= bit_length(stop - 2). */ + +static PyObject * +factorial_partial_product(unsigned long start, unsigned long stop, + unsigned long max_bits) +{ + unsigned long midpoint, num_operands; + PyObject *left = NULL, *right = NULL, *result = NULL; + + /* If the return value will fit an unsigned long, then we can + * multiply in a tight, fast loop where each multiply is O(1). + * Compute an upper bound on the number of bits required to store + * the answer. + * + * Storing some integer z requires floor(lg(z))+1 bits, which is + * conveniently the value returned by bit_length(z). The + * product x*y will require at most + * bit_length(x) + bit_length(y) bits to store, based + * on the idea that lg product = lg x + lg y. + * + * We know that stop - 2 is the largest number to be multiplied. From + * there, we have: bit_length(answer) <= num_operands * + * bit_length(stop - 2) + */ + + num_operands = (stop - start) / 2; + /* The "num_operands <= 8 * SIZEOF_LONG" check guards against the + * unlikely case of an overflow in num_operands * max_bits. */ + if (num_operands <= 8 * SIZEOF_LONG && + num_operands * max_bits <= 8 * SIZEOF_LONG) { + unsigned long j, total; + for (total = start, j = start + 2; j < stop; j += 2) + total *= j; + return PyLong_FromUnsignedLong(total); + } + + /* find midpoint of range(start, stop), rounded up to next odd number. */ + midpoint = (start + num_operands) | 1; + left = factorial_partial_product(start, midpoint, + bit_length(midpoint - 2)); + if (left == NULL) + goto error; + right = factorial_partial_product(midpoint, stop, max_bits); + if (right == NULL) + goto error; + result = PyNumber_Multiply(left, right); + + error: + Py_XDECREF(left); + Py_XDECREF(right); + return result; +} + +/* factorial_odd_part: compute the odd part of factorial(n). */ + +static PyObject * +factorial_odd_part(unsigned long n) +{ + long i; + unsigned long v, lower, upper; + PyObject *partial, *tmp, *inner, *outer; + + inner = PyLong_FromLong(1); + if (inner == NULL) + return NULL; + outer = inner; + Py_INCREF(outer); + + upper = 3; + for (i = bit_length(n) - 2; i >= 0; i--) { + v = n >> i; + if (v <= 2) + continue; + lower = upper; + /* (v + 1) | 1 = least odd integer strictly larger than n / 2**i */ + upper = (v + 1) | 1; + /* Here inner is the product of all odd integers j in the range (0, + n/2**(i+1)]. The factorial_partial_product call below gives the + product of all odd integers j in the range (n/2**(i+1), n/2**i]. */ + partial = factorial_partial_product(lower, upper, bit_length(upper-2)); + /* inner *= partial */ + if (partial == NULL) + goto error; + tmp = PyNumber_Multiply(inner, partial); + Py_DECREF(partial); + if (tmp == NULL) + goto error; + Py_DECREF(inner); + inner = tmp; + /* Now inner is the product of all odd integers j in the range (0, + n/2**i], giving the inner product in the formula above. */ + + /* outer *= inner; */ + tmp = PyNumber_Multiply(outer, inner); + if (tmp == NULL) + goto error; + Py_DECREF(outer); + outer = tmp; + } + Py_DECREF(inner); + return outer; + + error: + Py_DECREF(outer); + Py_DECREF(inner); + return NULL; +} + + +/* Lookup table for small factorial values */ + +static const unsigned long SmallFactorials[] = { + 1, 1, 2, 6, 24, 120, 720, 5040, 40320, + 362880, 3628800, 39916800, 479001600, +#if SIZEOF_LONG >= 8 + 6227020800, 87178291200, 1307674368000, + 20922789888000, 355687428096000, 6402373705728000, + 121645100408832000, 2432902008176640000 +#endif +}; + +/*[clinic input] +imath.factorial + + x as arg: object + / + +Find x!. + +Raise a TypeError if x is not an integer. +Raise a ValueError if x is negative integer. +[clinic start generated code]*/ + +static PyObject * +imath_factorial(PyObject *module, PyObject *arg) +/*[clinic end generated code: output=73f1879dcbd64aea input=d5f41d496efcaf51]*/ +{ + long x, two_valuation; + int overflow; + PyObject *result, *odd_part, *pyint_form; + + pyint_form = PyNumber_Index(arg); + if (pyint_form == NULL) { + return NULL; + } + x = PyLong_AsLongAndOverflow(pyint_form, &overflow); + Py_DECREF(pyint_form); + if (x == -1 && PyErr_Occurred()) { + return NULL; + } + else if (overflow == 1) { + PyErr_Format(PyExc_OverflowError, + "factorial() argument should not exceed %ld", + LONG_MAX); + return NULL; + } + else if (overflow == -1 || x < 0) { + PyErr_SetString(PyExc_ValueError, + "factorial() not defined for negative values"); + return NULL; + } + + /* use lookup table if x is small */ + if (x < (long)Py_ARRAY_LENGTH(SmallFactorials)) + return PyLong_FromUnsignedLong(SmallFactorials[x]); + + /* else express in the form odd_part * 2**two_valuation, and compute as + odd_part << two_valuation. */ + odd_part = factorial_odd_part(x); + if (odd_part == NULL) + return NULL; + two_valuation = x - count_set_bits(x); + result = _PyLong_Lshift(odd_part, two_valuation); + Py_DECREF(odd_part); + return result; +} + + +/*[clinic input] +imath.perm + + n: object + k: object + / + +Number of ways to choose k items from n items without repetition and with order. + +It is mathematically equal to the expression n! / (n - k)!. + +Raises TypeError if the arguments are not integers. +Raises ValueError if the arguments are negative or if k > n. +[clinic start generated code]*/ + +static PyObject * +imath_perm_impl(PyObject *module, PyObject *n, PyObject *k) +/*[clinic end generated code: output=4b4f3aa22c47911e input=24aa606dab86900c]*/ +{ + PyObject *result = NULL, *factor = NULL; + int overflow, cmp; + long long i, factors; + + n = PyNumber_Index(n); + if (n == NULL) { + return NULL; + } + if (!PyLong_CheckExact(n)) { + Py_SETREF(n, _PyLong_Copy((PyLongObject *)n)); + if (n == NULL) { + return NULL; + } + } + k = PyNumber_Index(k); + if (k == NULL) { + Py_DECREF(n); + return NULL; + } + if (!PyLong_CheckExact(k)) { + Py_SETREF(k, _PyLong_Copy((PyLongObject *)k)); + if (k == NULL) { + Py_DECREF(n); + return NULL; + } + } + + if (Py_SIZE(n) < 0) { + PyErr_SetString(PyExc_ValueError, + "n must be a non-negative integer"); + goto error; + } + cmp = PyObject_RichCompareBool(n, k, Py_LT); + if (cmp != 0) { + if (cmp > 0) { + PyErr_SetString(PyExc_ValueError, + "k must be an integer less than or equal to n"); + } + goto error; + } + + factors = PyLong_AsLongLongAndOverflow(k, &overflow); + if (overflow > 0) { + PyErr_Format(PyExc_OverflowError, + "k must not exceed %lld", + LLONG_MAX); + goto error; + } + else if (overflow < 0 || factors < 0) { + if (!PyErr_Occurred()) { + PyErr_SetString(PyExc_ValueError, + "k must be a non-negative integer"); + } + goto error; + } + + if (factors == 0) { + result = PyLong_FromLong(1); + goto done; + } + + result = n; + Py_INCREF(result); + if (factors == 1) { + goto done; + } + + factor = n; + Py_INCREF(factor); + for (i = 1; i < factors; ++i) { + Py_SETREF(factor, PyNumber_Subtract(factor, _PyLong_One)); + if (factor == NULL) { + goto error; + } + Py_SETREF(result, PyNumber_Multiply(result, factor)); + if (result == NULL) { + goto error; + } + } + Py_DECREF(factor); + +done: + Py_DECREF(n); + Py_DECREF(k); + return result; + +error: + Py_XDECREF(factor); + Py_XDECREF(result); + Py_DECREF(n); + Py_DECREF(k); + return NULL; +} + + +/*[clinic input] +imath.comb + + n: object + k: object + / + +Number of ways to choose k items from n items without repetition and without order. + +Also called the binomial coefficient. It is mathematically equal to the expression +n! / (k! * (n - k)!). It is equivalent to the coefficient of k-th term in +polynomial expansion of the expression (1 + x)**n. + +Raises TypeError if the arguments are not integers. +Raises ValueError if the arguments are negative or if k > n. + +[clinic start generated code]*/ + +static PyObject * +imath_comb_impl(PyObject *module, PyObject *n, PyObject *k) +/*[clinic end generated code: output=87cda746ba0145ef input=f6376b6622fdc123]*/ +{ + PyObject *result = NULL, *factor = NULL, *temp; + int overflow, cmp; + long long i, factors; + + n = PyNumber_Index(n); + if (n == NULL) { + return NULL; + } + if (!PyLong_CheckExact(n)) { + Py_SETREF(n, _PyLong_Copy((PyLongObject *)n)); + if (n == NULL) { + return NULL; + } + } + k = PyNumber_Index(k); + if (k == NULL) { + Py_DECREF(n); + return NULL; + } + if (!PyLong_CheckExact(k)) { + Py_SETREF(k, _PyLong_Copy((PyLongObject *)k)); + if (k == NULL) { + Py_DECREF(n); + return NULL; + } + } + + if (Py_SIZE(n) < 0) { + PyErr_SetString(PyExc_ValueError, + "n must be a non-negative integer"); + goto error; + } + /* k = min(k, n - k) */ + temp = PyNumber_Subtract(n, k); + if (temp == NULL) { + goto error; + } + if (Py_SIZE(temp) < 0) { + Py_DECREF(temp); + PyErr_SetString(PyExc_ValueError, + "k must be an integer less than or equal to n"); + goto error; + } + cmp = PyObject_RichCompareBool(temp, k, Py_LT); + if (cmp > 0) { + Py_SETREF(k, temp); + } + else { + Py_DECREF(temp); + if (cmp < 0) { + goto error; + } + } + + factors = PyLong_AsLongLongAndOverflow(k, &overflow); + if (overflow > 0) { + PyErr_Format(PyExc_OverflowError, + "min(n - k, k) must not exceed %lld", + LLONG_MAX); + goto error; + } + else if (overflow < 0 || factors < 0) { + if (!PyErr_Occurred()) { + PyErr_SetString(PyExc_ValueError, + "k must be a non-negative integer"); + } + goto error; + } + + if (factors == 0) { + result = PyLong_FromLong(1); + goto done; + } + + result = n; + Py_INCREF(result); + if (factors == 1) { + goto done; + } + + factor = n; + Py_INCREF(factor); + for (i = 1; i < factors; ++i) { + Py_SETREF(factor, PyNumber_Subtract(factor, _PyLong_One)); + if (factor == NULL) { + goto error; + } + Py_SETREF(result, PyNumber_Multiply(result, factor)); + if (result == NULL) { + goto error; + } + + temp = PyLong_FromUnsignedLongLong((unsigned long long)i + 1); + if (temp == NULL) { + goto error; + } + Py_SETREF(result, PyNumber_FloorDivide(result, temp)); + Py_DECREF(temp); + if (result == NULL) { + goto error; + } + } + Py_DECREF(factor); + +done: + Py_DECREF(n); + Py_DECREF(k); + return result; + +error: + Py_XDECREF(factor); + Py_XDECREF(result); + Py_DECREF(n); + Py_DECREF(k); + return NULL; +} + + +static PyMethodDef imath_methods[] = { + IMATH_COMB_METHODDEF + IMATH_FACTORIAL_METHODDEF + IMATH_GCD_METHODDEF + IMATH_ILOG2_METHODDEF + IMATH_ISQRT_METHODDEF + IMATH_PERM_METHODDEF + {NULL, NULL} /* sentinel */ +}; + + +PyDoc_STRVAR(module_doc, +"This module provides access to integer related mathematical functions."); + + +static struct PyModuleDef imathmodule = { + PyModuleDef_HEAD_INIT, + "imath", + module_doc, + -1, + imath_methods, + NULL, + NULL, + NULL, + NULL +}; + +PyMODINIT_FUNC +PyInit__imath(void) +{ + return PyModule_Create(&imathmodule); +} diff --git a/Modules/clinic/_imathmodule.c.h b/Modules/clinic/_imathmodule.c.h new file mode 100644 index 00000000000000..54f75c6367f5ce --- /dev/null +++ b/Modules/clinic/_imathmodule.c.h @@ -0,0 +1,136 @@ +/*[clinic input] +preserve +[clinic start generated code]*/ + +PyDoc_STRVAR(imath_gcd__doc__, +"gcd($module, x, y, /)\n" +"--\n" +"\n" +"Greatest common divisor of x and y."); + +#define IMATH_GCD_METHODDEF \ + {"gcd", (PyCFunction)(void(*)(void))imath_gcd, METH_FASTCALL, imath_gcd__doc__}, + +static PyObject * +imath_gcd_impl(PyObject *module, PyObject *x, PyObject *y); + +static PyObject * +imath_gcd(PyObject *module, PyObject *const *args, Py_ssize_t nargs) +{ + PyObject *return_value = NULL; + PyObject *x; + PyObject *y; + + if (!_PyArg_CheckPositional("gcd", nargs, 2, 2)) { + goto exit; + } + x = args[0]; + y = args[1]; + return_value = imath_gcd_impl(module, x, y); + +exit: + return return_value; +} + +PyDoc_STRVAR(imath_ilog2__doc__, +"ilog2($module, n, /)\n" +"--\n" +"\n" +"Return the integer part of the base 2 logarithm of the input."); + +#define IMATH_ILOG2_METHODDEF \ + {"ilog2", (PyCFunction)imath_ilog2, METH_O, imath_ilog2__doc__}, + +PyDoc_STRVAR(imath_isqrt__doc__, +"isqrt($module, n, /)\n" +"--\n" +"\n" +"Return the integer part of the square root of the input."); + +#define IMATH_ISQRT_METHODDEF \ + {"isqrt", (PyCFunction)imath_isqrt, METH_O, imath_isqrt__doc__}, + +PyDoc_STRVAR(imath_factorial__doc__, +"factorial($module, x, /)\n" +"--\n" +"\n" +"Find x!.\n" +"\n" +"Raise a TypeError if x is not an integer.\n" +"Raise a ValueError if x is negative integer."); + +#define IMATH_FACTORIAL_METHODDEF \ + {"factorial", (PyCFunction)imath_factorial, METH_O, imath_factorial__doc__}, + +PyDoc_STRVAR(imath_perm__doc__, +"perm($module, n, k, /)\n" +"--\n" +"\n" +"Number of ways to choose k items from n items without repetition and with order.\n" +"\n" +"It is mathematically equal to the expression n! / (n - k)!.\n" +"\n" +"Raises TypeError if the arguments are not integers.\n" +"Raises ValueError if the arguments are negative or if k > n."); + +#define IMATH_PERM_METHODDEF \ + {"perm", (PyCFunction)(void(*)(void))imath_perm, METH_FASTCALL, imath_perm__doc__}, + +static PyObject * +imath_perm_impl(PyObject *module, PyObject *n, PyObject *k); + +static PyObject * +imath_perm(PyObject *module, PyObject *const *args, Py_ssize_t nargs) +{ + PyObject *return_value = NULL; + PyObject *n; + PyObject *k; + + if (!_PyArg_CheckPositional("perm", nargs, 2, 2)) { + goto exit; + } + n = args[0]; + k = args[1]; + return_value = imath_perm_impl(module, n, k); + +exit: + return return_value; +} + +PyDoc_STRVAR(imath_comb__doc__, +"comb($module, n, k, /)\n" +"--\n" +"\n" +"Number of ways to choose k items from n items without repetition and without order.\n" +"\n" +"Also called the binomial coefficient. It is mathematically equal to the expression\n" +"n! / (k! * (n - k)!). It is equivalent to the coefficient of k-th term in\n" +"polynomial expansion of the expression (1 + x)**n.\n" +"\n" +"Raises TypeError if the arguments are not integers.\n" +"Raises ValueError if the arguments are negative or if k > n."); + +#define IMATH_COMB_METHODDEF \ + {"comb", (PyCFunction)(void(*)(void))imath_comb, METH_FASTCALL, imath_comb__doc__}, + +static PyObject * +imath_comb_impl(PyObject *module, PyObject *n, PyObject *k); + +static PyObject * +imath_comb(PyObject *module, PyObject *const *args, Py_ssize_t nargs) +{ + PyObject *return_value = NULL; + PyObject *n; + PyObject *k; + + if (!_PyArg_CheckPositional("comb", nargs, 2, 2)) { + goto exit; + } + n = args[0]; + k = args[1]; + return_value = imath_comb_impl(module, n, k); + +exit: + return return_value; +} +/*[clinic end generated code: output=47736df1c2f249c7 input=a9049054013a1b77]*/ diff --git a/Modules/clinic/mathmodule.c.h b/Modules/clinic/mathmodule.c.h index 0efe5cc409ceb1..388a6400db5f6a 100644 --- a/Modules/clinic/mathmodule.c.h +++ b/Modules/clinic/mathmodule.c.h @@ -6,7 +6,9 @@ PyDoc_STRVAR(math_gcd__doc__, "gcd($module, x, y, /)\n" "--\n" "\n" -"greatest common divisor of x and y"); +"greatest common divisor of x and y\n" +"\n" +"See also imath.gcd()."); #define MATH_GCD_METHODDEF \ {"gcd", (PyCFunction)(void(*)(void))math_gcd, METH_FASTCALL, math_gcd__doc__}, @@ -65,22 +67,15 @@ PyDoc_STRVAR(math_fsum__doc__, #define MATH_FSUM_METHODDEF \ {"fsum", (PyCFunction)math_fsum, METH_O, math_fsum__doc__}, -PyDoc_STRVAR(math_isqrt__doc__, -"isqrt($module, n, /)\n" -"--\n" -"\n" -"Return the integer part of the square root of the input."); - -#define MATH_ISQRT_METHODDEF \ - {"isqrt", (PyCFunction)math_isqrt, METH_O, math_isqrt__doc__}, - PyDoc_STRVAR(math_factorial__doc__, "factorial($module, x, /)\n" "--\n" "\n" "Find x!.\n" "\n" -"Raise a ValueError if x is negative or non-integral."); +"Raise a ValueError if x is negative or non-integral.\n" +"\n" +"See also imath.factorial()."); #define MATH_FACTORIAL_METHODDEF \ {"factorial", (PyCFunction)math_factorial, METH_O, math_factorial__doc__}, @@ -637,76 +632,4 @@ math_prod(PyObject *module, PyObject *const *args, Py_ssize_t nargs, PyObject *k exit: return return_value; } - -PyDoc_STRVAR(math_perm__doc__, -"perm($module, n, k, /)\n" -"--\n" -"\n" -"Number of ways to choose k items from n items without repetition and with order.\n" -"\n" -"It is mathematically equal to the expression n! / (n - k)!.\n" -"\n" -"Raises TypeError if the arguments are not integers.\n" -"Raises ValueError if the arguments are negative or if k > n."); - -#define MATH_PERM_METHODDEF \ - {"perm", (PyCFunction)(void(*)(void))math_perm, METH_FASTCALL, math_perm__doc__}, - -static PyObject * -math_perm_impl(PyObject *module, PyObject *n, PyObject *k); - -static PyObject * -math_perm(PyObject *module, PyObject *const *args, Py_ssize_t nargs) -{ - PyObject *return_value = NULL; - PyObject *n; - PyObject *k; - - if (!_PyArg_CheckPositional("perm", nargs, 2, 2)) { - goto exit; - } - n = args[0]; - k = args[1]; - return_value = math_perm_impl(module, n, k); - -exit: - return return_value; -} - -PyDoc_STRVAR(math_comb__doc__, -"comb($module, n, k, /)\n" -"--\n" -"\n" -"Number of ways to choose k items from n items without repetition and without order.\n" -"\n" -"Also called the binomial coefficient. It is mathematically equal to the expression\n" -"n! / (k! * (n - k)!). It is equivalent to the coefficient of k-th term in\n" -"polynomial expansion of the expression (1 + x)**n.\n" -"\n" -"Raises TypeError if the arguments are not integers.\n" -"Raises ValueError if the arguments are negative or if k > n."); - -#define MATH_COMB_METHODDEF \ - {"comb", (PyCFunction)(void(*)(void))math_comb, METH_FASTCALL, math_comb__doc__}, - -static PyObject * -math_comb_impl(PyObject *module, PyObject *n, PyObject *k); - -static PyObject * -math_comb(PyObject *module, PyObject *const *args, Py_ssize_t nargs) -{ - PyObject *return_value = NULL; - PyObject *n; - PyObject *k; - - if (!_PyArg_CheckPositional("comb", nargs, 2, 2)) { - goto exit; - } - n = args[0]; - k = args[1]; - return_value = math_comb_impl(module, n, k); - -exit: - return return_value; -} -/*[clinic end generated code: output=a82b0e705b6d0ec0 input=a9049054013a1b77]*/ +/*[clinic end generated code: output=e9a8e18bfc89e9b1 input=a9049054013a1b77]*/ diff --git a/Modules/mathmodule.c b/Modules/mathmodule.c index 6e1099321c5495..b2fdddabe29835 100644 --- a/Modules/mathmodule.c +++ b/Modules/mathmodule.c @@ -76,6 +76,8 @@ static const double logpi = 1.144729885849400174143427351353058711647; static const double sqrtpi = 1.772453850905516027298167483341145182798; #endif /* !defined(HAVE_ERF) || !defined(HAVE_ERFC) */ +static PyObject *imath_factorial = NULL; + /* Version of PyFloat_AsDouble() with in-line fast paths for exact floats and integers. Gives a substantial @@ -833,11 +835,13 @@ math.gcd / greatest common divisor of x and y + +See also imath.gcd(). [clinic start generated code]*/ static PyObject * math_gcd_impl(PyObject *module, PyObject *a, PyObject *b) -/*[clinic end generated code: output=7b2e0c151bd7a5d8 input=c2691e57fb2a98fa]*/ +/*[clinic end generated code: output=7b2e0c151bd7a5d8 input=10d2c441d48df152]*/ { PyObject *g; @@ -1443,524 +1447,6 @@ math_fsum(PyObject *module, PyObject *seq) #undef NUM_PARTIALS -/* Return the smallest integer k such that n < 2**k, or 0 if n == 0. - * Equivalent to floor(lg(x))+1. Also equivalent to: bitwidth_of_type - - * count_leading_zero_bits(x) - */ - -/* XXX: This routine does more or less the same thing as - * bits_in_digit() in Objects/longobject.c. Someday it would be nice to - * consolidate them. On BSD, there's a library function called fls() - * that we could use, and GCC provides __builtin_clz(). - */ - -static unsigned long -bit_length(unsigned long n) -{ - unsigned long len = 0; - while (n != 0) { - ++len; - n >>= 1; - } - return len; -} - -static unsigned long -count_set_bits(unsigned long n) -{ - unsigned long count = 0; - while (n != 0) { - ++count; - n &= n - 1; /* clear least significant bit */ - } - return count; -} - -/* Integer square root - -Given a nonnegative integer `n`, we want to compute the largest integer -`a` for which `a * a <= n`, or equivalently the integer part of the exact -square root of `n`. - -We use an adaptive-precision pure-integer version of Newton's iteration. Given -a positive integer `n`, the algorithm produces at each iteration an integer -approximation `a` to the square root of `n >> s` for some even integer `s`, -with `s` decreasing as the iterations progress. On the final iteration, `s` is -zero and we have an approximation to the square root of `n` itself. - -At every step, the approximation `a` is strictly within 1.0 of the true square -root, so we have - - (a - 1)**2 < (n >> s) < (a + 1)**2 - -After the final iteration, a check-and-correct step is needed to determine -whether `a` or `a - 1` gives the desired integer square root of `n`. - -The algorithm is remarkable in its simplicity. There's no need for a -per-iteration check-and-correct step, and termination is straightforward: the -number of iterations is known in advance (it's exactly `floor(log2(log2(n)))` -for `n > 1`). The only tricky part of the correctness proof is in establishing -that the bound `(a - 1)**2 < (n >> s) < (a + 1)**2` is maintained from one -iteration to the next. A sketch of the proof of this is given below. - -In addition to the proof sketch, a formal, computer-verified proof -of correctness (using Lean) of an equivalent recursive algorithm can be found -here: - - https://github.com/mdickinson/snippets/blob/master/proofs/isqrt/src/isqrt.lean - - -Here's Python code equivalent to the C implementation below: - - def isqrt(n): - """ - Return the integer part of the square root of the input. - """ - n = operator.index(n) - - if n < 0: - raise ValueError("isqrt() argument must be nonnegative") - if n == 0: - return 0 - - c = (n.bit_length() - 1) // 2 - a = 1 - d = 0 - for s in reversed(range(c.bit_length())): - e = d - d = c >> s - a = (a << d - e - 1) + (n >> 2*c - e - d + 1) // a - assert (a-1)**2 < n >> 2*(c - d) < (a+1)**2 - - return a - (a*a > n) - - -Sketch of proof of correctness ------------------------------- - -The delicate part of the correctness proof is showing that the loop invariant -is preserved from one iteration to the next. That is, just before the line - - a = (a << d - e - 1) + (n >> 2*c - e - d + 1) // a - -is executed in the above code, we know that - - (1) (a - 1)**2 < (n >> 2*(c - e)) < (a + 1)**2. - -(since `e` is always the value of `d` from the previous iteration). We must -prove that after that line is executed, we have - - (a - 1)**2 < (n >> 2*(c - d)) < (a + 1)**2 - -To faciliate the proof, we make some changes of notation. Write `m` for -`n >> 2*(c-d)`, and write `b` for the new value of `a`, so - - b = (a << d - e - 1) + (n >> 2*c - e - d + 1) // a - -or equivalently: - - (2) b = (a << d - e - 1) + (m >> d - e + 1) // a - -Then we can rewrite (1) as: - - (3) (a - 1)**2 < (m >> 2*(d - e)) < (a + 1)**2 - -and we must show that (b - 1)**2 < m < (b + 1)**2. - -From this point on, we switch to mathematical notation, so `/` means exact -division rather than integer division and `^` is used for exponentiation. We -use the `√` symbol for the exact square root. In (3), we can remove the -implicit floor operation to give: - - (4) (a - 1)^2 < m / 4^(d - e) < (a + 1)^2 - -Taking square roots throughout (4), scaling by `2^(d-e)`, and rearranging gives - - (5) 0 <= | 2^(d-e)a - √m | < 2^(d-e) - -Squaring and dividing through by `2^(d-e+1) a` gives - - (6) 0 <= 2^(d-e-1) a + m / (2^(d-e+1) a) - √m < 2^(d-e-1) / a - -We'll show below that `2^(d-e-1) <= a`. Given that, we can replace the -right-hand side of (6) with `1`, and now replacing the central -term `m / (2^(d-e+1) a)` with its floor in (6) gives - - (7) -1 < 2^(d-e-1) a + m // 2^(d-e+1) a - √m < 1 - -Or equivalently, from (2): - - (7) -1 < b - √m < 1 - -and rearranging gives that `(b-1)^2 < m < (b+1)^2`, which is what we needed -to prove. - -We're not quite done: we still have to prove the inequality `2^(d - e - 1) <= -a` that was used to get line (7) above. From the definition of `c`, we have -`4^c <= n`, which implies - - (8) 4^d <= m - -also, since `e == d >> 1`, `d` is at most `2e + 1`, from which it follows -that `2d - 2e - 1 <= d` and hence that - - (9) 4^(2d - 2e - 1) <= m - -Dividing both sides by `4^(d - e)` gives - - (10) 4^(d - e - 1) <= m / 4^(d - e) - -But we know from (4) that `m / 4^(d-e) < (a + 1)^2`, hence - - (11) 4^(d - e - 1) < (a + 1)^2 - -Now taking square roots of both sides and observing that both `2^(d-e-1)` and -`a` are integers gives `2^(d - e - 1) <= a`, which is what we needed. This -completes the proof sketch. - -*/ - - -/* Approximate square root of a large 64-bit integer. - - Given `n` satisfying `2**62 <= n < 2**64`, return `a` - satisfying `(a - 1)**2 < n < (a + 1)**2`. */ - -static uint64_t -_approximate_isqrt(uint64_t n) -{ - uint32_t u = 1U + (n >> 62); - u = (u << 1) + (n >> 59) / u; - u = (u << 3) + (n >> 53) / u; - u = (u << 7) + (n >> 41) / u; - return (u << 15) + (n >> 17) / u; -} - -/*[clinic input] -math.isqrt - - n: object - / - -Return the integer part of the square root of the input. -[clinic start generated code]*/ - -static PyObject * -math_isqrt(PyObject *module, PyObject *n) -/*[clinic end generated code: output=35a6f7f980beab26 input=5b6e7ae4fa6c43d6]*/ -{ - int a_too_large, c_bit_length; - size_t c, d; - uint64_t m, u; - PyObject *a = NULL, *b; - - n = PyNumber_Index(n); - if (n == NULL) { - return NULL; - } - - if (_PyLong_Sign(n) < 0) { - PyErr_SetString( - PyExc_ValueError, - "isqrt() argument must be nonnegative"); - goto error; - } - if (_PyLong_Sign(n) == 0) { - Py_DECREF(n); - return PyLong_FromLong(0); - } - - /* c = (n.bit_length() - 1) // 2 */ - c = _PyLong_NumBits(n); - if (c == (size_t)(-1)) { - goto error; - } - c = (c - 1U) / 2U; - - /* Fast path: if c <= 31 then n < 2**64 and we can compute directly with a - fast, almost branch-free algorithm. In the final correction, we use `u*u - - 1 >= m` instead of the simpler `u*u > m` in order to get the correct - result in the corner case where `u=2**32`. */ - if (c <= 31U) { - m = (uint64_t)PyLong_AsUnsignedLongLong(n); - Py_DECREF(n); - if (m == (uint64_t)(-1) && PyErr_Occurred()) { - return NULL; - } - u = _approximate_isqrt(m << (62U - 2U*c)) >> (31U - c); - u -= u * u - 1U >= m; - return PyLong_FromUnsignedLongLong((unsigned long long)u); - } - - /* Slow path: n >= 2**64. We perform the first five iterations in C integer - arithmetic, then switch to using Python long integers. */ - - /* From n >= 2**64 it follows that c.bit_length() >= 6. */ - c_bit_length = 6; - while ((c >> c_bit_length) > 0U) { - ++c_bit_length; - } - - /* Initialise d and a. */ - d = c >> (c_bit_length - 5); - b = _PyLong_Rshift(n, 2U*c - 62U); - if (b == NULL) { - goto error; - } - m = (uint64_t)PyLong_AsUnsignedLongLong(b); - Py_DECREF(b); - if (m == (uint64_t)(-1) && PyErr_Occurred()) { - goto error; - } - u = _approximate_isqrt(m) >> (31U - d); - a = PyLong_FromUnsignedLongLong((unsigned long long)u); - if (a == NULL) { - goto error; - } - - for (int s = c_bit_length - 6; s >= 0; --s) { - PyObject *q; - size_t e = d; - - d = c >> s; - - /* q = (n >> 2*c - e - d + 1) // a */ - q = _PyLong_Rshift(n, 2U*c - d - e + 1U); - if (q == NULL) { - goto error; - } - Py_SETREF(q, PyNumber_FloorDivide(q, a)); - if (q == NULL) { - goto error; - } - - /* a = (a << d - 1 - e) + q */ - Py_SETREF(a, _PyLong_Lshift(a, d - 1U - e)); - if (a == NULL) { - Py_DECREF(q); - goto error; - } - Py_SETREF(a, PyNumber_Add(a, q)); - Py_DECREF(q); - if (a == NULL) { - goto error; - } - } - - /* The correct result is either a or a - 1. Figure out which, and - decrement a if necessary. */ - - /* a_too_large = n < a * a */ - b = PyNumber_Multiply(a, a); - if (b == NULL) { - goto error; - } - a_too_large = PyObject_RichCompareBool(n, b, Py_LT); - Py_DECREF(b); - if (a_too_large == -1) { - goto error; - } - - if (a_too_large) { - Py_SETREF(a, PyNumber_Subtract(a, _PyLong_One)); - } - Py_DECREF(n); - return a; - - error: - Py_XDECREF(a); - Py_DECREF(n); - return NULL; -} - -/* Divide-and-conquer factorial algorithm - * - * Based on the formula and pseudo-code provided at: - * http://www.luschny.de/math/factorial/binarysplitfact.html - * - * Faster algorithms exist, but they're more complicated and depend on - * a fast prime factorization algorithm. - * - * Notes on the algorithm - * ---------------------- - * - * factorial(n) is written in the form 2**k * m, with m odd. k and m are - * computed separately, and then combined using a left shift. - * - * The function factorial_odd_part computes the odd part m (i.e., the greatest - * odd divisor) of factorial(n), using the formula: - * - * factorial_odd_part(n) = - * - * product_{i >= 0} product_{0 < j <= n / 2**i, j odd} j - * - * Example: factorial_odd_part(20) = - * - * (1) * - * (1) * - * (1 * 3 * 5) * - * (1 * 3 * 5 * 7 * 9) - * (1 * 3 * 5 * 7 * 9 * 11 * 13 * 15 * 17 * 19) - * - * Here i goes from large to small: the first term corresponds to i=4 (any - * larger i gives an empty product), and the last term corresponds to i=0. - * Each term can be computed from the last by multiplying by the extra odd - * numbers required: e.g., to get from the penultimate term to the last one, - * we multiply by (11 * 13 * 15 * 17 * 19). - * - * To see a hint of why this formula works, here are the same numbers as above - * but with the even parts (i.e., the appropriate powers of 2) included. For - * each subterm in the product for i, we multiply that subterm by 2**i: - * - * factorial(20) = - * - * (16) * - * (8) * - * (4 * 12 * 20) * - * (2 * 6 * 10 * 14 * 18) * - * (1 * 3 * 5 * 7 * 9 * 11 * 13 * 15 * 17 * 19) - * - * The factorial_partial_product function computes the product of all odd j in - * range(start, stop) for given start and stop. It's used to compute the - * partial products like (11 * 13 * 15 * 17 * 19) in the example above. It - * operates recursively, repeatedly splitting the range into two roughly equal - * pieces until the subranges are small enough to be computed using only C - * integer arithmetic. - * - * The two-valuation k (i.e., the exponent of the largest power of 2 dividing - * the factorial) is computed independently in the main math_factorial - * function. By standard results, its value is: - * - * two_valuation = n//2 + n//4 + n//8 + .... - * - * It can be shown (e.g., by complete induction on n) that two_valuation is - * equal to n - count_set_bits(n), where count_set_bits(n) gives the number of - * '1'-bits in the binary expansion of n. - */ - -/* factorial_partial_product: Compute product(range(start, stop, 2)) using - * divide and conquer. Assumes start and stop are odd and stop > start. - * max_bits must be >= bit_length(stop - 2). */ - -static PyObject * -factorial_partial_product(unsigned long start, unsigned long stop, - unsigned long max_bits) -{ - unsigned long midpoint, num_operands; - PyObject *left = NULL, *right = NULL, *result = NULL; - - /* If the return value will fit an unsigned long, then we can - * multiply in a tight, fast loop where each multiply is O(1). - * Compute an upper bound on the number of bits required to store - * the answer. - * - * Storing some integer z requires floor(lg(z))+1 bits, which is - * conveniently the value returned by bit_length(z). The - * product x*y will require at most - * bit_length(x) + bit_length(y) bits to store, based - * on the idea that lg product = lg x + lg y. - * - * We know that stop - 2 is the largest number to be multiplied. From - * there, we have: bit_length(answer) <= num_operands * - * bit_length(stop - 2) - */ - - num_operands = (stop - start) / 2; - /* The "num_operands <= 8 * SIZEOF_LONG" check guards against the - * unlikely case of an overflow in num_operands * max_bits. */ - if (num_operands <= 8 * SIZEOF_LONG && - num_operands * max_bits <= 8 * SIZEOF_LONG) { - unsigned long j, total; - for (total = start, j = start + 2; j < stop; j += 2) - total *= j; - return PyLong_FromUnsignedLong(total); - } - - /* find midpoint of range(start, stop), rounded up to next odd number. */ - midpoint = (start + num_operands) | 1; - left = factorial_partial_product(start, midpoint, - bit_length(midpoint - 2)); - if (left == NULL) - goto error; - right = factorial_partial_product(midpoint, stop, max_bits); - if (right == NULL) - goto error; - result = PyNumber_Multiply(left, right); - - error: - Py_XDECREF(left); - Py_XDECREF(right); - return result; -} - -/* factorial_odd_part: compute the odd part of factorial(n). */ - -static PyObject * -factorial_odd_part(unsigned long n) -{ - long i; - unsigned long v, lower, upper; - PyObject *partial, *tmp, *inner, *outer; - - inner = PyLong_FromLong(1); - if (inner == NULL) - return NULL; - outer = inner; - Py_INCREF(outer); - - upper = 3; - for (i = bit_length(n) - 2; i >= 0; i--) { - v = n >> i; - if (v <= 2) - continue; - lower = upper; - /* (v + 1) | 1 = least odd integer strictly larger than n / 2**i */ - upper = (v + 1) | 1; - /* Here inner is the product of all odd integers j in the range (0, - n/2**(i+1)]. The factorial_partial_product call below gives the - product of all odd integers j in the range (n/2**(i+1), n/2**i]. */ - partial = factorial_partial_product(lower, upper, bit_length(upper-2)); - /* inner *= partial */ - if (partial == NULL) - goto error; - tmp = PyNumber_Multiply(inner, partial); - Py_DECREF(partial); - if (tmp == NULL) - goto error; - Py_DECREF(inner); - inner = tmp; - /* Now inner is the product of all odd integers j in the range (0, - n/2**i], giving the inner product in the formula above. */ - - /* outer *= inner; */ - tmp = PyNumber_Multiply(outer, inner); - if (tmp == NULL) - goto error; - Py_DECREF(outer); - outer = tmp; - } - Py_DECREF(inner); - return outer; - - error: - Py_DECREF(outer); - Py_DECREF(inner); - return NULL; -} - - -/* Lookup table for small factorial values */ - -static const unsigned long SmallFactorials[] = { - 1, 1, 2, 6, 24, 120, 720, 5040, 40320, - 362880, 3628800, 39916800, 479001600, -#if SIZEOF_LONG >= 8 - 6227020800, 87178291200, 1307674368000, - 20922789888000, 355687428096000, 6402373705728000, - 121645100408832000, 2432902008176640000 -#endif -}; - /*[clinic input] math.factorial @@ -1970,66 +1456,33 @@ math.factorial Find x!. Raise a ValueError if x is negative or non-integral. + +See also imath.factorial(). [clinic start generated code]*/ static PyObject * math_factorial(PyObject *module, PyObject *arg) -/*[clinic end generated code: output=6686f26fae00e9ca input=6d1c8105c0d91fb4]*/ +/*[clinic end generated code: output=6686f26fae00e9ca input=3a6ec477a80c807b]*/ { - long x, two_valuation; - int overflow; - PyObject *result, *odd_part, *pyint_form; + PyObject *result; if (PyFloat_Check(arg)) { - PyObject *lx; double dx = PyFloat_AS_DOUBLE((PyFloatObject *)arg); if (!(Py_IS_FINITE(dx) && dx == floor(dx))) { PyErr_SetString(PyExc_ValueError, "factorial() only accepts integral values"); return NULL; } - lx = PyLong_FromDouble(dx); - if (lx == NULL) + arg = PyLong_FromDouble(dx); + if (arg == NULL) return NULL; - x = PyLong_AsLongAndOverflow(lx, &overflow); - Py_DECREF(lx); } else { - pyint_form = PyNumber_Index(arg); - if (pyint_form == NULL) { - return NULL; - } - x = PyLong_AsLongAndOverflow(pyint_form, &overflow); - Py_DECREF(pyint_form); + Py_INCREF(arg); } - if (x == -1 && PyErr_Occurred()) { - return NULL; - } - else if (overflow == 1) { - PyErr_Format(PyExc_OverflowError, - "factorial() argument should not exceed %ld", - LONG_MAX); - return NULL; - } - else if (overflow == -1 || x < 0) { - PyErr_SetString(PyExc_ValueError, - "factorial() not defined for negative values"); - return NULL; - } - - /* use lookup table if x is small */ - if (x < (long)Py_ARRAY_LENGTH(SmallFactorials)) - return PyLong_FromUnsignedLong(SmallFactorials[x]); - - /* else express in the form odd_part * 2**two_valuation, and compute as - odd_part << two_valuation. */ - odd_part = factorial_odd_part(x); - if (odd_part == NULL) - return NULL; - two_valuation = x - count_set_bits(x); - result = _PyLong_Lshift(odd_part, two_valuation); - Py_DECREF(odd_part); + result = _PyObject_FastCall(imath_factorial, &arg, 1); + Py_DECREF(arg); return result; } @@ -2998,260 +2451,6 @@ math_prod_impl(PyObject *module, PyObject *iterable, PyObject *start) } -/*[clinic input] -math.perm - - n: object - k: object - / - -Number of ways to choose k items from n items without repetition and with order. - -It is mathematically equal to the expression n! / (n - k)!. - -Raises TypeError if the arguments are not integers. -Raises ValueError if the arguments are negative or if k > n. -[clinic start generated code]*/ - -static PyObject * -math_perm_impl(PyObject *module, PyObject *n, PyObject *k) -/*[clinic end generated code: output=e021a25469653e23 input=f71ee4f6ff26be24]*/ -{ - PyObject *result = NULL, *factor = NULL; - int overflow, cmp; - long long i, factors; - - n = PyNumber_Index(n); - if (n == NULL) { - return NULL; - } - if (!PyLong_CheckExact(n)) { - Py_SETREF(n, _PyLong_Copy((PyLongObject *)n)); - if (n == NULL) { - return NULL; - } - } - k = PyNumber_Index(k); - if (k == NULL) { - Py_DECREF(n); - return NULL; - } - if (!PyLong_CheckExact(k)) { - Py_SETREF(k, _PyLong_Copy((PyLongObject *)k)); - if (k == NULL) { - Py_DECREF(n); - return NULL; - } - } - - if (Py_SIZE(n) < 0) { - PyErr_SetString(PyExc_ValueError, - "n must be a non-negative integer"); - goto error; - } - cmp = PyObject_RichCompareBool(n, k, Py_LT); - if (cmp != 0) { - if (cmp > 0) { - PyErr_SetString(PyExc_ValueError, - "k must be an integer less than or equal to n"); - } - goto error; - } - - factors = PyLong_AsLongLongAndOverflow(k, &overflow); - if (overflow > 0) { - PyErr_Format(PyExc_OverflowError, - "k must not exceed %lld", - LLONG_MAX); - goto error; - } - else if (overflow < 0 || factors < 0) { - if (!PyErr_Occurred()) { - PyErr_SetString(PyExc_ValueError, - "k must be a non-negative integer"); - } - goto error; - } - - if (factors == 0) { - result = PyLong_FromLong(1); - goto done; - } - - result = n; - Py_INCREF(result); - if (factors == 1) { - goto done; - } - - factor = n; - Py_INCREF(factor); - for (i = 1; i < factors; ++i) { - Py_SETREF(factor, PyNumber_Subtract(factor, _PyLong_One)); - if (factor == NULL) { - goto error; - } - Py_SETREF(result, PyNumber_Multiply(result, factor)); - if (result == NULL) { - goto error; - } - } - Py_DECREF(factor); - -done: - Py_DECREF(n); - Py_DECREF(k); - return result; - -error: - Py_XDECREF(factor); - Py_XDECREF(result); - Py_DECREF(n); - Py_DECREF(k); - return NULL; -} - - -/*[clinic input] -math.comb - - n: object - k: object - / - -Number of ways to choose k items from n items without repetition and without order. - -Also called the binomial coefficient. It is mathematically equal to the expression -n! / (k! * (n - k)!). It is equivalent to the coefficient of k-th term in -polynomial expansion of the expression (1 + x)**n. - -Raises TypeError if the arguments are not integers. -Raises ValueError if the arguments are negative or if k > n. - -[clinic start generated code]*/ - -static PyObject * -math_comb_impl(PyObject *module, PyObject *n, PyObject *k) -/*[clinic end generated code: output=bd2cec8d854f3493 input=2f336ac9ec8242f9]*/ -{ - PyObject *result = NULL, *factor = NULL, *temp; - int overflow, cmp; - long long i, factors; - - n = PyNumber_Index(n); - if (n == NULL) { - return NULL; - } - if (!PyLong_CheckExact(n)) { - Py_SETREF(n, _PyLong_Copy((PyLongObject *)n)); - if (n == NULL) { - return NULL; - } - } - k = PyNumber_Index(k); - if (k == NULL) { - Py_DECREF(n); - return NULL; - } - if (!PyLong_CheckExact(k)) { - Py_SETREF(k, _PyLong_Copy((PyLongObject *)k)); - if (k == NULL) { - Py_DECREF(n); - return NULL; - } - } - - if (Py_SIZE(n) < 0) { - PyErr_SetString(PyExc_ValueError, - "n must be a non-negative integer"); - goto error; - } - /* k = min(k, n - k) */ - temp = PyNumber_Subtract(n, k); - if (temp == NULL) { - goto error; - } - if (Py_SIZE(temp) < 0) { - Py_DECREF(temp); - PyErr_SetString(PyExc_ValueError, - "k must be an integer less than or equal to n"); - goto error; - } - cmp = PyObject_RichCompareBool(temp, k, Py_LT); - if (cmp > 0) { - Py_SETREF(k, temp); - } - else { - Py_DECREF(temp); - if (cmp < 0) { - goto error; - } - } - - factors = PyLong_AsLongLongAndOverflow(k, &overflow); - if (overflow > 0) { - PyErr_Format(PyExc_OverflowError, - "min(n - k, k) must not exceed %lld", - LLONG_MAX); - goto error; - } - else if (overflow < 0 || factors < 0) { - if (!PyErr_Occurred()) { - PyErr_SetString(PyExc_ValueError, - "k must be a non-negative integer"); - } - goto error; - } - - if (factors == 0) { - result = PyLong_FromLong(1); - goto done; - } - - result = n; - Py_INCREF(result); - if (factors == 1) { - goto done; - } - - factor = n; - Py_INCREF(factor); - for (i = 1; i < factors; ++i) { - Py_SETREF(factor, PyNumber_Subtract(factor, _PyLong_One)); - if (factor == NULL) { - goto error; - } - Py_SETREF(result, PyNumber_Multiply(result, factor)); - if (result == NULL) { - goto error; - } - - temp = PyLong_FromUnsignedLongLong((unsigned long long)i + 1); - if (temp == NULL) { - goto error; - } - Py_SETREF(result, PyNumber_FloorDivide(result, temp)); - Py_DECREF(temp); - if (result == NULL) { - goto error; - } - } - Py_DECREF(factor); - -done: - Py_DECREF(n); - Py_DECREF(k); - return result; - -error: - Py_XDECREF(factor); - Py_XDECREF(result); - Py_DECREF(n); - Py_DECREF(k); - return NULL; -} - - static PyMethodDef math_methods[] = { {"acos", math_acos, METH_O, math_acos_doc}, {"acosh", math_acosh, METH_O, math_acosh_doc}, @@ -3283,7 +2482,6 @@ static PyMethodDef math_methods[] = { MATH_ISFINITE_METHODDEF MATH_ISINF_METHODDEF MATH_ISNAN_METHODDEF - MATH_ISQRT_METHODDEF MATH_LDEXP_METHODDEF {"lgamma", math_lgamma, METH_O, math_lgamma_doc}, MATH_LOG_METHODDEF @@ -3301,8 +2499,6 @@ static PyMethodDef math_methods[] = { {"tanh", math_tanh, METH_O, math_tanh_doc}, MATH_TRUNC_METHODDEF MATH_PROD_METHODDEF - MATH_PERM_METHODDEF - MATH_COMB_METHODDEF {NULL, NULL} /* sentinel */ }; @@ -3327,7 +2523,7 @@ static struct PyModuleDef mathmodule = { PyMODINIT_FUNC PyInit_math(void) { - PyObject *m; + PyObject *m, *imath; m = PyModule_Create(&mathmodule); if (m == NULL) @@ -3341,6 +2537,20 @@ PyInit_math(void) PyModule_AddObject(m, "nan", PyFloat_FromDouble(m_nan())); #endif + imath = PyImport_ImportModule("imath"); + if (!imath) { + goto error; + } + imath_factorial = PyObject_GetAttrString(imath, "factorial"); + Py_DECREF(imath); + if (!imath_factorial) { + goto error; + } + finally: return m; + + error: + Py_DECREF(m); + return NULL; } diff --git a/PC/config.c b/PC/config.c index 6f34962bd72d4f..45b4e98a9c6700 100644 --- a/PC/config.c +++ b/PC/config.c @@ -14,6 +14,7 @@ extern PyObject* PyInit_errno(void); extern PyObject* PyInit_faulthandler(void); extern PyObject* PyInit__tracemalloc(void); extern PyObject* PyInit_gc(void); +extern PyObject* PyInit__imath(void); extern PyObject* PyInit_math(void); extern PyObject* PyInit__md5(void); extern PyObject* PyInit_nt(void); @@ -91,6 +92,7 @@ struct _inittab _PyImport_Inittab[] = { {"errno", PyInit_errno}, {"faulthandler", PyInit_faulthandler}, {"gc", PyInit_gc}, + {"imath", PyInit__imath}, {"math", PyInit_math}, {"nt", PyInit_nt}, /* Use the NT os functions, not posix */ {"_operator", PyInit__operator}, diff --git a/PCbuild/pythoncore.vcxproj b/PCbuild/pythoncore.vcxproj index 329f9feb2bdf02..d8eb30a05a0b29 100644 --- a/PCbuild/pythoncore.vcxproj +++ b/PCbuild/pythoncore.vcxproj @@ -316,6 +316,7 @@ + diff --git a/PCbuild/pythoncore.vcxproj.filters b/PCbuild/pythoncore.vcxproj.filters index d80d05fb15a0cf..54d7c1095cb564 100644 --- a/PCbuild/pythoncore.vcxproj.filters +++ b/PCbuild/pythoncore.vcxproj.filters @@ -635,6 +635,9 @@ Modules + + Modules + Modules diff --git a/setup.py b/setup.py index 7852c2dfa27e08..647b56e586ffb2 100644 --- a/setup.py +++ b/setup.py @@ -691,6 +691,9 @@ def detect_simple_extensions(self): # Context Variables self.add(Extension('_contextvars', ['_contextvarsmodule.c'])) + # integer math library functions, e.g. isqrt() + self.add(Extension('_imath', ['_imathmodule.c'])) + shared_math = 'Modules/_math.o' # math library functions, e.g. sin()