Skip to content

Make comparisons is_one, and is_zero consistent for polys #168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 61 additions & 22 deletions src/flint/test/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -1604,7 +1604,7 @@ def test_pickling():

def test_fmpz_mod():
from flint import fmpz_mod_ctx, fmpz, fmpz_mod

p_sml = 163
p_med = 2**127 - 1
p_big = 2**255 - 19
Expand Down Expand Up @@ -1754,7 +1754,7 @@ def test_fmpz_mod():
assert raises(lambda: F_test(test_x) * "AAA", TypeError)
assert raises(lambda: F_test(test_x) * F_other(test_x), ValueError)

# Exponentiation
# Exponentiation

assert F_test(0)**0 == pow(0, 0, test_mod)
assert F_test(0)**1 == pow(0, 1, test_mod)
Expand Down Expand Up @@ -1804,7 +1804,7 @@ def test_fmpz_mod():

assert fmpz(test_y) / F_test(test_x) == (test_y * pow(test_x, -1, test_mod)) % test_mod
assert test_y / F_test(test_x) == (test_y * pow(test_x, -1, test_mod)) % test_mod

def test_fmpz_mod_dlog():
from flint import fmpz, fmpz_mod_ctx

Expand All @@ -1826,7 +1826,7 @@ def test_fmpz_mod_dlog():
F = fmpz_mod_ctx(163)
g = F(2)
a = g**123

assert 123 == g.discrete_log(a)

a_int = pow(2, 123, 163)
Expand Down Expand Up @@ -1877,7 +1877,7 @@ def test_fmpz_mod_poly():
assert repr(R3) == "fmpz_mod_poly_ctx(13)"

assert R1.modulus() == 11

assert R1.is_prime()
assert R1.zero() == 0
assert R1.one() == 1
Expand Down Expand Up @@ -1946,7 +1946,7 @@ def test_fmpz_mod_poly():
assert str(f) == "8*x^3 + 7*x^2 + 6*x + 7"

# TODO: currently repr does pretty printing
# just like str, we should address this. Mainly,
# just like str, we should address this. Mainly,
# the issue is we want nice `repr` behaviour in
# interactive shells, which currently is why this
# choice has been made
Expand Down Expand Up @@ -1992,7 +1992,7 @@ def test_fmpz_mod_poly():
F_sml = fmpz_mod_ctx(p_sml)
F_med = fmpz_mod_ctx(p_med)
F_big = fmpz_mod_ctx(p_big)

R_sml = fmpz_mod_poly_ctx(F_sml)
R_med = fmpz_mod_poly_ctx(F_med)
R_big = fmpz_mod_poly_ctx(F_big)
Expand All @@ -2003,14 +2003,14 @@ def test_fmpz_mod_poly():
f_bad = R_cmp([2,2,2,2,2])

for (F_test, R_test) in [(F_sml, R_sml), (F_med, R_med), (F_big, R_big)]:

f = R_test([-1,-2])
g = R_test([-3,-4])

# pos, neg
assert f is +f
assert -f == R_test([1,2])

# add
assert raises(lambda: f + f_cmp, ValueError)
assert raises(lambda: f + "AAA", TypeError)
Expand Down Expand Up @@ -2063,7 +2063,7 @@ def test_fmpz_mod_poly():
assert raises(lambda: f / "AAA", TypeError)
assert raises(lambda: f / 0, ZeroDivisionError)
assert raises(lambda: f_cmp / 2, ZeroDivisionError)

assert (f + f) / 2 == f
assert (f + f) / fmpz(2) == f
assert (f + f) / F_test(2) == f
Expand All @@ -2077,7 +2077,7 @@ def test_fmpz_mod_poly():
assert (f + f) // 2 == f
assert (f + f) // fmpz(2) == f
assert (f + f) // F_test(2) == f
assert 2 // R_test(2) == 1
assert 2 // R_test(2) == 1
assert (f + 1) // f == 1

# pow
Expand Down Expand Up @@ -2171,7 +2171,7 @@ def test_fmpz_mod_poly():
f1 = R_test([-3, 1])
f2 = R_test([-5, 1])
assert f1.resultant(f2) == (3 - 5)
assert raises(lambda: f.resultant("AAA"), TypeError)
assert raises(lambda: f.resultant("AAA"), TypeError)

# sqrt
f1 = R_test.random_element(irreducible=True)
Expand Down Expand Up @@ -2428,14 +2428,14 @@ def _all_polys():
(flint.fmpz_poly, flint.fmpz, False),
(flint.fmpq_poly, flint.fmpq, True),
(lambda *a: flint.nmod_poly(*a, 17), lambda x: flint.nmod(x, 17), True),
(lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(163)),
lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(163)),
(lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(163)),
lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(163)),
True),
(lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(2**127 - 1)),
lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(2**127 - 1)),
(lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(2**127 - 1)),
lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(2**127 - 1)),
True),
(lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(2**255 - 19)),
lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(2**255 - 19)),
(lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(2**255 - 19)),
lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(2**255 - 19)),
True),
]

