Skip to content

Commit 607387e

Browse files
committed
exact division for fmpq_poly, nmod_poly and fmpz_mod_poly
1 parent d92524d commit 607387e

File tree

5 files changed

+138
-42
lines changed

5 files changed

+138
-42
lines changed

src/flint/test/test.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -970,7 +970,8 @@ def test_fmpq_poly():
970970
assert 3 * Q([1,2,3]) == Q([3,6,9])
971971
assert Q([1,2,3]) * flint.fmpq(2,3) == (Q([1,2,3]) * 2) / 3
972972
assert flint.fmpq(2,3) * Q([1,2,3]) == (Q([1,2,3]) * 2) / 3
973-
assert raises(lambda: Q([1,2]) / Q([1,2]), TypeError)
973+
assert Q([1,2]) / Q([1,2]) == Q([1])
974+
assert raises(lambda: Q([1,2]) / Q([2,2]), DomainError)
974975
assert Q([1,2,3]) / flint.fmpq(2,3) == Q([1,2,3]) * flint.fmpq(3,2)
975976
assert Q([1,2,3]) ** 2 == Q([1,2,3]) * Q([1,2,3])
976977
assert raises(lambda: pow(Q([1,2]), 3, 5), NotImplementedError)
@@ -2061,7 +2062,7 @@ def test_fmpz_mod_poly():
20612062
assert raises(lambda: f.exact_division(0), ZeroDivisionError)
20622063

20632064
assert (f * g).exact_division(g) == f
2064-
assert raises(lambda: f.exact_division(g), ValueError)
2065+
assert raises(lambda: f.exact_division(g), DomainError)
20652066

