Skip to content

Commit 2e4b0ab

Browse files
Merge pull request #168 from oscarbenjamin/pr_comparisons
Make comparisons is_one, and is_zero consistent for polys
2 parents 3be51d9 + f24f9aa commit 2e4b0ab

15 files changed

+120
-41
lines changed

src/flint/test/test_all.py

+61-22
Original file line numberDiff line numberDiff line change
@@ -1604,7 +1604,7 @@ def test_pickling():
16041604

16051605
def test_fmpz_mod():
16061606
from flint import fmpz_mod_ctx, fmpz, fmpz_mod
1607-
1607+
16081608
p_sml = 163
16091609
p_med = 2**127 - 1
16101610
p_big = 2**255 - 19
@@ -1754,7 +1754,7 @@ def test_fmpz_mod():
17541754
assert raises(lambda: F_test(test_x) * "AAA", TypeError)
17551755
assert raises(lambda: F_test(test_x) * F_other(test_x), ValueError)
17561756

1757-
# Exponentiation
1757+
# Exponentiation
17581758

17591759
assert F_test(0)**0 == pow(0, 0, test_mod)
17601760
assert F_test(0)**1 == pow(0, 1, test_mod)
@@ -1804,7 +1804,7 @@ def test_fmpz_mod():
18041804

18051805
assert fmpz(test_y) / F_test(test_x) == (test_y * pow(test_x, -1, test_mod)) % test_mod
18061806
assert test_y / F_test(test_x) == (test_y * pow(test_x, -1, test_mod)) % test_mod
1807-
1807+
18081808
def test_fmpz_mod_dlog():
18091809
from flint import fmpz, fmpz_mod_ctx
18101810

@@ -1826,7 +1826,7 @@ def test_fmpz_mod_dlog():
18261826
F = fmpz_mod_ctx(163)
18271827
g = F(2)
18281828
a = g**123
1829-
1829+
18301830
assert 123 == g.discrete_log(a)
18311831

18321832
a_int = pow(2, 123, 163)
@@ -1877,7 +1877,7 @@ def test_fmpz_mod_poly():
18771877
assert repr(R3) == "fmpz_mod_poly_ctx(13)"
18781878

18791879
assert R1.modulus() == 11
1880-
1880+
18811881
assert R1.is_prime()
18821882
assert R1.zero() == 0
18831883
assert R1.one() == 1
@@ -1946,7 +1946,7 @@ def test_fmpz_mod_poly():
19461946
assert str(f) == "8*x^3 + 7*x^2 + 6*x + 7"
19471947

19481948
# TODO: currently repr does pretty printing
1949-
# just like str, we should address this. Mainly,
1949+
# just like str, we should address this. Mainly,
19501950
# the issue is we want nice `repr` behaviour in
19511951
# interactive shells, which currently is why this
19521952
# choice has been made
@@ -1992,7 +1992,7 @@ def test_fmpz_mod_poly():
19921992
F_sml = fmpz_mod_ctx(p_sml)
19931993
F_med = fmpz_mod_ctx(p_med)
19941994
F_big = fmpz_mod_ctx(p_big)
1995-
1995+
19961996
R_sml = fmpz_mod_poly_ctx(F_sml)
19971997
R_med = fmpz_mod_poly_ctx(F_med)
19981998
R_big = fmpz_mod_poly_ctx(F_big)
@@ -2003,14 +2003,14 @@ def test_fmpz_mod_poly():
20032003
f_bad = R_cmp([2,2,2,2,2])
20042004

20052005
for (F_test, R_test) in [(F_sml, R_sml), (F_med, R_med), (F_big, R_big)]:
2006-
2006+
20072007
f = R_test([-1,-2])
20082008
g = R_test([-3,-4])
20092009

