diff --git a/Doc/library/imath.rst b/Doc/library/imath.rst
new file mode 100644
index 00000000000000..d3dfce1befc1d7
--- /dev/null
+++ b/Doc/library/imath.rst
@@ -0,0 +1,77 @@
+:mod:`imath` --- Mathematical functions for integer numbers
+===========================================================
+
+.. module:: imath
+ :synopsis: Mathematical functions for integer numbers.
+
+.. versionadded:: 3.8
+
+**Source code:** :source:`Lib/imath.py`
+
+--------------
+
+This module provides access to the mathematical functions for integer arguments.
+These functions accept integers and objects that implement the
+:meth:`__index__` method which is used to convert the object to an integer
+number. They cannot be used with floating-point numbers or complex
+numbers.
+
+The following functions are provided by this module. All return values are
+integers.
+
+
+.. function:: comb(n, k)
+
+ Return the number of ways to choose *k* items from *n* items without repetition
+ and without order.
+
+ Also called the binomial coefficient. It is mathematically equal to the expression
+ ``n! / (k! (n - k)!)``. It is equivalent to the coefficient of the *k*-th term in the
+ polynomial expansion of the expression ``(1 + x) ** n``.
+
+ Raises :exc:`TypeError` if the arguments not integers.
+ Raises :exc:`ValueError` if the arguments are negative or if *k* > *n*.
+
+
+.. function:: gcd(a, b)
+
+ Return the greatest common divisor of the integers *a* and *b*. If either
+ *a* or *b* is nonzero, then the value of ``gcd(a, b)`` is the largest
+ positive integer that divides both *a* and *b*. ``gcd(0, 0)`` returns
+ ``0``.
+
+
+.. function:: ilog2(n)
+
+ Return the integer base 2 logarithm of the positive integer *n*. This is the
+ floor of the exact base 2 logarithm root of *n*, or equivalently the
+ greatest integer *k* such that
+ 2\ :sup:`k` |nbsp| ≤ |nbsp| *n* |nbsp| < |nbsp| 2\ :sup:`k+1`.
+
+ It is equivalent to ``n.bit_length() - 1`` for positive *n*.
+
+
+.. function:: isqrt(n)
+
+ Return the integer square root of the nonnegative integer *n*. This is the
+ floor of the exact square root of *n*, or equivalently the greatest integer
+ *a* such that *a*\ ² |nbsp| ≤ |nbsp| *n*.
+
+ For some applications, it may be more convenient to have the least integer
+ *a* such that *n* |nbsp| ≤ |nbsp| *a*\ ², or in other words the ceiling of
+ the exact square root of *n*. For positive *n*, this can be computed using
+ ``a = 1 + isqrt(n - 1)``.
+
+
+.. function:: perm(n, k)
+
+ Return the number of ways to choose *k* items from *n* items
+ without repetition and with order.
+
+ It is mathematically equal to the expression ``n! / (n - k)!``.
+
+ Raises :exc:`TypeError` if the arguments not integers.
+ Raises :exc:`ValueError` if the arguments are negative or if *k* > *n*.
+
+.. |nbsp| unicode:: 0xA0
+ :trim:
diff --git a/Doc/library/math.rst b/Doc/library/math.rst
index c5a77f1fab9fd6..4798f3074ea6f1 100644
--- a/Doc/library/math.rst
+++ b/Doc/library/math.rst
@@ -36,21 +36,6 @@ Number-theoretic and representation functions
:class:`~numbers.Integral` value.
-.. function:: comb(n, k)
-
- Return the number of ways to choose *k* items from *n* items without repetition
- and without order.
-
- Also called the binomial coefficient. It is mathematically equal to the expression
- ``n! / (k! (n - k)!)``. It is equivalent to the coefficient of the *k*-th term in the
- polynomial expansion of the expression ``(1 + x) ** n``.
-
- Raises :exc:`TypeError` if the arguments not integers.
- Raises :exc:`ValueError` if the arguments are negative or if *k* > *n*.
-
- .. versionadded:: 3.8
-
-
.. function:: copysign(x, y)
Return a float with the magnitude (absolute value) of *x* but the sign of
@@ -65,8 +50,8 @@ Number-theoretic and representation functions
.. function:: factorial(x)
- Return *x* factorial as an integer. Raises :exc:`ValueError` if *x* is not integral or
- is negative.
+ Similar to :func:`imath.factorial`, but accepts also floating-point numbers
+ with integer value (like ``3.0``).
.. function:: floor(x)
@@ -122,10 +107,7 @@ Number-theoretic and representation functions
.. function:: gcd(a, b)
- Return the greatest common divisor of the integers *a* and *b*. If either
- *a* or *b* is nonzero, then the value of ``gcd(a, b)`` is the largest
- positive integer that divides both *a* and *b*. ``gcd(0, 0)`` returns
- ``0``.
+ An alias of :func:`imath.gcd`.
.. versionadded:: 3.5
@@ -181,20 +163,6 @@ Number-theoretic and representation functions
Return ``True`` if *x* is a NaN (not a number), and ``False`` otherwise.
-.. function:: isqrt(n)
-
- Return the integer square root of the nonnegative integer *n*. This is the
- floor of the exact square root of *n*, or equivalently the greatest integer
- *a* such that *a*\ ² |nbsp| ≤ |nbsp| *n*.
-
- For some applications, it may be more convenient to have the least integer
- *a* such that *n* |nbsp| ≤ |nbsp| *a*\ ², or in other words the ceiling of
- the exact square root of *n*. For positive *n*, this can be computed using
- ``a = 1 + isqrt(n - 1)``.
-
- .. versionadded:: 3.8
-
-
.. function:: ldexp(x, i)
Return ``x * (2**i)``. This is essentially the inverse of function
@@ -207,19 +175,6 @@ Number-theoretic and representation functions
of *x* and are floats.
-.. function:: perm(n, k)
-
- Return the number of ways to choose *k* items from *n* items
- without repetition and with order.
-
- It is mathematically equal to the expression ``n! / (n - k)!``.
-
- Raises :exc:`TypeError` if the arguments not integers.
- Raises :exc:`ValueError` if the arguments are negative or if *k* > *n*.
-
- .. versionadded:: 3.8
-
-
.. function:: prod(iterable, *, start=1)
Calculate the product of all the elements in the input *iterable*.
@@ -580,6 +535,3 @@ Constants
Module :mod:`cmath`
Complex number versions of many of these functions.
-
-.. |nbsp| unicode:: 0xA0
- :trim:
diff --git a/Lib/test/test_imath.py b/Lib/test/test_imath.py
new file mode 100644
index 00000000000000..376da0cdaf4a4a
--- /dev/null
+++ b/Lib/test/test_imath.py
@@ -0,0 +1,344 @@
+from decimal import Decimal
+from fractions import Fraction
+import unittest
+from test import support
+
+py_imath = support.import_fresh_module('imath', blocked=['_imath'])
+c_imath = support.import_fresh_module('imath', fresh=['_imath'])
+
+class IntSubclass(int):
+ pass
+
+# Class providing an __index__ method.
+class MyIndexable(object):
+ def __init__(self, value):
+ self.value = value
+
+ def __index__(self):
+ return self.value
+
+# Here's a pure Python version of the math.factorial algorithm, for
+# documentation and comparison purposes.
+#
+# Formula:
+#
+# factorial(n) = factorial_odd_part(n) << (n - count_set_bits(n))
+#
+# where
+#
+# factorial_odd_part(n) = product_{i >= 0} product_{0 < j <= n >> i; j odd} j
+#
+# The outer product above is an infinite product, but once i >= n.bit_length,
+# (n >> i) < 1 and the corresponding term of the product is empty. So only the
+# finitely many terms for 0 <= i < n.bit_length() contribute anything.
+#
+# We iterate downwards from i == n.bit_length() - 1 to i == 0. The inner
+# product in the formula above starts at 1 for i == n.bit_length(); for each i
+# < n.bit_length() we get the inner product for i from that for i + 1 by
+# multiplying by all j in {n >> i+1 < j <= n >> i; j odd}. In Python terms,
+# this set is range((n >> i+1) + 1 | 1, (n >> i) + 1 | 1, 2).
+
+def count_set_bits(n):
+ """Number of '1' bits in binary expansion of a nonnnegative integer."""
+ return 1 + count_set_bits(n & n - 1) if n else 0
+
+def partial_product(start, stop):
+ """Product of integers in range(start, stop, 2), computed recursively.
+ start and stop should both be odd, with start <= stop.
+
+ """
+ numfactors = (stop - start) >> 1
+ if not numfactors:
+ return 1
+ elif numfactors == 1:
+ return start
+ else:
+ mid = (start + numfactors) | 1
+ return partial_product(start, mid) * partial_product(mid, stop)
+
+def py_factorial(n):
+ """Factorial of nonnegative integer n, via "Binary Split Factorial Formula"
+ described at http://www.luschny.de/math/factorial/binarysplitfact.html
+
+ """
+ inner = outer = 1
+ for i in reversed(range(n.bit_length())):
+ inner *= partial_product((n >> i + 1) + 1 | 1, (n >> i) + 1 | 1)
+ outer *= inner
+ return outer << (n - count_set_bits(n))
+
+
+class IMathTests:
+
+ def assertIntEqual(self, actual, expected):
+ self.assertEqual(actual, expected)
+ self.assertIs(type(actual), int)
+
+ def testFactorial(self):
+ factorial = self.module.factorial
+ self.assertEqual(factorial(0), 1)
+ total = 1
+ for i in range(1, 1000):
+ total *= i
+ self.assertEqual(factorial(i), total)
+ self.assertEqual(factorial(i), py_factorial(i))
+
+ self.assertIntEqual(factorial(False), 1)
+ self.assertIntEqual(factorial(True), 1)
+ for i in range(3):
+ expected = factorial(i)
+ self.assertIntEqual(factorial(IntSubclass(i)), expected)
+ self.assertIntEqual(factorial(MyIndexable(i)), expected)
+
+ self.assertRaises(ValueError, factorial, -1)
+ self.assertRaises(ValueError, factorial, -10**1000)
+
+ self.assertRaises(TypeError, factorial, 5.0)
+ self.assertRaises(TypeError, factorial, -1.0)
+ self.assertRaises(TypeError, factorial, -1.0e100)
+ self.assertRaises(TypeError, factorial, Decimal(5.0))
+ self.assertRaises(TypeError, factorial, Fraction(5, 1))
+ self.assertRaises(TypeError, factorial, "5")
+
+ if self.module is c_imath:
+ self.assertRaises((OverflowError, MemoryError), factorial, 10**100)
+
+ def testGcd(self):
+ gcd = self.module.gcd
+ self.assertEqual(gcd(0, 0), 0)
+ self.assertEqual(gcd(1, 0), 1)
+ self.assertEqual(gcd(-1, 0), 1)
+ self.assertEqual(gcd(0, 1), 1)
+ self.assertEqual(gcd(0, -1), 1)
+ self.assertEqual(gcd(7, 1), 1)
+ self.assertEqual(gcd(7, -1), 1)
+ self.assertEqual(gcd(-23, 15), 1)
+ self.assertEqual(gcd(120, 84), 12)
+ self.assertEqual(gcd(84, -120), 12)
+ self.assertEqual(gcd(1216342683557601535506311712,
+ 436522681849110124616458784), 32)
+ c = 652560
+ x = 434610456570399902378880679233098819019853229470286994367836600566
+ y = 1064502245825115327754847244914921553977
+ a = x * c
+ b = y * c
+ self.assertEqual(gcd(a, b), c)
+ self.assertEqual(gcd(b, a), c)
+ self.assertEqual(gcd(-a, b), c)
+ self.assertEqual(gcd(b, -a), c)
+ self.assertEqual(gcd(a, -b), c)
+ self.assertEqual(gcd(-b, a), c)
+ self.assertEqual(gcd(-a, -b), c)
+ self.assertEqual(gcd(-b, -a), c)
+ c = 576559230871654959816130551884856912003141446781646602790216406874
+ a = x * c
+ b = y * c
+ self.assertEqual(gcd(a, b), c)
+ self.assertEqual(gcd(b, a), c)
+ self.assertEqual(gcd(-a, b), c)
+ self.assertEqual(gcd(b, -a), c)
+ self.assertEqual(gcd(a, -b), c)
+ self.assertEqual(gcd(-b, a), c)
+ self.assertEqual(gcd(-a, -b), c)
+ self.assertEqual(gcd(-b, -a), c)
+
+ self.assertRaises(TypeError, gcd, 120.0, 84)
+ self.assertRaises(TypeError, gcd, 120, 84.0)
+ self.assertIntEqual(gcd(IntSubclass(120), IntSubclass(84)), 12)
+ self.assertIntEqual(gcd(MyIndexable(120), MyIndexable(84)), 12)
+
+ def testIlog(self):
+ ilog2 = self.module.ilog2
+ for value in range(1, 1000):
+ k = ilog2(value)
+ self.assertLessEqual(2**k, value)
+ self.assertLess(value, 2**(k+1))
+ self.assertRaises(ValueError, ilog2, 0)
+ self.assertRaises(ValueError, ilog2, -1)
+ self.assertRaises(ValueError, ilog2, -2**1000)
+
+ self.assertIntEqual(ilog2(True), 0)
+ self.assertIntEqual(ilog2(IntSubclass(5)), 2)
+ self.assertIntEqual(ilog2(MyIndexable(5)), 2)
+
+ self.assertRaises(TypeError, ilog2, 5.0)
+ self.assertRaises(TypeError, ilog2, Decimal('5'))
+ self.assertRaises(TypeError, ilog2, Fraction(5, 1))
+ self.assertRaises(TypeError, ilog2, '5')
+
+ def testIsqrt(self):
+ isqrt = self.module.isqrt
+ # Test a variety of inputs, large and small.
+ test_values = (
+ list(range(1000))
+ + list(range(10**6 - 1000, 10**6 + 1000))
+ + [2**e + i for e in range(60, 200) for i in range(-40, 40)]
+ + [3**9999, 10**5001]
+ )
+
+ for value in test_values:
+ with self.subTest(value=value):
+ s = isqrt(value)
+ self.assertIs(type(s), int)
+ self.assertLessEqual(s*s, value)
+ self.assertLess(value, (s+1)*(s+1))
+
+ # Negative values
+ with self.assertRaises(ValueError):
+ isqrt(-1)
+
+ # Integer-like things
+ self.assertIntEqual(isqrt(False), 0)
+ self.assertIntEqual(isqrt(True), 1)
+ self.assertIntEqual(isqrt(MyIndexable(1729)), 41)
+
+ with self.assertRaises(ValueError):
+ isqrt(MyIndexable(-3))
+
+ # Non-integer-like things
+ for value in 4.0, "4", Decimal("4.0"), Fraction(4, 1), 4j, -4.0:
+ with self.subTest(value=value):
+ with self.assertRaises(TypeError):
+ isqrt(value)
+
+ def testPerm(self):
+ perm = self.module.perm
+ factorial = self.module.factorial
+ # Test if factorial defintion is satisfied
+ for n in range(100):
+ for k in range(n + 1):
+ self.assertEqual(perm(n, k),
+ factorial(n) // factorial(n - k))
+
+ # Test for Pascal's identity
+ for n in range(1, 100):
+ for k in range(1, n):
+ self.assertEqual(perm(n, k), perm(n - 1, k - 1) * k + perm(n - 1, k))
+
+ # Test corner cases
+ for n in range(1, 100):
+ self.assertEqual(perm(n, 0), 1)
+ self.assertEqual(perm(n, 1), n)
+ self.assertEqual(perm(n, n), factorial(n))
+
+ # Raises TypeError if any argument is non-integer or argument count is
+ # not 2
+ self.assertRaises(TypeError, perm, 10, 1.0)
+ self.assertRaises(TypeError, perm, 10, Decimal(1.0))
+ self.assertRaises(TypeError, perm, 10, Fraction(1, 1))
+ self.assertRaises(TypeError, perm, 10, "1")
+ self.assertRaises(TypeError, perm, 10.0, 1)
+ self.assertRaises(TypeError, perm, Decimal(10.0), 1)
+ self.assertRaises(TypeError, perm, Fraction(10, 1), 1)
+ self.assertRaises(TypeError, perm, "10", 1)
+
+ self.assertRaises(TypeError, perm, 10)
+ self.assertRaises(TypeError, perm, 10, 1, 3)
+ self.assertRaises(TypeError, perm)
+
+ # Raises Value error if not k or n are negative numbers
+ self.assertRaises(ValueError, perm, -1, 1)
+ self.assertRaises(ValueError, perm, -2**1000, 1)
+ self.assertRaises(ValueError, perm, 1, -1)
+ self.assertRaises(ValueError, perm, 1, -2**1000)
+
+ # Raises value error if k is greater than n
+ self.assertRaises(ValueError, perm, 1, 2)
+ self.assertRaises(ValueError, perm, 1, 2**1000)
+
+ n = 2**1000
+ self.assertEqual(perm(n, 0), 1)
+ self.assertEqual(perm(n, 1), n)
+ self.assertEqual(perm(n, 2), n * (n-1))
+ if self.module is c_imath:
+ self.assertRaises((OverflowError, MemoryError), perm, n, n)
+
+ for n, k in (True, True), (True, False), (False, False):
+ self.assertIntEqual(perm(n, k), 1)
+ self.assertEqual(perm(IntSubclass(5), IntSubclass(2)), 20)
+ self.assertEqual(perm(MyIndexable(5), MyIndexable(2)), 20)
+ for k in range(3):
+ self.assertIs(type(perm(IntSubclass(5), IntSubclass(k))), int)
+ self.assertIs(type(perm(MyIndexable(5), MyIndexable(k))), int)
+
+ def testComb(self):
+ comb = self.module.comb
+ factorial = self.module.factorial
+ # Test if factorial defintion is satisfied
+ for n in range(100):
+ for k in range(n + 1):
+ self.assertEqual(comb(n, k), factorial(n)
+ // (factorial(k) * factorial(n - k)))
+
+ # Test for Pascal's identity
+ for n in range(1, 100):
+ for k in range(1, n):
+ self.assertEqual(comb(n, k), comb(n - 1, k - 1) + comb(n - 1, k))
+
+ # Test corner cases
+ for n in range(100):
+ self.assertEqual(comb(n, 0), 1)
+ self.assertEqual(comb(n, n), 1)
+
+ for n in range(1, 100):
+ self.assertEqual(comb(n, 1), n)
+ self.assertEqual(comb(n, n - 1), n)
+
+ # Test Symmetry
+ for n in range(100):
+ for k in range(n // 2):
+ self.assertEqual(comb(n, k), comb(n, n - k))
+
+ # Raises TypeError if any argument is non-integer or argument count is
+ # not 2
+ self.assertRaises(TypeError, comb, 10, 1.0)
+ self.assertRaises(TypeError, comb, 10, Decimal(1.0))
+ self.assertRaises(TypeError, comb, 10, Fraction(1, 1))
+ self.assertRaises(TypeError, comb, 10, "1")
+ self.assertRaises(TypeError, comb, 10.0, 1)
+ self.assertRaises(TypeError, comb, Fraction(10, 1), 1)
+ self.assertRaises(TypeError, comb, "10", 1)
+
+ self.assertRaises(TypeError, comb, 10)
+ self.assertRaises(TypeError, comb, 10, 1, 3)
+ self.assertRaises(TypeError, comb)
+
+ # Raises Value error if not k or n are negative numbers
+ self.assertRaises(ValueError, comb, -1, 1)
+ self.assertRaises(ValueError, comb, -2**1000, 1)
+ self.assertRaises(ValueError, comb, 1, -1)
+ self.assertRaises(ValueError, comb, 1, -2**1000)
+
+ # Raises value error if k is greater than n
+ self.assertRaises(ValueError, comb, 1, 2)
+ self.assertRaises(ValueError, comb, 1, 2**1000)
+
+ n = 2**1000
+ self.assertEqual(comb(n, 0), 1)
+ self.assertEqual(comb(n, 1), n)
+ self.assertEqual(comb(n, 2), n * (n-1) // 2)
+ self.assertEqual(comb(n, n), 1)
+ self.assertEqual(comb(n, n-1), n)
+ self.assertEqual(comb(n, n-2), n * (n-1) // 2)
+ if self.module is c_imath:
+ self.assertRaises((OverflowError, MemoryError), comb, n, n//2)
+
+ for n, k in (True, True), (True, False), (False, False):
+ self.assertIntEqual(comb(n, k), 1)
+ self.assertIntEqual(comb(IntSubclass(5), IntSubclass(2)), 10)
+ self.assertIntEqual(comb(MyIndexable(5), MyIndexable(2)), 10)
+ for k in range(3):
+ self.assertIs(type(comb(IntSubclass(5), IntSubclass(k))), int)
+ self.assertIs(type(comb(MyIndexable(5), MyIndexable(k))), int)
+
+
+class PyIMathTests(IMathTests, unittest.TestCase):
+ module = py_imath
+
+@unittest.skipUnless(c_imath, 'requires _imath')
+class CIMathTests(IMathTests, unittest.TestCase):
+ module = c_imath
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_math.py b/Lib/test/test_math.py
index 96e0cf2fe67197..b0532d0bbac925 100644
--- a/Lib/test/test_math.py
+++ b/Lib/test/test_math.py
@@ -3,6 +3,7 @@
from test.support import run_unittest, verbose, requires_IEEE_754
from test import support
+from .test_imath import py_factorial
import unittest
import itertools
import decimal
@@ -77,56 +78,6 @@ def ulp(x):
else:
return x_next - x
-# Here's a pure Python version of the math.factorial algorithm, for
-# documentation and comparison purposes.
-#
-# Formula:
-#
-# factorial(n) = factorial_odd_part(n) << (n - count_set_bits(n))
-#
-# where
-#
-# factorial_odd_part(n) = product_{i >= 0} product_{0 < j <= n >> i; j odd} j
-#
-# The outer product above is an infinite product, but once i >= n.bit_length,
-# (n >> i) < 1 and the corresponding term of the product is empty. So only the
-# finitely many terms for 0 <= i < n.bit_length() contribute anything.
-#
-# We iterate downwards from i == n.bit_length() - 1 to i == 0. The inner
-# product in the formula above starts at 1 for i == n.bit_length(); for each i
-# < n.bit_length() we get the inner product for i from that for i + 1 by
-# multiplying by all j in {n >> i+1 < j <= n >> i; j odd}. In Python terms,
-# this set is range((n >> i+1) + 1 | 1, (n >> i) + 1 | 1, 2).
-
-def count_set_bits(n):
- """Number of '1' bits in binary expansion of a nonnnegative integer."""
- return 1 + count_set_bits(n & n - 1) if n else 0
-
-def partial_product(start, stop):
- """Product of integers in range(start, stop, 2), computed recursively.
- start and stop should both be odd, with start <= stop.
-
- """
- numfactors = (stop - start) >> 1
- if not numfactors:
- return 1
- elif numfactors == 1:
- return start
- else:
- mid = (start + numfactors) | 1
- return partial_product(start, mid) * partial_product(mid, stop)
-
-def py_factorial(n):
- """Factorial of nonnegative integer n, via "Binary Split Factorial Formula"
- described at http://www.luschny.de/math/factorial/binarysplitfact.html
-
- """
- inner = outer = 1
- for i in reversed(range(n.bit_length())):
- inner *= partial_product((n >> i + 1) + 1 | 1, (n >> i) + 1 | 1)
- outer *= inner
- return outer << (n - count_set_bits(n))
-
def ulp_abs_check(expected, got, ulp_tol, abs_tol):
"""Given finite floats `expected` and `got`, check that they're
approximately equal to within the given number of ulps or the
@@ -240,9 +191,6 @@ def result_check(expected, got, ulp_tol=5, abs_tol=0.0):
else:
return None
-class IntSubclass(int):
- pass
-
# Class providing an __index__ method.
class MyIndexable(object):
def __init__(self, value):
@@ -915,59 +863,6 @@ class T(tuple):
self.assertEqual(math.dist(p, q), 5*scale)
self.assertEqual(math.dist(q, p), 5*scale)
- def testIsqrt(self):
- # Test a variety of inputs, large and small.
- test_values = (
- list(range(1000))
- + list(range(10**6 - 1000, 10**6 + 1000))
- + [2**e + i for e in range(60, 200) for i in range(-40, 40)]
- + [3**9999, 10**5001]
- )
-
- for value in test_values:
- with self.subTest(value=value):
- s = math.isqrt(value)
- self.assertIs(type(s), int)
- self.assertLessEqual(s*s, value)
- self.assertLess(value, (s+1)*(s+1))
-
- # Negative values
- with self.assertRaises(ValueError):
- math.isqrt(-1)
-
- # Integer-like things
- s = math.isqrt(True)
- self.assertIs(type(s), int)
- self.assertEqual(s, 1)
-
- s = math.isqrt(False)
- self.assertIs(type(s), int)
- self.assertEqual(s, 0)
-
- class IntegerLike(object):
- def __init__(self, value):
- self.value = value
-
- def __index__(self):
- return self.value
-
- s = math.isqrt(IntegerLike(1729))
- self.assertIs(type(s), int)
- self.assertEqual(s, 41)
-
- with self.assertRaises(ValueError):
- math.isqrt(IntegerLike(-3))
-
- # Non-integer-like things
- bad_values = [
- 3.5, "a string", decimal.Decimal("3.5"), 3.5j,
- 100.0, -4.0,
- ]
- for value in bad_values:
- with self.subTest(value=value):
- with self.assertRaises(TypeError):
- math.isqrt(value)
-
def testLdexp(self):
self.assertRaises(TypeError, math.ldexp)
self.ftest('ldexp(0,1)', math.ldexp(0,1), 0)
@@ -1865,133 +1760,6 @@ def test_fractions(self):
self.assertAllClose(fraction_examples, rel_tol=1e-8)
self.assertAllNotClose(fraction_examples, rel_tol=1e-9)
- def testPerm(self):
- perm = math.perm
- factorial = math.factorial
- # Test if factorial defintion is satisfied
- for n in range(100):
- for k in range(n + 1):
- self.assertEqual(perm(n, k),
- factorial(n) // factorial(n - k))
-
- # Test for Pascal's identity
- for n in range(1, 100):
- for k in range(1, n):
- self.assertEqual(perm(n, k), perm(n - 1, k - 1) * k + perm(n - 1, k))
-
- # Test corner cases
- for n in range(1, 100):
- self.assertEqual(perm(n, 0), 1)
- self.assertEqual(perm(n, 1), n)
- self.assertEqual(perm(n, n), factorial(n))
-
- # Raises TypeError if any argument is non-integer or argument count is
- # not 2
- self.assertRaises(TypeError, perm, 10, 1.0)
- self.assertRaises(TypeError, perm, 10, decimal.Decimal(1.0))
- self.assertRaises(TypeError, perm, 10, "1")
- self.assertRaises(TypeError, perm, 10.0, 1)
- self.assertRaises(TypeError, perm, decimal.Decimal(10.0), 1)
- self.assertRaises(TypeError, perm, "10", 1)
-
- self.assertRaises(TypeError, perm, 10)
- self.assertRaises(TypeError, perm, 10, 1, 3)
- self.assertRaises(TypeError, perm)
-
- # Raises Value error if not k or n are negative numbers
- self.assertRaises(ValueError, perm, -1, 1)
- self.assertRaises(ValueError, perm, -2**1000, 1)
- self.assertRaises(ValueError, perm, 1, -1)
- self.assertRaises(ValueError, perm, 1, -2**1000)
-
- # Raises value error if k is greater than n
- self.assertRaises(ValueError, perm, 1, 2)
- self.assertRaises(ValueError, perm, 1, 2**1000)
-
- n = 2**1000
- self.assertEqual(perm(n, 0), 1)
- self.assertEqual(perm(n, 1), n)
- self.assertEqual(perm(n, 2), n * (n-1))
- self.assertRaises((OverflowError, MemoryError), perm, n, n)
-
- for n, k in (True, True), (True, False), (False, False):
- self.assertEqual(perm(n, k), 1)
- self.assertIs(type(perm(n, k)), int)
- self.assertEqual(perm(IntSubclass(5), IntSubclass(2)), 20)
- self.assertEqual(perm(MyIndexable(5), MyIndexable(2)), 20)
- for k in range(3):
- self.assertIs(type(perm(IntSubclass(5), IntSubclass(k))), int)
- self.assertIs(type(perm(MyIndexable(5), MyIndexable(k))), int)
-
- def testComb(self):
- comb = math.comb
- factorial = math.factorial
- # Test if factorial defintion is satisfied
- for n in range(100):
- for k in range(n + 1):
- self.assertEqual(comb(n, k), factorial(n)
- // (factorial(k) * factorial(n - k)))
-
- # Test for Pascal's identity
- for n in range(1, 100):
- for k in range(1, n):
- self.assertEqual(comb(n, k), comb(n - 1, k - 1) + comb(n - 1, k))
-
- # Test corner cases
- for n in range(100):
- self.assertEqual(comb(n, 0), 1)
- self.assertEqual(comb(n, n), 1)
-
- for n in range(1, 100):
- self.assertEqual(comb(n, 1), n)
- self.assertEqual(comb(n, n - 1), n)
-
- # Test Symmetry
- for n in range(100):
- for k in range(n // 2):
- self.assertEqual(comb(n, k), comb(n, n - k))
-
- # Raises TypeError if any argument is non-integer or argument count is
- # not 2
- self.assertRaises(TypeError, comb, 10, 1.0)
- self.assertRaises(TypeError, comb, 10, decimal.Decimal(1.0))
- self.assertRaises(TypeError, comb, 10, "1")
- self.assertRaises(TypeError, comb, 10.0, 1)
- self.assertRaises(TypeError, comb, decimal.Decimal(10.0), 1)
- self.assertRaises(TypeError, comb, "10", 1)
-
- self.assertRaises(TypeError, comb, 10)
- self.assertRaises(TypeError, comb, 10, 1, 3)
- self.assertRaises(TypeError, comb)
-
- # Raises Value error if not k or n are negative numbers
- self.assertRaises(ValueError, comb, -1, 1)
- self.assertRaises(ValueError, comb, -2**1000, 1)
- self.assertRaises(ValueError, comb, 1, -1)
- self.assertRaises(ValueError, comb, 1, -2**1000)
-
- # Raises value error if k is greater than n
- self.assertRaises(ValueError, comb, 1, 2)
- self.assertRaises(ValueError, comb, 1, 2**1000)
-
- n = 2**1000
- self.assertEqual(comb(n, 0), 1)
- self.assertEqual(comb(n, 1), n)
- self.assertEqual(comb(n, 2), n * (n-1) // 2)
- self.assertEqual(comb(n, n), 1)
- self.assertEqual(comb(n, n-1), n)
- self.assertEqual(comb(n, n-2), n * (n-1) // 2)
- self.assertRaises((OverflowError, MemoryError), comb, n, n//2)
-
- for n, k in (True, True), (True, False), (False, False):
- self.assertEqual(comb(n, k), 1)
- self.assertIs(type(comb(n, k)), int)
- self.assertEqual(comb(IntSubclass(5), IntSubclass(2)), 10)
- self.assertEqual(comb(MyIndexable(5), MyIndexable(2)), 10)
- for k in range(3):
- self.assertIs(type(comb(IntSubclass(5), IntSubclass(k))), int)
- self.assertIs(type(comb(MyIndexable(5), MyIndexable(k))), int)
-
def test_main():
from doctest import DocFileSuite
diff --git a/Misc/NEWS.d/next/Library/2019-06-02-13-56-16.bpo-37132.axawSH.rst b/Misc/NEWS.d/next/Library/2019-06-02-13-56-16.bpo-37132.axawSH.rst
new file mode 100644
index 00000000000000..e41edf69ef1cd0
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2019-06-02-13-56-16.bpo-37132.axawSH.rst
@@ -0,0 +1 @@
+Add the :mod:`imath` module.
diff --git a/Modules/Setup b/Modules/Setup
index e729ab883f410b..316f5c43500da9 100644
--- a/Modules/Setup
+++ b/Modules/Setup
@@ -169,6 +169,7 @@ _symtable symtablemodule.c
#array arraymodule.c # array objects
#cmath cmathmodule.c _math.c # -lm # complex math library functions
#math mathmodule.c _math.c # -lm # math library functions, e.g. sin()
+#_imath _imathmodule.c # integer math library functions, e.g. isqrt()
#_contextvars _contextvarsmodule.c # Context Variables
#_struct _struct.c # binary structure packing/unpacking
#_weakref _weakref.c # basic weak reference support
diff --git a/Modules/_imathmodule.c b/Modules/_imathmodule.c
new file mode 100644
index 00000000000000..13f872ca149866
--- /dev/null
+++ b/Modules/_imathmodule.c
@@ -0,0 +1,943 @@
+/* imath module -- integer-related mathematical functions */
+
+#include "Python.h"
+
+#include "clinic/_imathmodule.c.h"
+
+/*[clinic input]
+module imath
+[clinic start generated code]*/
+/*[clinic end generated code: output=da39a3ee5e6b4b0d input=0b5ec353335010dd]*/
+
+
+/*[clinic input]
+imath.gcd
+
+ x: object
+ y: object
+ /
+
+Greatest common divisor of x and y.
+[clinic start generated code]*/
+
+static PyObject *
+imath_gcd_impl(PyObject *module, PyObject *x, PyObject *y)
+/*[clinic end generated code: output=14eee3e4a3bd7a1d input=11612898ad79c57c]*/
+{
+ PyObject *result;
+
+ x = PyNumber_Index(x);
+ if (x == NULL)
+ return NULL;
+ y = PyNumber_Index(y);
+ if (y == NULL) {
+ Py_DECREF(x);
+ return NULL;
+ }
+ result = _PyLong_GCD(x, y);
+ Py_DECREF(x);
+ Py_DECREF(y);
+ return result;
+}
+
+
+/*[clinic input]
+imath.ilog2
+
+ n: object
+ /
+
+Return the integer part of the base 2 logarithm of the input.
+[clinic start generated code]*/
+
+static PyObject *
+imath_ilog2(PyObject *module, PyObject *n)
+/*[clinic end generated code: output=6ab48d1a7f5160c2 input=e2d8e8631ec5c29b]*/
+{
+ size_t bits;
+
+ n = PyNumber_Index(n);
+ if (n == NULL) {
+ return NULL;
+ }
+
+ if (Py_SIZE(n) <= 0) {
+ PyErr_SetString(
+ PyExc_ValueError,
+ "ilog() argument must be positive");
+ Py_DECREF(n);
+ return NULL;
+ }
+
+ bits = _PyLong_NumBits(n);
+ Py_DECREF(n);
+ if (bits == (size_t)(-1)) {
+ return NULL;
+ }
+ return PyLong_FromSize_t(bits - 1);
+}
+
+
+/* Integer square root
+
+Given a nonnegative integer `n`, we want to compute the largest integer
+`a` for which `a * a <= n`, or equivalently the integer part of the exact
+square root of `n`.
+
+We use an adaptive-precision pure-integer version of Newton's iteration. Given
+a positive integer `n`, the algorithm produces at each iteration an integer
+approximation `a` to the square root of `n >> s` for some even integer `s`,
+with `s` decreasing as the iterations progress. On the final iteration, `s` is
+zero and we have an approximation to the square root of `n` itself.
+
+At every step, the approximation `a` is strictly within 1.0 of the true square
+root, so we have
+
+ (a - 1)**2 < (n >> s) < (a + 1)**2
+
+After the final iteration, a check-and-correct step is needed to determine
+whether `a` or `a - 1` gives the desired integer square root of `n`.
+
+The algorithm is remarkable in its simplicity. There's no need for a
+per-iteration check-and-correct step, and termination is straightforward: the
+number of iterations is known in advance (it's exactly `floor(log2(log2(n)))`
+for `n > 1`). The only tricky part of the correctness proof is in establishing
+that the bound `(a - 1)**2 < (n >> s) < (a + 1)**2` is maintained from one
+iteration to the next. A sketch of the proof of this is given below.
+
+In addition to the proof sketch, a formal, computer-verified proof
+of correctness (using Lean) of an equivalent recursive algorithm can be found
+here:
+
+ https://github.com/mdickinson/snippets/blob/master/proofs/isqrt/src/isqrt.lean
+
+
+Here's Python code equivalent to the C implementation below:
+
+ def isqrt(n):
+ """
+ Return the integer part of the square root of the input.
+ """
+ n = operator.index(n)
+
+ if n < 0:
+ raise ValueError("isqrt() argument must be nonnegative")
+ if n == 0:
+ return 0
+
+ c = (n.bit_length() - 1) // 2
+ a = 1
+ d = 0
+ for s in reversed(range(c.bit_length())):
+ e = d
+ d = c >> s
+ a = (a << d - e - 1) + (n >> 2*c - e - d + 1) // a
+ assert (a-1)**2 < n >> 2*(c - d) < (a+1)**2
+
+ return a - (a*a > n)
+
+
+Sketch of proof of correctness
+------------------------------
+
+The delicate part of the correctness proof is showing that the loop invariant
+is preserved from one iteration to the next. That is, just before the line
+
+ a = (a << d - e - 1) + (n >> 2*c - e - d + 1) // a
+
+is executed in the above code, we know that
+
+ (1) (a - 1)**2 < (n >> 2*(c - e)) < (a + 1)**2.
+
+(since `e` is always the value of `d` from the previous iteration). We must
+prove that after that line is executed, we have
+
+ (a - 1)**2 < (n >> 2*(c - d)) < (a + 1)**2
+
+To faciliate the proof, we make some changes of notation. Write `m` for
+`n >> 2*(c-d)`, and write `b` for the new value of `a`, so
+
+ b = (a << d - e - 1) + (n >> 2*c - e - d + 1) // a
+
+or equivalently:
+
+ (2) b = (a << d - e - 1) + (m >> d - e + 1) // a
+
+Then we can rewrite (1) as:
+
+ (3) (a - 1)**2 < (m >> 2*(d - e)) < (a + 1)**2
+
+and we must show that (b - 1)**2 < m < (b + 1)**2.
+
+From this point on, we switch to mathematical notation, so `/` means exact
+division rather than integer division and `^` is used for exponentiation. We
+use the `√` symbol for the exact square root. In (3), we can remove the
+implicit floor operation to give:
+
+ (4) (a - 1)^2 < m / 4^(d - e) < (a + 1)^2
+
+Taking square roots throughout (4), scaling by `2^(d-e)`, and rearranging gives
+
+ (5) 0 <= | 2^(d-e)a - √m | < 2^(d-e)
+
+Squaring and dividing through by `2^(d-e+1) a` gives
+
+ (6) 0 <= 2^(d-e-1) a + m / (2^(d-e+1) a) - √m < 2^(d-e-1) / a
+
+We'll show below that `2^(d-e-1) <= a`. Given that, we can replace the
+right-hand side of (6) with `1`, and now replacing the central
+term `m / (2^(d-e+1) a)` with its floor in (6) gives
+
+ (7) -1 < 2^(d-e-1) a + m // 2^(d-e+1) a - √m < 1
+
+Or equivalently, from (2):
+
+ (7) -1 < b - √m < 1
+
+and rearranging gives that `(b-1)^2 < m < (b+1)^2`, which is what we needed
+to prove.
+
+We're not quite done: we still have to prove the inequality `2^(d - e - 1) <=
+a` that was used to get line (7) above. From the definition of `c`, we have
+`4^c <= n`, which implies
+
+ (8) 4^d <= m
+
+also, since `e == d >> 1`, `d` is at most `2e + 1`, from which it follows
+that `2d - 2e - 1 <= d` and hence that
+
+ (9) 4^(2d - 2e - 1) <= m
+
+Dividing both sides by `4^(d - e)` gives
+
+ (10) 4^(d - e - 1) <= m / 4^(d - e)
+
+But we know from (4) that `m / 4^(d-e) < (a + 1)^2`, hence
+
+ (11) 4^(d - e - 1) < (a + 1)^2
+
+Now taking square roots of both sides and observing that both `2^(d-e-1)` and
+`a` are integers gives `2^(d - e - 1) <= a`, which is what we needed. This
+completes the proof sketch.
+
+*/
+
+
+/* Approximate square root of a large 64-bit integer.
+
+ Given `n` satisfying `2**62 <= n < 2**64`, return `a`
+ satisfying `(a - 1)**2 < n < (a + 1)**2`. */
+
+static uint64_t
+_approximate_isqrt(uint64_t n)
+{
+ uint32_t u = 1U + (n >> 62);
+ u = (u << 1) + (n >> 59) / u;
+ u = (u << 3) + (n >> 53) / u;
+ u = (u << 7) + (n >> 41) / u;
+ return (u << 15) + (n >> 17) / u;
+}
+
+/*[clinic input]
+imath.isqrt
+
+ n: object
+ /
+
+Return the integer part of the square root of the input.
+[clinic start generated code]*/
+
+static PyObject *
+imath_isqrt(PyObject *module, PyObject *n)
+/*[clinic end generated code: output=5ad16a80dd47c888 input=64f5b3e586986fc9]*/
+{
+ int a_too_large, c_bit_length;
+ size_t c, d;
+ uint64_t m, u;
+ PyObject *a = NULL, *b;
+
+ n = PyNumber_Index(n);
+ if (n == NULL) {
+ return NULL;
+ }
+
+ if (_PyLong_Sign(n) < 0) {
+ PyErr_SetString(
+ PyExc_ValueError,
+ "isqrt() argument must be nonnegative");
+ goto error;
+ }
+ if (_PyLong_Sign(n) == 0) {
+ Py_DECREF(n);
+ return PyLong_FromLong(0);
+ }
+
+ /* c = (n.bit_length() - 1) // 2 */
+ c = _PyLong_NumBits(n);
+ if (c == (size_t)(-1)) {
+ goto error;
+ }
+ c = (c - 1U) / 2U;
+
+ /* Fast path: if c <= 31 then n < 2**64 and we can compute directly with a
+ fast, almost branch-free algorithm. In the final correction, we use `u*u
+ - 1 >= m` instead of the simpler `u*u > m` in order to get the correct
+ result in the corner case where `u=2**32`. */
+ if (c <= 31U) {
+ m = (uint64_t)PyLong_AsUnsignedLongLong(n);
+ Py_DECREF(n);
+ if (m == (uint64_t)(-1) && PyErr_Occurred()) {
+ return NULL;
+ }
+ u = _approximate_isqrt(m << (62U - 2U*c)) >> (31U - c);
+ u -= u * u - 1U >= m;
+ return PyLong_FromUnsignedLongLong((unsigned long long)u);
+ }
+
+ /* Slow path: n >= 2**64. We perform the first five iterations in C integer
+ arithmetic, then switch to using Python long integers. */
+
+ /* From n >= 2**64 it follows that c.bit_length() >= 6. */
+ c_bit_length = 6;
+ while ((c >> c_bit_length) > 0U) {
+ ++c_bit_length;
+ }
+
+ /* Initialise d and a. */
+ d = c >> (c_bit_length - 5);
+ b = _PyLong_Rshift(n, 2U*c - 62U);
+ if (b == NULL) {
+ goto error;
+ }
+ m = (uint64_t)PyLong_AsUnsignedLongLong(b);
+ Py_DECREF(b);
+ if (m == (uint64_t)(-1) && PyErr_Occurred()) {
+ goto error;
+ }
+ u = _approximate_isqrt(m) >> (31U - d);
+ a = PyLong_FromUnsignedLongLong((unsigned long long)u);
+ if (a == NULL) {
+ goto error;
+ }
+
+ for (int s = c_bit_length - 6; s >= 0; --s) {
+ PyObject *q;
+ size_t e = d;
+
+ d = c >> s;
+
+ /* q = (n >> 2*c - e - d + 1) // a */
+ q = _PyLong_Rshift(n, 2U*c - d - e + 1U);
+ if (q == NULL) {
+ goto error;
+ }
+ Py_SETREF(q, PyNumber_FloorDivide(q, a));
+ if (q == NULL) {
+ goto error;
+ }
+
+ /* a = (a << d - 1 - e) + q */
+ Py_SETREF(a, _PyLong_Lshift(a, d - 1U - e));
+ if (a == NULL) {
+ Py_DECREF(q);
+ goto error;
+ }
+ Py_SETREF(a, PyNumber_Add(a, q));
+ Py_DECREF(q);
+ if (a == NULL) {
+ goto error;
+ }
+ }
+
+ /* The correct result is either a or a - 1. Figure out which, and
+ decrement a if necessary. */
+
+ /* a_too_large = n < a * a */
+ b = PyNumber_Multiply(a, a);
+ if (b == NULL) {
+ goto error;
+ }
+ a_too_large = PyObject_RichCompareBool(n, b, Py_LT);
+ Py_DECREF(b);
+ if (a_too_large == -1) {
+ goto error;
+ }
+
+ if (a_too_large) {
+ Py_SETREF(a, PyNumber_Subtract(a, _PyLong_One));
+ }
+ Py_DECREF(n);
+ return a;
+
+ error:
+ Py_XDECREF(a);
+ Py_DECREF(n);
+ return NULL;
+}
+
+
+/* Return the smallest integer k such that n < 2**k, or 0 if n == 0.
+ * Equivalent to floor(lg(x))+1. Also equivalent to: bitwidth_of_type -
+ * count_leading_zero_bits(x)
+ */
+
+/* XXX: This routine does more or less the same thing as
+ * bits_in_digit() in Objects/longobject.c. Someday it would be nice to
+ * consolidate them. On BSD, there's a library function called fls()
+ * that we could use, and GCC provides __builtin_clz().
+ */
+
+static unsigned long
+bit_length(unsigned long n)
+{
+ unsigned long len = 0;
+ while (n != 0) {
+ ++len;
+ n >>= 1;
+ }
+ return len;
+}
+
+static unsigned long
+count_set_bits(unsigned long n)
+{
+ unsigned long count = 0;
+ while (n != 0) {
+ ++count;
+ n &= n - 1; /* clear least significant bit */
+ }
+ return count;
+}
+
+
+/* Divide-and-conquer factorial algorithm
+ *
+ * Based on the formula and pseudo-code provided at:
+ * http://www.luschny.de/math/factorial/binarysplitfact.html
+ *
+ * Faster algorithms exist, but they're more complicated and depend on
+ * a fast prime factorization algorithm.
+ *
+ * Notes on the algorithm
+ * ----------------------
+ *
+ * factorial(n) is written in the form 2**k * m, with m odd. k and m are
+ * computed separately, and then combined using a left shift.
+ *
+ * The function factorial_odd_part computes the odd part m (i.e., the greatest
+ * odd divisor) of factorial(n), using the formula:
+ *
+ * factorial_odd_part(n) =
+ *
+ * product_{i >= 0} product_{0 < j <= n / 2**i, j odd} j
+ *
+ * Example: factorial_odd_part(20) =
+ *
+ * (1) *
+ * (1) *
+ * (1 * 3 * 5) *
+ * (1 * 3 * 5 * 7 * 9)
+ * (1 * 3 * 5 * 7 * 9 * 11 * 13 * 15 * 17 * 19)
+ *
+ * Here i goes from large to small: the first term corresponds to i=4 (any
+ * larger i gives an empty product), and the last term corresponds to i=0.
+ * Each term can be computed from the last by multiplying by the extra odd
+ * numbers required: e.g., to get from the penultimate term to the last one,
+ * we multiply by (11 * 13 * 15 * 17 * 19).
+ *
+ * To see a hint of why this formula works, here are the same numbers as above
+ * but with the even parts (i.e., the appropriate powers of 2) included. For
+ * each subterm in the product for i, we multiply that subterm by 2**i:
+ *
+ * factorial(20) =
+ *
+ * (16) *
+ * (8) *
+ * (4 * 12 * 20) *
+ * (2 * 6 * 10 * 14 * 18) *
+ * (1 * 3 * 5 * 7 * 9 * 11 * 13 * 15 * 17 * 19)
+ *
+ * The factorial_partial_product function computes the product of all odd j in
+ * range(start, stop) for given start and stop. It's used to compute the
+ * partial products like (11 * 13 * 15 * 17 * 19) in the example above. It
+ * operates recursively, repeatedly splitting the range into two roughly equal
+ * pieces until the subranges are small enough to be computed using only C
+ * integer arithmetic.
+ *
+ * The two-valuation k (i.e., the exponent of the largest power of 2 dividing
+ * the factorial) is computed independently in the main math_factorial
+ * function. By standard results, its value is:
+ *
+ * two_valuation = n//2 + n//4 + n//8 + ....
+ *
+ * It can be shown (e.g., by complete induction on n) that two_valuation is
+ * equal to n - count_set_bits(n), where count_set_bits(n) gives the number of
+ * '1'-bits in the binary expansion of n.
+ */
+
+/* factorial_partial_product: Compute product(range(start, stop, 2)) using
+ * divide and conquer. Assumes start and stop are odd and stop > start.
+ * max_bits must be >= bit_length(stop - 2). */
+
+static PyObject *
+factorial_partial_product(unsigned long start, unsigned long stop,
+ unsigned long max_bits)
+{
+ unsigned long midpoint, num_operands;
+ PyObject *left = NULL, *right = NULL, *result = NULL;
+
+ /* If the return value will fit an unsigned long, then we can
+ * multiply in a tight, fast loop where each multiply is O(1).
+ * Compute an upper bound on the number of bits required to store
+ * the answer.
+ *
+ * Storing some integer z requires floor(lg(z))+1 bits, which is
+ * conveniently the value returned by bit_length(z). The
+ * product x*y will require at most
+ * bit_length(x) + bit_length(y) bits to store, based
+ * on the idea that lg product = lg x + lg y.
+ *
+ * We know that stop - 2 is the largest number to be multiplied. From
+ * there, we have: bit_length(answer) <= num_operands *
+ * bit_length(stop - 2)
+ */
+
+ num_operands = (stop - start) / 2;
+ /* The "num_operands <= 8 * SIZEOF_LONG" check guards against the
+ * unlikely case of an overflow in num_operands * max_bits. */
+ if (num_operands <= 8 * SIZEOF_LONG &&
+ num_operands * max_bits <= 8 * SIZEOF_LONG) {
+ unsigned long j, total;
+ for (total = start, j = start + 2; j < stop; j += 2)
+ total *= j;
+ return PyLong_FromUnsignedLong(total);
+ }
+
+ /* find midpoint of range(start, stop), rounded up to next odd number. */
+ midpoint = (start + num_operands) | 1;
+ left = factorial_partial_product(start, midpoint,
+ bit_length(midpoint - 2));
+ if (left == NULL)
+ goto error;
+ right = factorial_partial_product(midpoint, stop, max_bits);
+ if (right == NULL)
+ goto error;
+ result = PyNumber_Multiply(left, right);
+
+ error:
+ Py_XDECREF(left);
+ Py_XDECREF(right);
+ return result;
+}
+
+/* factorial_odd_part: compute the odd part of factorial(n). */
+
+static PyObject *
+factorial_odd_part(unsigned long n)
+{
+ long i;
+ unsigned long v, lower, upper;
+ PyObject *partial, *tmp, *inner, *outer;
+
+ inner = PyLong_FromLong(1);
+ if (inner == NULL)
+ return NULL;
+ outer = inner;
+ Py_INCREF(outer);
+
+ upper = 3;
+ for (i = bit_length(n) - 2; i >= 0; i--) {
+ v = n >> i;
+ if (v <= 2)
+ continue;
+ lower = upper;
+ /* (v + 1) | 1 = least odd integer strictly larger than n / 2**i */
+ upper = (v + 1) | 1;
+ /* Here inner is the product of all odd integers j in the range (0,
+ n/2**(i+1)]. The factorial_partial_product call below gives the
+ product of all odd integers j in the range (n/2**(i+1), n/2**i]. */
+ partial = factorial_partial_product(lower, upper, bit_length(upper-2));
+ /* inner *= partial */
+ if (partial == NULL)
+ goto error;
+ tmp = PyNumber_Multiply(inner, partial);
+ Py_DECREF(partial);
+ if (tmp == NULL)
+ goto error;
+ Py_DECREF(inner);
+ inner = tmp;
+ /* Now inner is the product of all odd integers j in the range (0,
+ n/2**i], giving the inner product in the formula above. */
+
+ /* outer *= inner; */
+ tmp = PyNumber_Multiply(outer, inner);
+ if (tmp == NULL)
+ goto error;
+ Py_DECREF(outer);
+ outer = tmp;
+ }
+ Py_DECREF(inner);
+ return outer;
+
+ error:
+ Py_DECREF(outer);
+ Py_DECREF(inner);
+ return NULL;
+}
+
+
+/* Lookup table for small factorial values */
+
+static const unsigned long SmallFactorials[] = {
+ 1, 1, 2, 6, 24, 120, 720, 5040, 40320,
+ 362880, 3628800, 39916800, 479001600,
+#if SIZEOF_LONG >= 8
+ 6227020800, 87178291200, 1307674368000,
+ 20922789888000, 355687428096000, 6402373705728000,
+ 121645100408832000, 2432902008176640000
+#endif
+};
+
+/*[clinic input]
+imath.factorial
+
+ x as arg: object
+ /
+
+Find x!.
+
+Raise a TypeError if x is not an integer.
+Raise a ValueError if x is negative integer.
+[clinic start generated code]*/
+
+static PyObject *
+imath_factorial(PyObject *module, PyObject *arg)
+/*[clinic end generated code: output=73f1879dcbd64aea input=d5f41d496efcaf51]*/
+{
+ long x, two_valuation;
+ int overflow;
+ PyObject *result, *odd_part, *pyint_form;
+
+ pyint_form = PyNumber_Index(arg);
+ if (pyint_form == NULL) {
+ return NULL;
+ }
+ x = PyLong_AsLongAndOverflow(pyint_form, &overflow);
+ Py_DECREF(pyint_form);
+ if (x == -1 && PyErr_Occurred()) {
+ return NULL;
+ }
+ else if (overflow == 1) {
+ PyErr_Format(PyExc_OverflowError,
+ "factorial() argument should not exceed %ld",
+ LONG_MAX);
+ return NULL;
+ }
+ else if (overflow == -1 || x < 0) {
+ PyErr_SetString(PyExc_ValueError,
+ "factorial() not defined for negative values");
+ return NULL;
+ }
+
+ /* use lookup table if x is small */
+ if (x < (long)Py_ARRAY_LENGTH(SmallFactorials))
+ return PyLong_FromUnsignedLong(SmallFactorials[x]);
+
+ /* else express in the form odd_part * 2**two_valuation, and compute as
+ odd_part << two_valuation. */
+ odd_part = factorial_odd_part(x);
+ if (odd_part == NULL)
+ return NULL;
+ two_valuation = x - count_set_bits(x);
+ result = _PyLong_Lshift(odd_part, two_valuation);
+ Py_DECREF(odd_part);
+ return result;
+}
+
+
+/*[clinic input]
+imath.perm
+
+ n: object
+ k: object
+ /
+
+Number of ways to choose k items from n items without repetition and with order.
+
+It is mathematically equal to the expression n! / (n - k)!.
+
+Raises TypeError if the arguments are not integers.
+Raises ValueError if the arguments are negative or if k > n.
+[clinic start generated code]*/
+
+static PyObject *
+imath_perm_impl(PyObject *module, PyObject *n, PyObject *k)
+/*[clinic end generated code: output=4b4f3aa22c47911e input=24aa606dab86900c]*/
+{
+ PyObject *result = NULL, *factor = NULL;
+ int overflow, cmp;
+ long long i, factors;
+
+ n = PyNumber_Index(n);
+ if (n == NULL) {
+ return NULL;
+ }
+ if (!PyLong_CheckExact(n)) {
+ Py_SETREF(n, _PyLong_Copy((PyLongObject *)n));
+ if (n == NULL) {
+ return NULL;
+ }
+ }
+ k = PyNumber_Index(k);
+ if (k == NULL) {
+ Py_DECREF(n);
+ return NULL;
+ }
+ if (!PyLong_CheckExact(k)) {
+ Py_SETREF(k, _PyLong_Copy((PyLongObject *)k));
+ if (k == NULL) {
+ Py_DECREF(n);
+ return NULL;
+ }
+ }
+
+ if (Py_SIZE(n) < 0) {
+ PyErr_SetString(PyExc_ValueError,
+ "n must be a non-negative integer");
+ goto error;
+ }
+ cmp = PyObject_RichCompareBool(n, k, Py_LT);
+ if (cmp != 0) {
+ if (cmp > 0) {
+ PyErr_SetString(PyExc_ValueError,
+ "k must be an integer less than or equal to n");
+ }
+ goto error;
+ }
+
+ factors = PyLong_AsLongLongAndOverflow(k, &overflow);
+ if (overflow > 0) {
+ PyErr_Format(PyExc_OverflowError,
+ "k must not exceed %lld",
+ LLONG_MAX);
+ goto error;
+ }
+ else if (overflow < 0 || factors < 0) {
+ if (!PyErr_Occurred()) {
+ PyErr_SetString(PyExc_ValueError,
+ "k must be a non-negative integer");
+ }
+ goto error;
+ }
+
+ if (factors == 0) {
+ result = PyLong_FromLong(1);
+ goto done;
+ }
+
+ result = n;
+ Py_INCREF(result);
+ if (factors == 1) {
+ goto done;
+ }
+
+ factor = n;
+ Py_INCREF(factor);
+ for (i = 1; i < factors; ++i) {
+ Py_SETREF(factor, PyNumber_Subtract(factor, _PyLong_One));
+ if (factor == NULL) {
+ goto error;
+ }
+ Py_SETREF(result, PyNumber_Multiply(result, factor));
+ if (result == NULL) {
+ goto error;
+ }
+ }
+ Py_DECREF(factor);
+
+done:
+ Py_DECREF(n);
+ Py_DECREF(k);
+ return result;
+
+error:
+ Py_XDECREF(factor);
+ Py_XDECREF(result);
+ Py_DECREF(n);
+ Py_DECREF(k);
+ return NULL;
+}
+
+
+/*[clinic input]
+imath.comb
+
+ n: object
+ k: object
+ /
+
+Number of ways to choose k items from n items without repetition and without order.
+
+Also called the binomial coefficient. It is mathematically equal to the expression
+n! / (k! * (n - k)!). It is equivalent to the coefficient of k-th term in
+polynomial expansion of the expression (1 + x)**n.
+
+Raises TypeError if the arguments are not integers.
+Raises ValueError if the arguments are negative or if k > n.
+
+[clinic start generated code]*/
+
+static PyObject *
+imath_comb_impl(PyObject *module, PyObject *n, PyObject *k)
+/*[clinic end generated code: output=87cda746ba0145ef input=f6376b6622fdc123]*/
+{
+ PyObject *result = NULL, *factor = NULL, *temp;
+ int overflow, cmp;
+ long long i, factors;
+
+ n = PyNumber_Index(n);
+ if (n == NULL) {
+ return NULL;
+ }
+ if (!PyLong_CheckExact(n)) {
+ Py_SETREF(n, _PyLong_Copy((PyLongObject *)n));
+ if (n == NULL) {
+ return NULL;
+ }
+ }
+ k = PyNumber_Index(k);
+ if (k == NULL) {
+ Py_DECREF(n);
+ return NULL;
+ }
+ if (!PyLong_CheckExact(k)) {
+ Py_SETREF(k, _PyLong_Copy((PyLongObject *)k));
+ if (k == NULL) {
+ Py_DECREF(n);
+ return NULL;
+ }
+ }
+
+ if (Py_SIZE(n) < 0) {
+ PyErr_SetString(PyExc_ValueError,
+ "n must be a non-negative integer");
+ goto error;
+ }
+ /* k = min(k, n - k) */
+ temp = PyNumber_Subtract(n, k);
+ if (temp == NULL) {
+ goto error;
+ }
+ if (Py_SIZE(temp) < 0) {
+ Py_DECREF(temp);
+ PyErr_SetString(PyExc_ValueError,
+ "k must be an integer less than or equal to n");
+ goto error;
+ }
+ cmp = PyObject_RichCompareBool(temp, k, Py_LT);
+ if (cmp > 0) {
+ Py_SETREF(k, temp);
+ }
+ else {
+ Py_DECREF(temp);
+ if (cmp < 0) {
+ goto error;
+ }
+ }
+
+ factors = PyLong_AsLongLongAndOverflow(k, &overflow);
+ if (overflow > 0) {
+ PyErr_Format(PyExc_OverflowError,
+ "min(n - k, k) must not exceed %lld",
+ LLONG_MAX);
+ goto error;
+ }
+ else if (overflow < 0 || factors < 0) {
+ if (!PyErr_Occurred()) {
+ PyErr_SetString(PyExc_ValueError,
+ "k must be a non-negative integer");
+ }
+ goto error;
+ }
+
+ if (factors == 0) {
+ result = PyLong_FromLong(1);
+ goto done;
+ }
+
+ result = n;
+ Py_INCREF(result);
+ if (factors == 1) {
+ goto done;
+ }
+
+ factor = n;
+ Py_INCREF(factor);
+ for (i = 1; i < factors; ++i) {
+ Py_SETREF(factor, PyNumber_Subtract(factor, _PyLong_One));
+ if (factor == NULL) {
+ goto error;
+ }
+ Py_SETREF(result, PyNumber_Multiply(result, factor));
+ if (result == NULL) {
+ goto error;
+ }
+
+ temp = PyLong_FromUnsignedLongLong((unsigned long long)i + 1);
+ if (temp == NULL) {
+ goto error;
+ }
+ Py_SETREF(result, PyNumber_FloorDivide(result, temp));
+ Py_DECREF(temp);
+ if (result == NULL) {
+ goto error;
+ }
+ }
+ Py_DECREF(factor);
+
+done:
+ Py_DECREF(n);
+ Py_DECREF(k);
+ return result;
+
+error:
+ Py_XDECREF(factor);
+ Py_XDECREF(result);
+ Py_DECREF(n);
+ Py_DECREF(k);
+ return NULL;
+}
+
+
+static PyMethodDef imath_methods[] = {
+ IMATH_COMB_METHODDEF
+ IMATH_FACTORIAL_METHODDEF
+ IMATH_GCD_METHODDEF
+ IMATH_ILOG2_METHODDEF
+ IMATH_ISQRT_METHODDEF
+ IMATH_PERM_METHODDEF
+ {NULL, NULL} /* sentinel */
+};
+
+
+PyDoc_STRVAR(module_doc,
+"This module provides access to integer related mathematical functions.");
+
+
+static struct PyModuleDef imathmodule = {
+ PyModuleDef_HEAD_INIT,
+ "imath",
+ module_doc,
+ -1,
+ imath_methods,
+ NULL,
+ NULL,
+ NULL,
+ NULL
+};
+
+PyMODINIT_FUNC
+PyInit__imath(void)
+{
+ return PyModule_Create(&imathmodule);
+}
diff --git a/Modules/clinic/_imathmodule.c.h b/Modules/clinic/_imathmodule.c.h
new file mode 100644
index 00000000000000..54f75c6367f5ce
--- /dev/null
+++ b/Modules/clinic/_imathmodule.c.h
@@ -0,0 +1,136 @@
+/*[clinic input]
+preserve
+[clinic start generated code]*/
+
+PyDoc_STRVAR(imath_gcd__doc__,
+"gcd($module, x, y, /)\n"
+"--\n"
+"\n"
+"Greatest common divisor of x and y.");
+
+#define IMATH_GCD_METHODDEF \
+ {"gcd", (PyCFunction)(void(*)(void))imath_gcd, METH_FASTCALL, imath_gcd__doc__},
+
+static PyObject *
+imath_gcd_impl(PyObject *module, PyObject *x, PyObject *y);
+
+static PyObject *
+imath_gcd(PyObject *module, PyObject *const *args, Py_ssize_t nargs)
+{
+ PyObject *return_value = NULL;
+ PyObject *x;
+ PyObject *y;
+
+ if (!_PyArg_CheckPositional("gcd", nargs, 2, 2)) {
+ goto exit;
+ }
+ x = args[0];
+ y = args[1];
+ return_value = imath_gcd_impl(module, x, y);
+
+exit:
+ return return_value;
+}
+
+PyDoc_STRVAR(imath_ilog2__doc__,
+"ilog2($module, n, /)\n"
+"--\n"
+"\n"
+"Return the integer part of the base 2 logarithm of the input.");
+
+#define IMATH_ILOG2_METHODDEF \
+ {"ilog2", (PyCFunction)imath_ilog2, METH_O, imath_ilog2__doc__},
+
+PyDoc_STRVAR(imath_isqrt__doc__,
+"isqrt($module, n, /)\n"
+"--\n"
+"\n"
+"Return the integer part of the square root of the input.");
+
+#define IMATH_ISQRT_METHODDEF \
+ {"isqrt", (PyCFunction)imath_isqrt, METH_O, imath_isqrt__doc__},
+
+PyDoc_STRVAR(imath_factorial__doc__,
+"factorial($module, x, /)\n"
+"--\n"
+"\n"
+"Find x!.\n"
+"\n"
+"Raise a TypeError if x is not an integer.\n"
+"Raise a ValueError if x is negative integer.");
+
+#define IMATH_FACTORIAL_METHODDEF \
+ {"factorial", (PyCFunction)imath_factorial, METH_O, imath_factorial__doc__},
+
+PyDoc_STRVAR(imath_perm__doc__,
+"perm($module, n, k, /)\n"
+"--\n"
+"\n"
+"Number of ways to choose k items from n items without repetition and with order.\n"
+"\n"
+"It is mathematically equal to the expression n! / (n - k)!.\n"
+"\n"
+"Raises TypeError if the arguments are not integers.\n"
+"Raises ValueError if the arguments are negative or if k > n.");
+
+#define IMATH_PERM_METHODDEF \
+ {"perm", (PyCFunction)(void(*)(void))imath_perm, METH_FASTCALL, imath_perm__doc__},
+
+static PyObject *
+imath_perm_impl(PyObject *module, PyObject *n, PyObject *k);
+
+static PyObject *
+imath_perm(PyObject *module, PyObject *const *args, Py_ssize_t nargs)
+{
+ PyObject *return_value = NULL;
+ PyObject *n;
+ PyObject *k;
+
+ if (!_PyArg_CheckPositional("perm", nargs, 2, 2)) {
+ goto exit;
+ }
+ n = args[0];
+ k = args[1];
+ return_value = imath_perm_impl(module, n, k);
+
+exit:
+ return return_value;
+}
+
+PyDoc_STRVAR(imath_comb__doc__,
+"comb($module, n, k, /)\n"
+"--\n"
+"\n"
+"Number of ways to choose k items from n items without repetition and without order.\n"
+"\n"
+"Also called the binomial coefficient. It is mathematically equal to the expression\n"
+"n! / (k! * (n - k)!). It is equivalent to the coefficient of k-th term in\n"
+"polynomial expansion of the expression (1 + x)**n.\n"
+"\n"
+"Raises TypeError if the arguments are not integers.\n"
+"Raises ValueError if the arguments are negative or if k > n.");
+
+#define IMATH_COMB_METHODDEF \
+ {"comb", (PyCFunction)(void(*)(void))imath_comb, METH_FASTCALL, imath_comb__doc__},
+
+static PyObject *
+imath_comb_impl(PyObject *module, PyObject *n, PyObject *k);
+
+static PyObject *
+imath_comb(PyObject *module, PyObject *const *args, Py_ssize_t nargs)
+{
+ PyObject *return_value = NULL;
+ PyObject *n;
+ PyObject *k;
+
+ if (!_PyArg_CheckPositional("comb", nargs, 2, 2)) {
+ goto exit;
+ }
+ n = args[0];
+ k = args[1];
+ return_value = imath_comb_impl(module, n, k);
+
+exit:
+ return return_value;
+}
+/*[clinic end generated code: output=47736df1c2f249c7 input=a9049054013a1b77]*/
diff --git a/Modules/clinic/mathmodule.c.h b/Modules/clinic/mathmodule.c.h
index 0efe5cc409ceb1..388a6400db5f6a 100644
--- a/Modules/clinic/mathmodule.c.h
+++ b/Modules/clinic/mathmodule.c.h
@@ -6,7 +6,9 @@ PyDoc_STRVAR(math_gcd__doc__,
"gcd($module, x, y, /)\n"
"--\n"
"\n"
-"greatest common divisor of x and y");
+"greatest common divisor of x and y\n"
+"\n"
+"See also imath.gcd().");
#define MATH_GCD_METHODDEF \
{"gcd", (PyCFunction)(void(*)(void))math_gcd, METH_FASTCALL, math_gcd__doc__},
@@ -65,22 +67,15 @@ PyDoc_STRVAR(math_fsum__doc__,
#define MATH_FSUM_METHODDEF \
{"fsum", (PyCFunction)math_fsum, METH_O, math_fsum__doc__},
-PyDoc_STRVAR(math_isqrt__doc__,
-"isqrt($module, n, /)\n"
-"--\n"
-"\n"
-"Return the integer part of the square root of the input.");
-
-#define MATH_ISQRT_METHODDEF \
- {"isqrt", (PyCFunction)math_isqrt, METH_O, math_isqrt__doc__},
-
PyDoc_STRVAR(math_factorial__doc__,
"factorial($module, x, /)\n"
"--\n"
"\n"
"Find x!.\n"
"\n"
-"Raise a ValueError if x is negative or non-integral.");
+"Raise a ValueError if x is negative or non-integral.\n"
+"\n"
+"See also imath.factorial().");
#define MATH_FACTORIAL_METHODDEF \
{"factorial", (PyCFunction)math_factorial, METH_O, math_factorial__doc__},
@@ -637,76 +632,4 @@ math_prod(PyObject *module, PyObject *const *args, Py_ssize_t nargs, PyObject *k
exit:
return return_value;
}
-
-PyDoc_STRVAR(math_perm__doc__,
-"perm($module, n, k, /)\n"
-"--\n"
-"\n"
-"Number of ways to choose k items from n items without repetition and with order.\n"
-"\n"
-"It is mathematically equal to the expression n! / (n - k)!.\n"
-"\n"
-"Raises TypeError if the arguments are not integers.\n"
-"Raises ValueError if the arguments are negative or if k > n.");
-
-#define MATH_PERM_METHODDEF \
- {"perm", (PyCFunction)(void(*)(void))math_perm, METH_FASTCALL, math_perm__doc__},
-
-static PyObject *
-math_perm_impl(PyObject *module, PyObject *n, PyObject *k);
-
-static PyObject *
-math_perm(PyObject *module, PyObject *const *args, Py_ssize_t nargs)
-{
- PyObject *return_value = NULL;
- PyObject *n;
- PyObject *k;
-
- if (!_PyArg_CheckPositional("perm", nargs, 2, 2)) {
- goto exit;
- }
- n = args[0];
- k = args[1];
- return_value = math_perm_impl(module, n, k);
-
-exit:
- return return_value;
-}
-
-PyDoc_STRVAR(math_comb__doc__,
-"comb($module, n, k, /)\n"
-"--\n"
-"\n"
-"Number of ways to choose k items from n items without repetition and without order.\n"
-"\n"
-"Also called the binomial coefficient. It is mathematically equal to the expression\n"
-"n! / (k! * (n - k)!). It is equivalent to the coefficient of k-th term in\n"
-"polynomial expansion of the expression (1 + x)**n.\n"
-"\n"
-"Raises TypeError if the arguments are not integers.\n"
-"Raises ValueError if the arguments are negative or if k > n.");
-
-#define MATH_COMB_METHODDEF \
- {"comb", (PyCFunction)(void(*)(void))math_comb, METH_FASTCALL, math_comb__doc__},
-
-static PyObject *
-math_comb_impl(PyObject *module, PyObject *n, PyObject *k);
-
-static PyObject *
-math_comb(PyObject *module, PyObject *const *args, Py_ssize_t nargs)
-{
- PyObject *return_value = NULL;
- PyObject *n;
- PyObject *k;
-
- if (!_PyArg_CheckPositional("comb", nargs, 2, 2)) {
- goto exit;
- }
- n = args[0];
- k = args[1];
- return_value = math_comb_impl(module, n, k);
-
-exit:
- return return_value;
-}
-/*[clinic end generated code: output=a82b0e705b6d0ec0 input=a9049054013a1b77]*/
+/*[clinic end generated code: output=e9a8e18bfc89e9b1 input=a9049054013a1b77]*/
diff --git a/Modules/mathmodule.c b/Modules/mathmodule.c
index 6e1099321c5495..b2fdddabe29835 100644
--- a/Modules/mathmodule.c
+++ b/Modules/mathmodule.c
@@ -76,6 +76,8 @@ static const double logpi = 1.144729885849400174143427351353058711647;
static const double sqrtpi = 1.772453850905516027298167483341145182798;
#endif /* !defined(HAVE_ERF) || !defined(HAVE_ERFC) */
+static PyObject *imath_factorial = NULL;
+
/* Version of PyFloat_AsDouble() with in-line fast paths
for exact floats and integers. Gives a substantial
@@ -833,11 +835,13 @@ math.gcd
/
greatest common divisor of x and y
+
+See also imath.gcd().
[clinic start generated code]*/
static PyObject *
math_gcd_impl(PyObject *module, PyObject *a, PyObject *b)
-/*[clinic end generated code: output=7b2e0c151bd7a5d8 input=c2691e57fb2a98fa]*/
+/*[clinic end generated code: output=7b2e0c151bd7a5d8 input=10d2c441d48df152]*/
{
PyObject *g;
@@ -1443,524 +1447,6 @@ math_fsum(PyObject *module, PyObject *seq)
#undef NUM_PARTIALS
-/* Return the smallest integer k such that n < 2**k, or 0 if n == 0.
- * Equivalent to floor(lg(x))+1. Also equivalent to: bitwidth_of_type -
- * count_leading_zero_bits(x)
- */
-
-/* XXX: This routine does more or less the same thing as
- * bits_in_digit() in Objects/longobject.c. Someday it would be nice to
- * consolidate them. On BSD, there's a library function called fls()
- * that we could use, and GCC provides __builtin_clz().
- */
-
-static unsigned long
-bit_length(unsigned long n)
-{
- unsigned long len = 0;
- while (n != 0) {
- ++len;
- n >>= 1;
- }
- return len;
-}
-
-static unsigned long
-count_set_bits(unsigned long n)
-{
- unsigned long count = 0;
- while (n != 0) {
- ++count;
- n &= n - 1; /* clear least significant bit */
- }
- return count;
-}
-
-/* Integer square root
-
-Given a nonnegative integer `n`, we want to compute the largest integer
-`a` for which `a * a <= n`, or equivalently the integer part of the exact
-square root of `n`.
-
-We use an adaptive-precision pure-integer version of Newton's iteration. Given
-a positive integer `n`, the algorithm produces at each iteration an integer
-approximation `a` to the square root of `n >> s` for some even integer `s`,
-with `s` decreasing as the iterations progress. On the final iteration, `s` is
-zero and we have an approximation to the square root of `n` itself.
-
-At every step, the approximation `a` is strictly within 1.0 of the true square
-root, so we have
-
- (a - 1)**2 < (n >> s) < (a + 1)**2
-
-After the final iteration, a check-and-correct step is needed to determine
-whether `a` or `a - 1` gives the desired integer square root of `n`.
-
-The algorithm is remarkable in its simplicity. There's no need for a
-per-iteration check-and-correct step, and termination is straightforward: the
-number of iterations is known in advance (it's exactly `floor(log2(log2(n)))`
-for `n > 1`). The only tricky part of the correctness proof is in establishing
-that the bound `(a - 1)**2 < (n >> s) < (a + 1)**2` is maintained from one
-iteration to the next. A sketch of the proof of this is given below.
-
-In addition to the proof sketch, a formal, computer-verified proof
-of correctness (using Lean) of an equivalent recursive algorithm can be found
-here:
-
- https://github.com/mdickinson/snippets/blob/master/proofs/isqrt/src/isqrt.lean
-
-
-Here's Python code equivalent to the C implementation below:
-
- def isqrt(n):
- """
- Return the integer part of the square root of the input.
- """
- n = operator.index(n)
-
- if n < 0:
- raise ValueError("isqrt() argument must be nonnegative")
- if n == 0:
- return 0
-
- c = (n.bit_length() - 1) // 2
- a = 1
- d = 0
- for s in reversed(range(c.bit_length())):
- e = d
- d = c >> s
- a = (a << d - e - 1) + (n >> 2*c - e - d + 1) // a
- assert (a-1)**2 < n >> 2*(c - d) < (a+1)**2
-
- return a - (a*a > n)
-
-
-Sketch of proof of correctness
-------------------------------
-
-The delicate part of the correctness proof is showing that the loop invariant
-is preserved from one iteration to the next. That is, just before the line
-
- a = (a << d - e - 1) + (n >> 2*c - e - d + 1) // a
-
-is executed in the above code, we know that
-
- (1) (a - 1)**2 < (n >> 2*(c - e)) < (a + 1)**2.
-
-(since `e` is always the value of `d` from the previous iteration). We must
-prove that after that line is executed, we have
-
- (a - 1)**2 < (n >> 2*(c - d)) < (a + 1)**2
-
-To faciliate the proof, we make some changes of notation. Write `m` for
-`n >> 2*(c-d)`, and write `b` for the new value of `a`, so
-
- b = (a << d - e - 1) + (n >> 2*c - e - d + 1) // a
-
-or equivalently:
-
- (2) b = (a << d - e - 1) + (m >> d - e + 1) // a
-
-Then we can rewrite (1) as:
-
- (3) (a - 1)**2 < (m >> 2*(d - e)) < (a + 1)**2
-
-and we must show that (b - 1)**2 < m < (b + 1)**2.
-
-From this point on, we switch to mathematical notation, so `/` means exact
-division rather than integer division and `^` is used for exponentiation. We
-use the `√` symbol for the exact square root. In (3), we can remove the
-implicit floor operation to give:
-
- (4) (a - 1)^2 < m / 4^(d - e) < (a + 1)^2
-
-Taking square roots throughout (4), scaling by `2^(d-e)`, and rearranging gives
-
- (5) 0 <= | 2^(d-e)a - √m | < 2^(d-e)
-
-Squaring and dividing through by `2^(d-e+1) a` gives
-
- (6) 0 <= 2^(d-e-1) a + m / (2^(d-e+1) a) - √m < 2^(d-e-1) / a
-
-We'll show below that `2^(d-e-1) <= a`. Given that, we can replace the
-right-hand side of (6) with `1`, and now replacing the central
-term `m / (2^(d-e+1) a)` with its floor in (6) gives
-
- (7) -1 < 2^(d-e-1) a + m // 2^(d-e+1) a - √m < 1
-
-Or equivalently, from (2):
-
- (7) -1 < b - √m < 1
-
-and rearranging gives that `(b-1)^2 < m < (b+1)^2`, which is what we needed
-to prove.
-
-We're not quite done: we still have to prove the inequality `2^(d - e - 1) <=
-a` that was used to get line (7) above. From the definition of `c`, we have
-`4^c <= n`, which implies
-
- (8) 4^d <= m
-
-also, since `e == d >> 1`, `d` is at most `2e + 1`, from which it follows
-that `2d - 2e - 1 <= d` and hence that
-
- (9) 4^(2d - 2e - 1) <= m
-
-Dividing both sides by `4^(d - e)` gives
-
- (10) 4^(d - e - 1) <= m / 4^(d - e)
-
-But we know from (4) that `m / 4^(d-e) < (a + 1)^2`, hence
-
- (11) 4^(d - e - 1) < (a + 1)^2
-
-Now taking square roots of both sides and observing that both `2^(d-e-1)` and
-`a` are integers gives `2^(d - e - 1) <= a`, which is what we needed. This
-completes the proof sketch.
-
-*/
-
-
-/* Approximate square root of a large 64-bit integer.
-
- Given `n` satisfying `2**62 <= n < 2**64`, return `a`
- satisfying `(a - 1)**2 < n < (a + 1)**2`. */
-
-static uint64_t
-_approximate_isqrt(uint64_t n)
-{
- uint32_t u = 1U + (n >> 62);
- u = (u << 1) + (n >> 59) / u;
- u = (u << 3) + (n >> 53) / u;
- u = (u << 7) + (n >> 41) / u;
- return (u << 15) + (n >> 17) / u;
-}
-
-/*[clinic input]
-math.isqrt
-
- n: object
- /
-
-Return the integer part of the square root of the input.
-[clinic start generated code]*/
-
-static PyObject *
-math_isqrt(PyObject *module, PyObject *n)
-/*[clinic end generated code: output=35a6f7f980beab26 input=5b6e7ae4fa6c43d6]*/
-{
- int a_too_large, c_bit_length;
- size_t c, d;
- uint64_t m, u;
- PyObject *a = NULL, *b;
-
- n = PyNumber_Index(n);
- if (n == NULL) {
- return NULL;
- }
-
- if (_PyLong_Sign(n) < 0) {
- PyErr_SetString(
- PyExc_ValueError,
- "isqrt() argument must be nonnegative");
- goto error;
- }
- if (_PyLong_Sign(n) == 0) {
- Py_DECREF(n);
- return PyLong_FromLong(0);
- }
-
- /* c = (n.bit_length() - 1) // 2 */
- c = _PyLong_NumBits(n);
- if (c == (size_t)(-1)) {
- goto error;
- }
- c = (c - 1U) / 2U;
-
- /* Fast path: if c <= 31 then n < 2**64 and we can compute directly with a
- fast, almost branch-free algorithm. In the final correction, we use `u*u
- - 1 >= m` instead of the simpler `u*u > m` in order to get the correct
- result in the corner case where `u=2**32`. */
- if (c <= 31U) {
- m = (uint64_t)PyLong_AsUnsignedLongLong(n);
- Py_DECREF(n);
- if (m == (uint64_t)(-1) && PyErr_Occurred()) {
- return NULL;
- }
- u = _approximate_isqrt(m << (62U - 2U*c)) >> (31U - c);
- u -= u * u - 1U >= m;
- return PyLong_FromUnsignedLongLong((unsigned long long)u);
- }
-
- /* Slow path: n >= 2**64. We perform the first five iterations in C integer
- arithmetic, then switch to using Python long integers. */
-
- /* From n >= 2**64 it follows that c.bit_length() >= 6. */
- c_bit_length = 6;
- while ((c >> c_bit_length) > 0U) {
- ++c_bit_length;
- }
-
- /* Initialise d and a. */
- d = c >> (c_bit_length - 5);
- b = _PyLong_Rshift(n, 2U*c - 62U);
- if (b == NULL) {
- goto error;
- }
- m = (uint64_t)PyLong_AsUnsignedLongLong(b);
- Py_DECREF(b);
- if (m == (uint64_t)(-1) && PyErr_Occurred()) {
- goto error;
- }
- u = _approximate_isqrt(m) >> (31U - d);
- a = PyLong_FromUnsignedLongLong((unsigned long long)u);
- if (a == NULL) {
- goto error;
- }
-
- for (int s = c_bit_length - 6; s >= 0; --s) {
- PyObject *q;
- size_t e = d;
-
- d = c >> s;
-
- /* q = (n >> 2*c - e - d + 1) // a */
- q = _PyLong_Rshift(n, 2U*c - d - e + 1U);
- if (q == NULL) {
- goto error;
- }
- Py_SETREF(q, PyNumber_FloorDivide(q, a));
- if (q == NULL) {
- goto error;
- }
-
- /* a = (a << d - 1 - e) + q */
- Py_SETREF(a, _PyLong_Lshift(a, d - 1U - e));
- if (a == NULL) {
- Py_DECREF(q);
- goto error;
- }
- Py_SETREF(a, PyNumber_Add(a, q));
- Py_DECREF(q);
- if (a == NULL) {
- goto error;
- }
- }
-
- /* The correct result is either a or a - 1. Figure out which, and
- decrement a if necessary. */
-
- /* a_too_large = n < a * a */
- b = PyNumber_Multiply(a, a);
- if (b == NULL) {
- goto error;
- }
- a_too_large = PyObject_RichCompareBool(n, b, Py_LT);
- Py_DECREF(b);
- if (a_too_large == -1) {
- goto error;
- }
-
- if (a_too_large) {
- Py_SETREF(a, PyNumber_Subtract(a, _PyLong_One));
- }
- Py_DECREF(n);
- return a;
-
- error:
- Py_XDECREF(a);
- Py_DECREF(n);
- return NULL;
-}
-
-/* Divide-and-conquer factorial algorithm
- *
- * Based on the formula and pseudo-code provided at:
- * http://www.luschny.de/math/factorial/binarysplitfact.html
- *
- * Faster algorithms exist, but they're more complicated and depend on
- * a fast prime factorization algorithm.
- *
- * Notes on the algorithm
- * ----------------------
- *
- * factorial(n) is written in the form 2**k * m, with m odd. k and m are
- * computed separately, and then combined using a left shift.
- *
- * The function factorial_odd_part computes the odd part m (i.e., the greatest
- * odd divisor) of factorial(n), using the formula:
- *
- * factorial_odd_part(n) =
- *
- * product_{i >= 0} product_{0 < j <= n / 2**i, j odd} j
- *
- * Example: factorial_odd_part(20) =
- *
- * (1) *
- * (1) *
- * (1 * 3 * 5) *
- * (1 * 3 * 5 * 7 * 9)
- * (1 * 3 * 5 * 7 * 9 * 11 * 13 * 15 * 17 * 19)
- *
- * Here i goes from large to small: the first term corresponds to i=4 (any
- * larger i gives an empty product), and the last term corresponds to i=0.
- * Each term can be computed from the last by multiplying by the extra odd
- * numbers required: e.g., to get from the penultimate term to the last one,
- * we multiply by (11 * 13 * 15 * 17 * 19).
- *
- * To see a hint of why this formula works, here are the same numbers as above
- * but with the even parts (i.e., the appropriate powers of 2) included. For
- * each subterm in the product for i, we multiply that subterm by 2**i:
- *
- * factorial(20) =
- *
- * (16) *
- * (8) *
- * (4 * 12 * 20) *
- * (2 * 6 * 10 * 14 * 18) *
- * (1 * 3 * 5 * 7 * 9 * 11 * 13 * 15 * 17 * 19)
- *
- * The factorial_partial_product function computes the product of all odd j in
- * range(start, stop) for given start and stop. It's used to compute the
- * partial products like (11 * 13 * 15 * 17 * 19) in the example above. It
- * operates recursively, repeatedly splitting the range into two roughly equal
- * pieces until the subranges are small enough to be computed using only C
- * integer arithmetic.
- *
- * The two-valuation k (i.e., the exponent of the largest power of 2 dividing
- * the factorial) is computed independently in the main math_factorial
- * function. By standard results, its value is:
- *
- * two_valuation = n//2 + n//4 + n//8 + ....
- *
- * It can be shown (e.g., by complete induction on n) that two_valuation is
- * equal to n - count_set_bits(n), where count_set_bits(n) gives the number of
- * '1'-bits in the binary expansion of n.
- */
-
-/* factorial_partial_product: Compute product(range(start, stop, 2)) using
- * divide and conquer. Assumes start and stop are odd and stop > start.
- * max_bits must be >= bit_length(stop - 2). */
-
-static PyObject *
-factorial_partial_product(unsigned long start, unsigned long stop,
- unsigned long max_bits)
-{
- unsigned long midpoint, num_operands;
- PyObject *left = NULL, *right = NULL, *result = NULL;
-
- /* If the return value will fit an unsigned long, then we can
- * multiply in a tight, fast loop where each multiply is O(1).
- * Compute an upper bound on the number of bits required to store
- * the answer.
- *
- * Storing some integer z requires floor(lg(z))+1 bits, which is
- * conveniently the value returned by bit_length(z). The
- * product x*y will require at most
- * bit_length(x) + bit_length(y) bits to store, based
- * on the idea that lg product = lg x + lg y.
- *
- * We know that stop - 2 is the largest number to be multiplied. From
- * there, we have: bit_length(answer) <= num_operands *
- * bit_length(stop - 2)
- */
-
- num_operands = (stop - start) / 2;
- /* The "num_operands <= 8 * SIZEOF_LONG" check guards against the
- * unlikely case of an overflow in num_operands * max_bits. */
- if (num_operands <= 8 * SIZEOF_LONG &&
- num_operands * max_bits <= 8 * SIZEOF_LONG) {
- unsigned long j, total;
- for (total = start, j = start + 2; j < stop; j += 2)
- total *= j;
- return PyLong_FromUnsignedLong(total);
- }
-
- /* find midpoint of range(start, stop), rounded up to next odd number. */
- midpoint = (start + num_operands) | 1;
- left = factorial_partial_product(start, midpoint,
- bit_length(midpoint - 2));
- if (left == NULL)
- goto error;
- right = factorial_partial_product(midpoint, stop, max_bits);
- if (right == NULL)
- goto error;
- result = PyNumber_Multiply(left, right);
-
- error:
- Py_XDECREF(left);
- Py_XDECREF(right);
- return result;
-}
-
-/* factorial_odd_part: compute the odd part of factorial(n). */
-
-static PyObject *
-factorial_odd_part(unsigned long n)
-{
- long i;
- unsigned long v, lower, upper;
- PyObject *partial, *tmp, *inner, *outer;
-
- inner = PyLong_FromLong(1);
- if (inner == NULL)
- return NULL;
- outer = inner;
- Py_INCREF(outer);
-
- upper = 3;
- for (i = bit_length(n) - 2; i >= 0; i--) {
- v = n >> i;
- if (v <= 2)
- continue;
- lower = upper;
- /* (v + 1) | 1 = least odd integer strictly larger than n / 2**i */
- upper = (v + 1) | 1;
- /* Here inner is the product of all odd integers j in the range (0,
- n/2**(i+1)]. The factorial_partial_product call below gives the
- product of all odd integers j in the range (n/2**(i+1), n/2**i]. */
- partial = factorial_partial_product(lower, upper, bit_length(upper-2));
- /* inner *= partial */
- if (partial == NULL)
- goto error;
- tmp = PyNumber_Multiply(inner, partial);
- Py_DECREF(partial);
- if (tmp == NULL)
- goto error;
- Py_DECREF(inner);
- inner = tmp;
- /* Now inner is the product of all odd integers j in the range (0,
- n/2**i], giving the inner product in the formula above. */
-
- /* outer *= inner; */
- tmp = PyNumber_Multiply(outer, inner);
- if (tmp == NULL)
- goto error;
- Py_DECREF(outer);
- outer = tmp;
- }
- Py_DECREF(inner);
- return outer;
-
- error:
- Py_DECREF(outer);
- Py_DECREF(inner);
- return NULL;
-}
-
-
-/* Lookup table for small factorial values */
-
-static const unsigned long SmallFactorials[] = {
- 1, 1, 2, 6, 24, 120, 720, 5040, 40320,
- 362880, 3628800, 39916800, 479001600,
-#if SIZEOF_LONG >= 8
- 6227020800, 87178291200, 1307674368000,
- 20922789888000, 355687428096000, 6402373705728000,
- 121645100408832000, 2432902008176640000
-#endif
-};
-
/*[clinic input]
math.factorial
@@ -1970,66 +1456,33 @@ math.factorial
Find x!.
Raise a ValueError if x is negative or non-integral.
+
+See also imath.factorial().
[clinic start generated code]*/
static PyObject *
math_factorial(PyObject *module, PyObject *arg)
-/*[clinic end generated code: output=6686f26fae00e9ca input=6d1c8105c0d91fb4]*/
+/*[clinic end generated code: output=6686f26fae00e9ca input=3a6ec477a80c807b]*/
{
- long x, two_valuation;
- int overflow;
- PyObject *result, *odd_part, *pyint_form;
+ PyObject *result;
if (PyFloat_Check(arg)) {
- PyObject *lx;
double dx = PyFloat_AS_DOUBLE((PyFloatObject *)arg);
if (!(Py_IS_FINITE(dx) && dx == floor(dx))) {
PyErr_SetString(PyExc_ValueError,
"factorial() only accepts integral values");
return NULL;
}
- lx = PyLong_FromDouble(dx);
- if (lx == NULL)
+ arg = PyLong_FromDouble(dx);
+ if (arg == NULL)
return NULL;
- x = PyLong_AsLongAndOverflow(lx, &overflow);
- Py_DECREF(lx);
}
else {
- pyint_form = PyNumber_Index(arg);
- if (pyint_form == NULL) {
- return NULL;
- }
- x = PyLong_AsLongAndOverflow(pyint_form, &overflow);
- Py_DECREF(pyint_form);
+ Py_INCREF(arg);
}
- if (x == -1 && PyErr_Occurred()) {
- return NULL;
- }
- else if (overflow == 1) {
- PyErr_Format(PyExc_OverflowError,
- "factorial() argument should not exceed %ld",
- LONG_MAX);
- return NULL;
- }
- else if (overflow == -1 || x < 0) {
- PyErr_SetString(PyExc_ValueError,
- "factorial() not defined for negative values");
- return NULL;
- }
-
- /* use lookup table if x is small */
- if (x < (long)Py_ARRAY_LENGTH(SmallFactorials))
- return PyLong_FromUnsignedLong(SmallFactorials[x]);
-
- /* else express in the form odd_part * 2**two_valuation, and compute as
- odd_part << two_valuation. */
- odd_part = factorial_odd_part(x);
- if (odd_part == NULL)
- return NULL;
- two_valuation = x - count_set_bits(x);
- result = _PyLong_Lshift(odd_part, two_valuation);
- Py_DECREF(odd_part);
+ result = _PyObject_FastCall(imath_factorial, &arg, 1);
+ Py_DECREF(arg);
return result;
}
@@ -2998,260 +2451,6 @@ math_prod_impl(PyObject *module, PyObject *iterable, PyObject *start)
}
-/*[clinic input]
-math.perm
-
- n: object
- k: object
- /
-
-Number of ways to choose k items from n items without repetition and with order.
-
-It is mathematically equal to the expression n! / (n - k)!.
-
-Raises TypeError if the arguments are not integers.
-Raises ValueError if the arguments are negative or if k > n.
-[clinic start generated code]*/
-
-static PyObject *
-math_perm_impl(PyObject *module, PyObject *n, PyObject *k)
-/*[clinic end generated code: output=e021a25469653e23 input=f71ee4f6ff26be24]*/
-{
- PyObject *result = NULL, *factor = NULL;
- int overflow, cmp;
- long long i, factors;
-
- n = PyNumber_Index(n);
- if (n == NULL) {
- return NULL;
- }
- if (!PyLong_CheckExact(n)) {
- Py_SETREF(n, _PyLong_Copy((PyLongObject *)n));
- if (n == NULL) {
- return NULL;
- }
- }
- k = PyNumber_Index(k);
- if (k == NULL) {
- Py_DECREF(n);
- return NULL;
- }
- if (!PyLong_CheckExact(k)) {
- Py_SETREF(k, _PyLong_Copy((PyLongObject *)k));
- if (k == NULL) {
- Py_DECREF(n);
- return NULL;
- }
- }
-
- if (Py_SIZE(n) < 0) {
- PyErr_SetString(PyExc_ValueError,
- "n must be a non-negative integer");
- goto error;
- }
- cmp = PyObject_RichCompareBool(n, k, Py_LT);
- if (cmp != 0) {
- if (cmp > 0) {
- PyErr_SetString(PyExc_ValueError,
- "k must be an integer less than or equal to n");
- }
- goto error;
- }
-
- factors = PyLong_AsLongLongAndOverflow(k, &overflow);
- if (overflow > 0) {
- PyErr_Format(PyExc_OverflowError,
- "k must not exceed %lld",
- LLONG_MAX);
- goto error;
- }
- else if (overflow < 0 || factors < 0) {
- if (!PyErr_Occurred()) {
- PyErr_SetString(PyExc_ValueError,
- "k must be a non-negative integer");
- }
- goto error;
- }
-
- if (factors == 0) {
- result = PyLong_FromLong(1);
- goto done;
- }
-
- result = n;
- Py_INCREF(result);
- if (factors == 1) {
- goto done;
- }
-
- factor = n;
- Py_INCREF(factor);
- for (i = 1; i < factors; ++i) {
- Py_SETREF(factor, PyNumber_Subtract(factor, _PyLong_One));
- if (factor == NULL) {
- goto error;
- }
- Py_SETREF(result, PyNumber_Multiply(result, factor));
- if (result == NULL) {
- goto error;
- }
- }
- Py_DECREF(factor);
-
-done:
- Py_DECREF(n);
- Py_DECREF(k);
- return result;
-
-error:
- Py_XDECREF(factor);
- Py_XDECREF(result);
- Py_DECREF(n);
- Py_DECREF(k);
- return NULL;
-}
-
-
-/*[clinic input]
-math.comb
-
- n: object
- k: object
- /
-
-Number of ways to choose k items from n items without repetition and without order.
-
-Also called the binomial coefficient. It is mathematically equal to the expression
-n! / (k! * (n - k)!). It is equivalent to the coefficient of k-th term in
-polynomial expansion of the expression (1 + x)**n.
-
-Raises TypeError if the arguments are not integers.
-Raises ValueError if the arguments are negative or if k > n.
-
-[clinic start generated code]*/
-
-static PyObject *
-math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
-/*[clinic end generated code: output=bd2cec8d854f3493 input=2f336ac9ec8242f9]*/
-{
- PyObject *result = NULL, *factor = NULL, *temp;
- int overflow, cmp;
- long long i, factors;
-
- n = PyNumber_Index(n);
- if (n == NULL) {
- return NULL;
- }
- if (!PyLong_CheckExact(n)) {
- Py_SETREF(n, _PyLong_Copy((PyLongObject *)n));
- if (n == NULL) {
- return NULL;
- }
- }
- k = PyNumber_Index(k);
- if (k == NULL) {
- Py_DECREF(n);
- return NULL;
- }
- if (!PyLong_CheckExact(k)) {
- Py_SETREF(k, _PyLong_Copy((PyLongObject *)k));
- if (k == NULL) {
- Py_DECREF(n);
- return NULL;
- }
- }
-
- if (Py_SIZE(n) < 0) {
- PyErr_SetString(PyExc_ValueError,
- "n must be a non-negative integer");
- goto error;
- }
- /* k = min(k, n - k) */
- temp = PyNumber_Subtract(n, k);
- if (temp == NULL) {
- goto error;
- }
- if (Py_SIZE(temp) < 0) {
- Py_DECREF(temp);
- PyErr_SetString(PyExc_ValueError,
- "k must be an integer less than or equal to n");
- goto error;
- }
- cmp = PyObject_RichCompareBool(temp, k, Py_LT);
- if (cmp > 0) {
- Py_SETREF(k, temp);
- }
- else {
- Py_DECREF(temp);
- if (cmp < 0) {
- goto error;
- }
- }
-
- factors = PyLong_AsLongLongAndOverflow(k, &overflow);
- if (overflow > 0) {
- PyErr_Format(PyExc_OverflowError,
- "min(n - k, k) must not exceed %lld",
- LLONG_MAX);
- goto error;
- }
- else if (overflow < 0 || factors < 0) {
- if (!PyErr_Occurred()) {
- PyErr_SetString(PyExc_ValueError,
- "k must be a non-negative integer");
- }
- goto error;
- }
-
- if (factors == 0) {
- result = PyLong_FromLong(1);
- goto done;
- }
-
- result = n;
- Py_INCREF(result);
- if (factors == 1) {
- goto done;
- }
-
- factor = n;
- Py_INCREF(factor);
- for (i = 1; i < factors; ++i) {
- Py_SETREF(factor, PyNumber_Subtract(factor, _PyLong_One));
- if (factor == NULL) {
- goto error;
- }
- Py_SETREF(result, PyNumber_Multiply(result, factor));
- if (result == NULL) {
- goto error;
- }
-
- temp = PyLong_FromUnsignedLongLong((unsigned long long)i + 1);
- if (temp == NULL) {
- goto error;
- }
- Py_SETREF(result, PyNumber_FloorDivide(result, temp));
- Py_DECREF(temp);
- if (result == NULL) {
- goto error;
- }
- }
- Py_DECREF(factor);
-
-done:
- Py_DECREF(n);
- Py_DECREF(k);
- return result;
-
-error:
- Py_XDECREF(factor);
- Py_XDECREF(result);
- Py_DECREF(n);
- Py_DECREF(k);
- return NULL;
-}
-
-
static PyMethodDef math_methods[] = {
{"acos", math_acos, METH_O, math_acos_doc},
{"acosh", math_acosh, METH_O, math_acosh_doc},
@@ -3283,7 +2482,6 @@ static PyMethodDef math_methods[] = {
MATH_ISFINITE_METHODDEF
MATH_ISINF_METHODDEF
MATH_ISNAN_METHODDEF
- MATH_ISQRT_METHODDEF
MATH_LDEXP_METHODDEF
{"lgamma", math_lgamma, METH_O, math_lgamma_doc},
MATH_LOG_METHODDEF
@@ -3301,8 +2499,6 @@ static PyMethodDef math_methods[] = {
{"tanh", math_tanh, METH_O, math_tanh_doc},
MATH_TRUNC_METHODDEF
MATH_PROD_METHODDEF
- MATH_PERM_METHODDEF
- MATH_COMB_METHODDEF
{NULL, NULL} /* sentinel */
};
@@ -3327,7 +2523,7 @@ static struct PyModuleDef mathmodule = {
PyMODINIT_FUNC
PyInit_math(void)
{
- PyObject *m;
+ PyObject *m, *imath;
m = PyModule_Create(&mathmodule);
if (m == NULL)
@@ -3341,6 +2537,20 @@ PyInit_math(void)
PyModule_AddObject(m, "nan", PyFloat_FromDouble(m_nan()));
#endif
+ imath = PyImport_ImportModule("imath");
+ if (!imath) {
+ goto error;
+ }
+ imath_factorial = PyObject_GetAttrString(imath, "factorial");
+ Py_DECREF(imath);
+ if (!imath_factorial) {
+ goto error;
+ }
+
finally:
return m;
+
+ error:
+ Py_DECREF(m);
+ return NULL;
}
diff --git a/PC/config.c b/PC/config.c
index 6f34962bd72d4f..45b4e98a9c6700 100644
--- a/PC/config.c
+++ b/PC/config.c
@@ -14,6 +14,7 @@ extern PyObject* PyInit_errno(void);
extern PyObject* PyInit_faulthandler(void);
extern PyObject* PyInit__tracemalloc(void);
extern PyObject* PyInit_gc(void);
+extern PyObject* PyInit__imath(void);
extern PyObject* PyInit_math(void);
extern PyObject* PyInit__md5(void);
extern PyObject* PyInit_nt(void);
@@ -91,6 +92,7 @@ struct _inittab _PyImport_Inittab[] = {
{"errno", PyInit_errno},
{"faulthandler", PyInit_faulthandler},
{"gc", PyInit_gc},
+ {"imath", PyInit__imath},
{"math", PyInit_math},
{"nt", PyInit_nt}, /* Use the NT os functions, not posix */
{"_operator", PyInit__operator},
diff --git a/PCbuild/pythoncore.vcxproj b/PCbuild/pythoncore.vcxproj
index 329f9feb2bdf02..d8eb30a05a0b29 100644
--- a/PCbuild/pythoncore.vcxproj
+++ b/PCbuild/pythoncore.vcxproj
@@ -316,6 +316,7 @@
+
diff --git a/PCbuild/pythoncore.vcxproj.filters b/PCbuild/pythoncore.vcxproj.filters
index d80d05fb15a0cf..54d7c1095cb564 100644
--- a/PCbuild/pythoncore.vcxproj.filters
+++ b/PCbuild/pythoncore.vcxproj.filters
@@ -635,6 +635,9 @@
Modules
+
+ Modules
+
Modules
diff --git a/setup.py b/setup.py
index 7852c2dfa27e08..647b56e586ffb2 100644
--- a/setup.py
+++ b/setup.py
@@ -691,6 +691,9 @@ def detect_simple_extensions(self):
# Context Variables
self.add(Extension('_contextvars', ['_contextvarsmodule.c']))
+ # integer math library functions, e.g. isqrt()
+ self.add(Extension('_imath', ['_imathmodule.c']))
+
shared_math = 'Modules/_math.o'
# math library functions, e.g. sin()