20662067
# true div
20672068
assert raises(lambda: f / "AAA", TypeError)
@@ -2340,10 +2341,52 @@ def test_division_poly():
23402341
assert PK([2, 5]) / 2 == PK([K(2)/K(2), K(5)/K(2)])
23412342
# No other scalar division is allowed
23422343
for (R, PR) in [(Z, PZ), (Q, PQ), (F17, PF17), (F163, PF163)]:
2343-
assert raises(lambda: R(2) / PR([2, 5]), TypeError)
2344-
assert raises(lambda: 2 / PR([2, 5]), TypeError)
2344+
assert raises(lambda: R(2) / PR([2, 5]), DomainError)
2345+
assert raises(lambda: 2 / PR([2, 5]), DomainError)
23452346
assert raises(lambda: PR([2, 5]) / 0, ZeroDivisionError)
23462347
assert raises(lambda: PR([2, 5]) / R(0), ZeroDivisionError)
2348+
# exact polynomial division
2349+
for (R, PR) in [(Z, PZ), (Q, PQ), (F17, PF17), (F163, PF163)]:
2350+
assert PR([2, 4]) / PR([1, 2]) == PR([2])
2351+
assert PR([2, -3, 1]) / PR([-1, 1]) == PR([-2, 1])
2352+
assert raises(lambda: PR([2, 4]) / PR([1, 3]), DomainError)
2353+
assert PR([2]) / PR([2]) == 2 / PR([2]) == PR([1])
2354+
assert PR([0]) / PR([1, 2]) == 0 / PR([1, 2]) == PR([0])
2355+
if R is Z:
2356+
assert raises(lambda: PR([1, 2]) / PR([2, 4]), DomainError)
2357+
assert raises(lambda: 1 / PR([2]), DomainError)
2358+
else:
2359+
assert PR([1, 2]) / PR([2, 4]) == PR([R(1)/R(2)])
2360+
assert 1 / PR([2]) == PR([R(1)/R(2)])
2361+
assert raises(lambda: PR([1, 2]) / PR([0]), ZeroDivisionError)
2362+
# Euclidean polynomial division
2363+
for (R, PR) in [(Z, PZ), (Q, PQ), (F17, PF17), (F163, PF163)]:
2364+
assert PR([2, 4]) // PR([1, 2]) == PR([2])
2365+
assert PR([2, 4]) % PR([1, 2]) == PR([0])
2366+
assert divmod(PR([2, 4]), PR([1, 2])) == (PR([2]), PR([0]))
2367+
assert PR([3, -3, 1]) // PR([-1, 1]) == PR([-2, 1])
2368+
assert PR([3, -3, 1]) % PR([-1, 1]) == PR([1])
2369+
assert divmod(PR([3, -3, 1]), PR([-1, 1])) == (PR([-2, 1]), PR([1]))
2370+
assert PR([2]) // PR([2]) == 2 // PR([2]) == PR([1])
2371+
assert PR([2]) % PR([2]) == 2 % PR([2]) == PR([0])
2372+
assert divmod(PR([2]), PR([2])) == (PR([1]), PR([0]))
2373+
assert PR([0]) // PR([1, 2]) == 0 // PR([1, 2]) == PR([0])
2374+
assert PR([0]) % PR([1, 2]) == 0 % PR([1, 2]) == PR([0])
2375+
assert divmod(PR([0]), PR([1, 2])) == (PR([0]), PR([0]))
2376+
if R is Z:
2377+
assert PR([2, 2]) // PR([2, 4]) == PR([2, 2]) // PR([2, 4]) == PR([0])
2378+
assert PR([2, 2]) % PR([2, 4]) == PR([2, 2]) % PR([2, 4]) == PR([2, 2])
2379+
assert divmod(PR([2, 2]), PR([2, 4])) == (PR([0]), PR([2, 2]))
2380+
assert 1 // PR([2]) == PR([1]) // PR([2]) == PR([0])
2381+
assert 1 % PR([2]) == PR([1]) % PR([2]) == PR([1])
2382+
assert divmod(1, PR([2])) == (PR([0]), PR([1]))
2383+
else:
2384+
assert PR([2, 2]) // PR([2, 4]) == PR([R(1)/R(2)])
2385+
assert PR([2, 2]) % PR([2, 4]) == PR([1])
2386+
assert divmod(PR([2, 2]), PR([2, 4])) == (PR([R(1)/R(2)]), PR([1]))
2387+
assert 1 // PR([2]) == PR([R(1)/R(2)])
2388+
assert 1 % PR([2]) == PR([0])
2389+
assert divmod(1, PR([2])) == (PR([R(1)/R(2)]), PR([0]))
23472390

23482391

23492392
def test_division_matrix():
@@ -2553,9 +2596,9 @@ def setbad(obj, i, val):
25532596
assert raises(lambda: P([1, 2]) / 2, DomainError)
25542597
assert raises(lambda: P([1, 2]) / 0, ZeroDivisionError)
25552598

2556-
assert raises(lambda: 1 / P([1, 1]), TypeError)
2557-
assert raises(lambda: P([1, 2, 1]) / P([1, 1]), TypeError)
2558-
assert raises(lambda: P([1, 2, 1]) / P([1, 2]), TypeError)
2599+
assert P([1, 2, 1]) / P([1, 1]) == P([1, 1])
2600+
assert raises(lambda: 1 / P([1, 1]), DomainError)
2601+
assert raises(lambda: P([1, 2, 1]) / P([1, 2]), DomainError)
25592602

25602603
assert P([1, 1]) ** 0 == P([1])
25612604
assert P([1, 1]) ** 1 == P([1, 1])

src/flint/types/fmpq_poly.pyx

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ from flint.flintlib.arith cimport arith_bernoulli_polynomial
1616
from flint.flintlib.arith cimport arith_euler_polynomial
1717
from flint.flintlib.arith cimport arith_legendre_polynomial
1818

19+
from flint.utils.flint_exceptions import DomainError
20+
21+
1922
cdef any_as_fmpq_poly(obj):
2023
if typecheck(obj, fmpq_poly):
2124
return obj
@@ -295,23 +298,35 @@ cdef class fmpq_poly(flint_poly):
295298
return t
296299
return t._mod_(s)
297300