Expand Down Expand Up @@ -2467,6 +2467,28 @@ def test_polys():
assert (P([1]) == P([2])) is False
assert (P([1]) != P([2])) is True

assert (P([1]) == 1) is True
assert (P([1]) != 1) is False
assert (P([1]) == 2) is False
assert (P([1]) != 2) is True

assert (1 == P([1])) is True
assert (1 != P([1])) is False
assert (2 == P([1])) is False
assert (2 != P([1])) is True

s1, s2 = S(1), S(2)

assert (P([s1]) == s1) is True
assert (P([s1]) != s1) is False
assert (P([s1]) == s2) is False
assert (P([s1]) != s2) is True

assert (s1 == P([s1])) is True
assert (s1 != P([s1])) is False
assert (s1 == P([s2])) is False
assert (s1 != P([s2])) is True

assert (P([1]) == None) is False
assert (P([1]) != None) is True
assert (None == P([1])) is False
Expand Down Expand Up @@ -2500,12 +2522,17 @@ def setbad(obj, i, val):
assert raises(lambda: setbad(p, -1, 1), ValueError)

for v in [], [1], [1, 2]:
if P == flint.fmpz_poly:
p = P(v)
if type(p) == flint.fmpz_poly:
assert P(v).repr() == f'fmpz_poly({v!r})'
elif P == flint.fmpq_poly:
elif type(p) == flint.fmpq_poly:
assert P(v).repr() == f'fmpq_poly({v!r})'
elif P == flint.nmod_poly:
elif type(p) == flint.nmod_poly:
assert P(v).repr() == f'nmod_poly({v!r}, 17)'
elif type(p) == flint.fmpz_mod_poly:
pass # fmpz_mod_poly does not have .repr() ...
else:
assert False

assert repr(P([])) == '0'
assert repr(P([1])) == '1'
Expand All @@ -2521,6 +2548,12 @@ def setbad(obj, i, val):
assert bool(P([])) is False
assert bool(P([1])) is True

assert P([]).is_zero() is True
assert P([1]).is_zero() is False

assert P([]).is_one() is False
assert P([1]).is_one() is True

assert +P([1, 2, 3]) == P([1, 2, 3])
assert -P([1, 2, 3]) == P([-1, -2, -3])

Expand Down Expand Up @@ -2600,7 +2633,7 @@ def setbad(obj, i, val):
assert P([1, 1]) ** 2 == P([1, 2, 1])
assert raises(lambda: P([1, 1]) ** -1, ValueError)
assert raises(lambda: P([1, 1]) ** None, TypeError)

# # XXX: Not sure what this should do in general:
assert raises(lambda: pow(P([1, 1]), 2, 3), NotImplementedError)

Expand Down Expand Up @@ -2825,6 +2858,12 @@ def quick_poly():
assert bool(P(ctx=ctx)) is False
assert bool(P(1, ctx=ctx)) is True

assert P(ctx=ctx).is_zero() is True
assert P(1, ctx=ctx).is_zero() is False

assert P(ctx=ctx).is_one() is False
assert P(1, ctx=ctx).is_one() is True

assert +quick_poly() \
== quick_poly()
assert -quick_poly() \
Expand Down
2 changes: 1 addition & 1 deletion src/flint/types/acb_mat.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ cdef class acb_mat(flint_mat):
else:
raise ValueError("acb_mat: expected 1-3 arguments")

def __nonzero__(self):
def __bool__(self):
raise NotImplementedError

cpdef long nrows(s):
Expand Down
2 changes: 1 addition & 1 deletion src/flint/types/arb_mat.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ cdef class arb_mat(flint_mat):
else:
raise ValueError("arb_mat: expected 1-3 arguments")

def __nonzero__(self):
def __bool__(self):
raise NotImplementedError

cpdef long nrows(s):
Expand Down
2 changes: 1 addition & 1 deletion src/flint/types/fmpq.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ cdef class fmpq(flint_scalar):
def __trunc__(self):
return self.trunc()

def __nonzero__(self):
def __bool__(self):
return not fmpq_is_zero(self.val)

def __round__(self, ndigits=None):
Expand Down
2 changes: 1 addition & 1 deletion src/flint/types/fmpq_mat.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ cdef class fmpq_mat(flint_mat):
else:
raise TypeError("fmpq_mat: expected 1-3 arguments")

def __nonzero__(self):
def __bool__(self):
return not fmpq_mat_is_zero(self.val)