20102010
# pos, neg
20112011
assert f is +f
20122012
assert -f == R_test([1,2])
2013-
2013+
20142014
# add
20152015
assert raises(lambda: f + f_cmp, ValueError)
20162016
assert raises(lambda: f + "AAA", TypeError)
@@ -2063,7 +2063,7 @@ def test_fmpz_mod_poly():
20632063
assert raises(lambda: f / "AAA", TypeError)
20642064
assert raises(lambda: f / 0, ZeroDivisionError)
20652065
assert raises(lambda: f_cmp / 2, ZeroDivisionError)
2066-
2066+
20672067
assert (f + f) / 2 == f
20682068
assert (f + f) / fmpz(2) == f
20692069
assert (f + f) / F_test(2) == f
@@ -2077,7 +2077,7 @@ def test_fmpz_mod_poly():
20772077
assert (f + f) // 2 == f
20782078
assert (f + f) // fmpz(2) == f
20792079
assert (f + f) // F_test(2) == f
2080-
assert 2 // R_test(2) == 1
2080+
assert 2 // R_test(2) == 1
20812081
assert (f + 1) // f == 1
20822082

20832083
# pow
@@ -2171,7 +2171,7 @@ def test_fmpz_mod_poly():
21712171
f1 = R_test([-3, 1])
21722172
f2 = R_test([-5, 1])
21732173
assert f1.resultant(f2) == (3 - 5)
2174-
assert raises(lambda: f.resultant("AAA"), TypeError)
2174+
assert raises(lambda: f.resultant("AAA"), TypeError)
21752175

21762176
# sqrt
21772177
f1 = R_test.random_element(irreducible=True)
@@ -2428,14 +2428,14 @@ def _all_polys():
24282428
(flint.fmpz_poly, flint.fmpz, False),
24292429
(flint.fmpq_poly, flint.fmpq, True),
24302430
(lambda *a: flint.nmod_poly(*a, 17), lambda x: flint.nmod(x, 17), True),
2431-
(lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(163)),
2432-
lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(163)),
2431+
(lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(163)),
2432+
lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(163)),
24332433
True),
2434-
(lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(2**127 - 1)),
2435-
lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(2**127 - 1)),
2434+
(lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(2**127 - 1)),
2435+
lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(2**127 - 1)),
24362436
True),
2437-
(lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(2**255 - 19)),
2438-
lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(2**255 - 19)),
2437+
(lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(2**255 - 19)),
2438+
lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(2**255 - 19)),
24392439
True),
24402440
]
24412441

@@ -2467,6 +2467,28 @@ def test_polys():
24672467
assert (P([1]) == P([2])) is False
24682468
assert (P([1]) != P([2])) is True
24692469

2470+
assert (P([1]) == 1) is True
2471+
assert (P([1]) != 1) is False
2472+
assert (P([1]) == 2) is False
2473+
assert (P([1]) != 2) is True
2474+
2475+
assert (1 == P([1])) is True
2476+
assert (1 != P([1])) is False
2477+
assert (2 == P([1])) is False
2478+
assert (2 != P([1])) is True
2479+
2480+
s1, s2 = S(1), S(2)
2481+
2482+
assert (P([s1]) == s1) is True
2483+
assert (P([s1]) != s1) is False
2484+
assert (P([s1]) == s2) is False
2485+
assert (P([s1]) != s2) is True
2486+
2487+
assert (s1 == P([s1])) is True
2488+
assert (s1 != P([s1])) is False
2489+
assert (s1 == P([s2])) is False
2490+
assert (s1 != P([s2])) is True
2491+
24702492
assert (P([1]) == None) is False
24712493
assert (P([1]) != None) is True
24722494
assert (None == P([1])) is False
@@ -2500,12 +2522,17 @@ def setbad(obj, i, val):
25002522
assert raises(lambda: setbad(p, -1, 1), ValueError)
25012523

25022524
for v in [], [1], [1, 2]:
2503-
if P == flint.fmpz_poly:
2525+
p = P(v)
2526+
if type(p) == flint.fmpz_poly:
25042527
assert P(v).repr() == f'fmpz_poly({v!r})'
2505-
elif P == flint.fmpq_poly:
2528+
elif type(p) == flint.fmpq_poly:
25062529
assert P(v).repr() == f'fmpq_poly({v!r})'
2507-
elif P == flint.nmod_poly:
2530+
elif type(p) == flint.nmod_poly:
25082531
assert P(v).repr() == f'nmod_poly({v!r}, 17)'
2532+
elif type(p) == flint.fmpz_mod_poly:
2533+
pass # fmpz_mod_poly does not have .repr() ...
2534+
else:
2535+
assert False
25092536