298-
@staticmethod
299-
def _div_(fmpq_poly s, t):
300-
cdef fmpq_poly r
301-
t = any_as_fmpq(t)
301+
def __truediv__(fmpq_poly s, t):
302+
cdef fmpq_poly res
303+
cdef fmpq_poly_t r
304+
t2 = any_as_fmpq(t)
305+
if t2 is NotImplemented:
306+
t2 = any_as_fmpq_poly(t)
307+
if t2 is NotImplemented:
308+
return t2
309+
if fmpq_poly_is_zero((<fmpq_poly>t2).val):
310+
raise ZeroDivisionError("fmpq_poly division by 0")
311+
res = fmpq_poly.__new__(fmpq_poly)
312+
fmpq_poly_init(r)
313+
fmpq_poly_divrem(res.val, r, (<fmpq_poly>s).val, (<fmpq_poly>t2).val)
314+
exact = fmpq_poly_is_zero(r)
315+
fmpq_poly_clear(r)
316+
if not exact:
317+
raise DomainError("fmpq_poly inexact division")
318+
else:
319+
if fmpq_is_zero((<fmpq>t2).val):
320+
raise ZeroDivisionError("fmpq_poly scalar division by 0")
321+
res = fmpq_poly.__new__(fmpq_poly)
322+
fmpq_poly_scalar_div_fmpq(res.val, (<fmpq_poly>s).val, (<fmpq>t2).val)
323+
return res
324+
325+
def __rtruediv__(fmpq_poly s, t):
326+
t = any_as_fmpq_poly(t)
302327
if t is NotImplemented:
303328
return t
304-
if fmpq_is_zero((<fmpq>t).val):
305-
raise ZeroDivisionError("fmpq_poly scalar division by 0")
306-
r = fmpq_poly.__new__(fmpq_poly)
307-
fmpq_poly_scalar_div_fmpq(r.val, (<fmpq_poly>s).val, (<fmpq>t).val)
308-
return r
309-
310-
def __div__(s, t):
311-
return fmpq_poly._div_(s, t)
312-
313-
def __truediv__(s, t):
314-
return fmpq_poly._div_(s, t)
329+
return t / s
315330

316331
def _divmod_(s, t):
317332
cdef fmpq_poly P, Q

src/flint/types/fmpz_mod_poly.pyx

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,19 @@ cdef class fmpz_mod_poly(flint_poly):
452452
return res
453453

454454
def __truediv__(s, t):
455-
return fmpz_mod_poly._div_(s, t)
455+
t2 = s.ctx.mod.any_as_fmpz_mod(t)
456+
if t2 is not NotImplemented:
457+
return s._div_(t2)
458+
t2 = s.ctx.any_as_fmpz_mod_poly(t)
459+
if t2 is NotImplemented:
460+
return NotImplemented
461+
return s.exact_division(t2)
462+
463+
def __rtruediv__(s, t):
464+
t = s.ctx.any_as_fmpz_mod_poly(t)
465+
if t is NotImplemented:
466+
return NotImplemented
467+
return t.exact_division(s)
456468

