Skip to content

Commit 98ecfe1

Browse files
committed
allow comparison with nmod to fmpz and fmpz_mod
1 parent 1178e20 commit 98ecfe1

File tree

2 files changed

+39
-16
lines changed

2 files changed

+39
-16
lines changed

src/flint/test/test_all.py

+18
Original file line numberDiff line numberDiff line change
@@ -1357,6 +1357,24 @@ def test_nmod():
13571357
assert str(G(3,5)) == "3"
13581358
assert G(3,5).repr() == "nmod(3, 5)"
13591359

1360+
# We can compare to int and fmpz types
1361+
assert G(1, 5) == int(1)
1362+
assert G(4, 5) == int(-1)
1363+
assert G(1, 5) == flint.fmpz(1)
1364+
assert G(4, 5) == flint.fmpz(-1)
1365+
1366+
# When the modulus matches, we can compare fmpz_mod
1367+
R = flint.fmpz_mod_ctx(5)
1368+
assert G(1, 5) == R(1)
1369+
assert G(1, 5) != R(-1)
1370+
assert G(4, 5) == R(4)
1371+
assert G(4, 5) == R(-1)
1372+
# when the modulus doesnt match, everything fails
1373+
assert G(1, 7) != R(1)
1374+
assert G(1, 7) != R(-1)
1375+
assert G(4, 7) != R(4)
1376+
assert G(4, 7) != R(-1)
1377+
13601378
def test_nmod_poly():
13611379
N = flint.nmod
13621380
P = flint.nmod_poly

src/flint/types/nmod.pyx

+21-16
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ from flint.utils.typecheck cimport typecheck
33
from flint.types.fmpq cimport any_as_fmpq
44
from flint.types.fmpz cimport any_as_fmpz
55
from flint.types.fmpz cimport fmpz
6+
from flint.types.fmpz_mod cimport fmpz_mod
67
from flint.types.fmpq cimport fmpq
78

89
from flint.flintlib.flint cimport ulong
@@ -66,25 +67,29 @@ cdef class nmod(flint_scalar):
6667
def modulus(self):
6768
return self.mod.n
6869

69-
def __richcmp__(s, t, int op):
70-
cdef mp_limb_t v
70+
def __richcmp__(self, other, int op):
7171
cdef bint res
72+
7273
if op != 2 and op != 3:
7374
raise TypeError("nmods cannot be ordered")
74-
if typecheck(s, nmod) and typecheck(t, nmod):
75-
res = ((<nmod>s).val == (<nmod>t).val) and \
76-
((<nmod>s).mod.n == (<nmod>t).mod.n)
77-
if op == 2:
78-
return res
79-
else:
80-
return not res
81-
elif typecheck(s, nmod) and typecheck(t, int):
82-
res = s.val == (t % s.mod.n)
83-
if op == 2:
84-
return res
85-
else:
86-
return not res
87-
return NotImplemented
75+
76+
if typecheck(other, nmod):
77+
res = self.val == (<nmod>other).val and \
78+
self.mod.n == (<nmod>other).mod.n
79+
elif typecheck(other, int):
80+
res = self.val == (other % self.mod.n)
81+
elif typecheck(other, fmpz):
82+
res = self.val == (int(other) % self.mod.n)
83+
elif typecheck(other, fmpz_mod):
84+
res = self.mod.n == (<fmpz_mod>other).ctx.modulus() and \
85+
self.val == int(other)
86+
else:
87+
return NotImplemented
88+
89+
if op == 2:
90+
return res
91+
else:
92+
return not res
8893

8994
def __hash__(self):
9095
return hash((int(self.val), self.modulus))

0 commit comments

Comments
 (0)