25102537
assert repr(P([])) == '0'
25112538
assert repr(P([1])) == '1'
@@ -2521,6 +2548,12 @@ def setbad(obj, i, val):
25212548
assert bool(P([])) is False
25222549
assert bool(P([1])) is True
25232550

2551+
assert P([]).is_zero() is True
2552+
assert P([1]).is_zero() is False
2553+
2554+
assert P([]).is_one() is False
2555+
assert P([1]).is_one() is True
2556+
25242557
assert +P([1, 2, 3]) == P([1, 2, 3])
25252558
assert -P([1, 2, 3]) == P([-1, -2, -3])
25262559

@@ -2600,7 +2633,7 @@ def setbad(obj, i, val):
26002633
assert P([1, 1]) ** 2 == P([1, 2, 1])
26012634
assert raises(lambda: P([1, 1]) ** -1, ValueError)
26022635
assert raises(lambda: P([1, 1]) ** None, TypeError)
2603-
2636+
26042637
# # XXX: Not sure what this should do in general:
26052638
assert raises(lambda: pow(P([1, 1]), 2, 3), NotImplementedError)
26062639

@@ -2825,6 +2858,12 @@ def quick_poly():
28252858
assert bool(P(ctx=ctx)) is False
28262859
assert bool(P(1, ctx=ctx)) is True
28272860

2861+
assert P(ctx=ctx).is_zero() is True
2862+
assert P(1, ctx=ctx).is_zero() is False
2863+
2864+
assert P(ctx=ctx).is_one() is False
2865+
assert P(1, ctx=ctx).is_one() is True
2866+
28282867
assert +quick_poly() \
28292868
== quick_poly()
28302869
assert -quick_poly() \

src/flint/types/acb_mat.pyx

+1-1
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ cdef class acb_mat(flint_mat):
150150
else:
151151
raise ValueError("acb_mat: expected 1-3 arguments")
152152

153-
def __nonzero__(self):
153+
def __bool__(self):
154154
raise NotImplementedError
155155

156156
cpdef long nrows(s):

src/flint/types/arb_mat.pyx

+1-1
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ cdef class arb_mat(flint_mat):
148148
else:
149149
raise ValueError("arb_mat: expected 1-3 arguments")
150150

151-
def __nonzero__(self):
151+
def __bool__(self):
152152
raise NotImplementedError
153153

154154
cpdef long nrows(s):

src/flint/types/fmpq.pyx

+1-1
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ cdef class fmpq(flint_scalar):
186186
def __trunc__(self):
187187
return self.trunc()
188188

189-
def __nonzero__(self):
189+
def __bool__(self):
190190
return not fmpq_is_zero(self.val)
191191

192192
def __round__(self, ndigits=None):

src/flint/types/fmpq_mat.pyx

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ cdef class fmpq_mat(flint_mat):
9292
else:
9393
raise TypeError("fmpq_mat: expected 1-3 arguments")
9494

95-
def __nonzero__(self):
95+
def __bool__(self):
9696
return not fmpq_mat_is_zero(self.val)
9797

9898
def __richcmp__(s, t, int op):

src/flint/types/fmpq_mpoly.pyx

+6-3
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,12 @@ cdef class fmpq_mpoly(flint_mpoly):
243243
def __bool__(self):
244244
return not fmpq_mpoly_is_zero(self.val, self.ctx.val)
245245

246+
def is_zero(self):
247+
return <bint>fmpq_mpoly_is_zero(self.val, self.ctx.val)
248+
249+
def is_one(self):
250+
return <bint>fmpq_mpoly_is_one(self.val, self.ctx.val)
251+
246252
def __richcmp__(self, other, int op):
247253
if not (op == Py_EQ or op == Py_NE):
248254
return NotImplemented
@@ -782,9 +788,6 @@ cdef class fmpq_mpoly(flint_mpoly):
782788
"""
783789
return self.ctx
784790

785-
def is_one(self):
786-
return fmpq_mpoly_is_one(self.val, self.ctx.val)
787-
788791
def coefficient(self, slong i):
789792
"""
790793
Return the coefficient at index `i`.

src/flint/types/fmpq_poly.pyx

+7-1
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,15 @@ cdef class fmpq_poly(flint_poly):
167167
else:
168168
return "fmpq_poly(%s, %s)" % ([int(c) for c in n.coeffs()], d)
169169