457469
def exact_division(self, right):
458470
"""
@@ -482,7 +494,7 @@ cdef class fmpz_mod_poly(flint_poly):
482494
res.val, self.val, (<fmpz_mod_poly>right).val, res.ctx.mod.val
483495
)
484496
if check == 0:
485-
raise ValueError(
497+
raise DomainError(
486498
f"{right} does not divide {self}"
487499
)
488500

src/flint/types/fmpz_poly.pyx

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -232,18 +232,32 @@ cdef class fmpz_poly(flint_poly):
232232

233233
def __truediv__(fmpz_poly self, other):
234234
cdef fmpz_poly res
235-
other = any_as_fmpz(other)
236-
if other is NotImplemented:
237-
return other
238-
if fmpz_is_zero((<fmpz>other).val):
239-
raise ZeroDivisionError("fmpz_poly division by 0")
240-
res = fmpz_poly.__new__(fmpz_poly)
241-
fmpz_poly_scalar_divexact_fmpz(res.val, self.val, (<fmpz>other).val)
242-
# Check division is exact - there should be a better way to do this
243-
if res * other != self:
244-
raise DomainError("fmpz_poly division is not exact")
235+
o = any_as_fmpz(other)
236+
if o is NotImplemented:
237+
o = any_as_fmpz_poly(other)
238+
if o is NotImplemented:
239+
return NotImplemented
240+
if fmpz_poly_is_zero((<fmpz_poly>o).val):
241+
raise ZeroDivisionError("fmpz_poly division by 0")
242+
res, r = self._divmod_(o)
243+
if r:
244+
raise DomainError("fmpz_poly division is not exact")
245+
else:
246+
if fmpz_is_zero((<fmpz>o).val):
247+
raise ZeroDivisionError("fmpz_poly division by 0")
248+
res = fmpz_poly.__new__(fmpz_poly)
249+
fmpz_poly_scalar_divexact_fmpz(res.val, self.val, (<fmpz>o).val)
250+
# Check division is exact - there should be a better way to do this
251+
if res * o != self:
252+
raise DomainError("fmpz_poly division is not exact")
245253
return res
246254

255+
def __rtruediv__(fmpz_poly self, other):
256+
o = any_as_fmpz_poly(other)
257+
if o is NotImplemented:
258+
return NotImplemented
259+
return o / self
260+
247261
def _floordiv_(self, other):
248262
cdef fmpz_poly res
249263
if fmpz_poly_is_zero((<fmpz_poly>other).val):

src/flint/types/nmod_poly.pyx

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ from flint.flintlib.nmod_poly_factor cimport *
1212
from flint.flintlib.fmpz_poly cimport fmpz_poly_get_nmod_poly
1313
from flint.flintlib.ulong_extras cimport n_gcdinv
1414

15+
from flint.utils.flint_exceptions import DomainError
16+
17+
1518
cdef any_as_nmod_poly(obj, nmod_t mod):
1619
cdef nmod_poly r
1720
cdef mp_limb_t v
@@ -258,7 +261,23 @@ cdef class nmod_poly(flint_poly):
258261
def __rmul__(s, t):
259262
return s._mul_(t)
260263

261-
# TODO: __div__, __truediv__
264+
def __truediv__(s, t):
265+
t = any_as_nmod_poly(t, (<nmod_poly>s).val.mod)
266+
if t is NotImplemented:
267+
return t
268+
res, r = s._divmod_(t)
269+
if not nmod_poly_is_zero((<nmod_poly>r).val):
270+
raise DomainError("nmod_poly inexact division")
271+
return res
272+
273+
def __rtruediv__(s, t):
274+
t = any_as_nmod_poly(t, (<nmod_poly>s).val.mod)
275+
if t is NotImplemented:
276+
return t
277+
res, r = t._divmod_(s)
278+
if not nmod_poly_is_zero((<nmod_poly>r).val):
279+
raise DomainError("nmod_poly inexact division")
280+
return res
262281

263282
def _floordiv_(s, t):
264283
cdef nmod_poly r
@@ -308,13 +327,6 @@ cdef class nmod_poly(flint_poly):
308327
return t
309328
return t._divmod_(s)
310329

311-
def __truediv__(s, t):
312-
try:
313-
t = nmod(t, (<nmod_poly>s).val.mod.n)
314-
except TypeError:
315-
return NotImplemented
316-
return s * t ** -1
317-
318330
def __mod__(s, t):
319331
return divmod(s, t)[1] # XXX
320332

0 commit comments

Comments
 (0)