def __richcmp__(s, t, int op):
Expand Down
9 changes: 6 additions & 3 deletions src/flint/types/fmpq_mpoly.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,12 @@ cdef class fmpq_mpoly(flint_mpoly):
def __bool__(self):
return not fmpq_mpoly_is_zero(self.val, self.ctx.val)

def is_zero(self):
return <bint>fmpq_mpoly_is_zero(self.val, self.ctx.val)

def is_one(self):
return <bint>fmpq_mpoly_is_one(self.val, self.ctx.val)

def __richcmp__(self, other, int op):
if not (op == Py_EQ or op == Py_NE):
return NotImplemented
Expand Down Expand Up @@ -782,9 +788,6 @@ cdef class fmpq_mpoly(flint_mpoly):
"""
return self.ctx

def is_one(self):
return fmpq_mpoly_is_one(self.val, self.ctx.val)

def coefficient(self, slong i):
"""
Return the coefficient at index `i`.
Expand Down
8 changes: 7 additions & 1 deletion src/flint/types/fmpq_poly.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,15 @@ cdef class fmpq_poly(flint_poly):
else:
return "fmpq_poly(%s, %s)" % ([int(c) for c in n.coeffs()], d)

def __nonzero__(self):
def __bool__(self):
return not fmpq_poly_is_zero(self.val)

def is_zero(self):
return <bint>fmpq_poly_is_zero(self.val)

def is_one(self):
return <bint>fmpq_poly_is_one(self.val)

def __call__(self, other):
t = any_as_fmpz(other)
if t is not NotImplemented:
Expand Down
2 changes: 1 addition & 1 deletion src/flint/types/fmpz.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ cdef class fmpz(flint_scalar):
def repr(self):
return "fmpz(%s)" % self.str()

def __nonzero__(self):
def __bool__(self):
return not fmpz_is_zero(self.val)

def __pos__(self):
Expand Down
2 changes: 1 addition & 1 deletion src/flint/types/fmpz_mat.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ cdef class fmpz_mat(flint_mat):
else:
raise TypeError("fmpz_mat: expected 1-3 arguments")

def __nonzero__(self):
def __bool__(self):
return not fmpz_mat_is_zero(self.val)

def __richcmp__(fmpz_mat s, t, int op):
Expand Down
2 changes: 1 addition & 1 deletion src/flint/types/fmpz_mod_mat.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ cdef class fmpz_mod_mat(flint_mat):
e = self.ctx.any_as_fmpz_mod(value)
self._setitem(i, j, e.val)

def __nonzero__(self):
def __bool__(self):
"""Return ``True`` if the matrix has any nonzero entries."""
cdef bint zero
zero = compat_fmpz_mod_mat_is_zero(self.val, self.ctx.val)
Expand Down
9 changes: 6 additions & 3 deletions src/flint/types/fmpz_mpoly.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,12 @@ cdef class fmpz_mpoly(flint_mpoly):
def __bool__(self):
return not fmpz_mpoly_is_zero(self.val, self.ctx.val)

def is_zero(self):
return <bint>fmpz_mpoly_is_zero(self.val, self.ctx.val)

def is_one(self):
return <bint>fmpz_mpoly_is_one(self.val, self.ctx.val)

def __richcmp__(self, other, int op):
if not (op == Py_EQ or op == Py_NE):
return NotImplemented
Expand Down Expand Up @@ -764,9 +770,6 @@ cdef class fmpz_mpoly(flint_mpoly):
"""
return self.ctx

def is_one(self):
return fmpz_mpoly_is_one(self.val, self.ctx.val)

def coefficient(self, slong i):
"""
Return the coefficient at index `i`.
Expand Down
8 changes: 7 additions & 1 deletion src/flint/types/fmpz_poly.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,15 @@ cdef class fmpz_poly(flint_poly):
def repr(self):
return "fmpz_poly([%s])" % (", ".join(map(str, self.coeffs())))

def __nonzero__(self):
def __bool__(self):
return not fmpz_poly_is_zero(self.val)

def is_zero(self):
return <bint>fmpz_poly_is_zero(self.val)

def is_one(self):
return <bint>fmpz_poly_is_one(self.val)

def __call__(self, other):
t = any_as_fmpz(other)
if t is not NotImplemented:
Expand Down
2 changes: 1 addition & 1 deletion src/flint/types/nmod.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ cdef class nmod(flint_scalar):
def __hash__(self):
return hash((int(self.val), self.modulus))

def __nonzero__(self):
def __bool__(self):
return self.val != 0

def __pos__(self):
Expand Down
2 changes: 1 addition & 1 deletion src/flint/types/nmod_mat.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ cdef class nmod_mat(flint_mat):
else:
raise TypeError("nmod_mat: expected 1-3 arguments plus modulus")

def __nonzero__(self):
def __bool__(self):
return not nmod_mat_is_zero(self.val)

def __richcmp__(s, t, int op):
Expand Down
Loading
Loading