170-
def __nonzero__(self):
170+
def __bool__(self):
171171
return not fmpq_poly_is_zero(self.val)
172172

173+
def is_zero(self):
174+
return <bint>fmpq_poly_is_zero(self.val)
175+
176+
def is_one(self):
177+
return <bint>fmpq_poly_is_one(self.val)
178+
173179
def __call__(self, other):
174180
t = any_as_fmpz(other)
175181
if t is not NotImplemented:

src/flint/types/fmpz.pyx

+1-1
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ cdef class fmpz(flint_scalar):
168168
def repr(self):
169169
return "fmpz(%s)" % self.str()
170170

171-
def __nonzero__(self):
171+
def __bool__(self):
172172
return not fmpz_is_zero(self.val)
173173

174174
def __pos__(self):

src/flint/types/fmpz_mat.pyx

+1-1
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ cdef class fmpz_mat(flint_mat):
131131
else:
132132
raise TypeError("fmpz_mat: expected 1-3 arguments")
133133

134-
def __nonzero__(self):
134+
def __bool__(self):
135135
return not fmpz_mat_is_zero(self.val)
136136

137137
def __richcmp__(fmpz_mat s, t, int op):

src/flint/types/fmpz_mod_mat.pyx

+1-1
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ cdef class fmpz_mod_mat(flint_mat):
303303
e = self.ctx.any_as_fmpz_mod(value)
304304
self._setitem(i, j, e.val)
305305

306-
def __nonzero__(self):
306+
def __bool__(self):
307307
"""Return ``True`` if the matrix has any nonzero entries."""
308308
cdef bint zero
309309
zero = compat_fmpz_mod_mat_is_zero(self.val, self.ctx.val)

src/flint/types/fmpz_mpoly.pyx

+6-3
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,12 @@ cdef class fmpz_mpoly(flint_mpoly):
223223
def __bool__(self):
224224
return not fmpz_mpoly_is_zero(self.val, self.ctx.val)
225225

226+
def is_zero(self):
227+
return <bint>fmpz_mpoly_is_zero(self.val, self.ctx.val)
228+
229+
def is_one(self):
230+
return <bint>fmpz_mpoly_is_one(self.val, self.ctx.val)
231+
226232
def __richcmp__(self, other, int op):
227233
if not (op == Py_EQ or op == Py_NE):
228234
return NotImplemented
@@ -764,9 +770,6 @@ cdef class fmpz_mpoly(flint_mpoly):
764770
"""
765771
return self.ctx
766772

767-
def is_one(self):
768-
return fmpz_mpoly_is_one(self.val, self.ctx.val)
769-
770773
def coefficient(self, slong i):
771774
"""
772775
Return the coefficient at index `i`.

src/flint/types/fmpz_poly.pyx

+7-1
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,15 @@ cdef class fmpz_poly(flint_poly):
137137
def repr(self):
138138
return "fmpz_poly([%s])" % (", ".join(map(str, self.coeffs())))
139139

140-
def __nonzero__(self):
140+
def __bool__(self):
141141
return not fmpz_poly_is_zero(self.val)
142142

143+
def is_zero(self):
144+
return <bint>fmpz_poly_is_zero(self.val)
145+
146+
def is_one(self):
147+
return <bint>fmpz_poly_is_one(self.val)
148+
143149
def __call__(self, other):
144150
t = any_as_fmpz(other)
145151
if t is not NotImplemented:

src/flint/types/nmod.pyx

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ cdef class nmod(flint_scalar):
8989
def __hash__(self):
9090
return hash((int(self.val), self.modulus))
9191

92-
def __nonzero__(self):
92+
def __bool__(self):
9393
return self.val != 0
9494

9595
def __pos__(self):

src/flint/types/nmod_mat.pyx

+1-1
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ cdef class nmod_mat(flint_mat):
136136
else:
137137
raise TypeError("nmod_mat: expected 1-3 arguments plus modulus")
138138

139-
def __nonzero__(self):
139+
def __bool__(self):
140140
return not nmod_mat_is_zero(self.val)
141141

142142
def __richcmp__(s, t, int op):

0 commit comments

Comments
 (0)