Skip to content

Commit 68c240c

Browse files
Merge pull request #109 from oscarbenjamin/pr_exact_division
Use exact division for `a/b` for `fmpz`, `fmpz_mat`, and `*_poly`
2 parents 1ae8ebf + 607387e commit 68c240c

File tree

7 files changed

+331
-79
lines changed

7 files changed

+331
-79
lines changed

src/flint/test/test.py

Lines changed: 167 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def test_fmpz():
9090
assert int(f) == i
9191
assert flint.fmpz(f) == f
9292
assert flint.fmpz(str(i)) == f
93+
assert raises(lambda: flint.fmpz(1,2), TypeError)
9394
assert raises(lambda: flint.fmpz("qwe"), ValueError)
9495
assert raises(lambda: flint.fmpz([]), TypeError)
9596
for s in L:
@@ -162,6 +163,9 @@ def test_fmpz():
162163
# XXX: Handle negative modulus like int?
163164
assert raises(lambda: pow(flint.fmpz(2), 2, -1), ValueError)
164165

166+
assert raises(lambda: pow(flint.fmpz(2), "asd", 2), TypeError)
167+
assert raises(lambda: pow(flint.fmpz(2), 2, "asd"), TypeError)
168+
165169
f = flint.fmpz(2)
166170
assert f.numerator == f
167171
assert type(f.numerator) is flint.fmpz
@@ -543,7 +547,8 @@ def test_fmpz_mat():
543547
assert str(M(2,2,[1,2,3,4])) == '[1, 2]\n[3, 4]'
544548
assert M(1,2,[3,4]) * flint.fmpq(1,3) == flint.fmpq_mat(1, 2, [1, flint.fmpq(4,3)])
545549
assert flint.fmpq(1,3) * M(1,2,[3,4]) == flint.fmpq_mat(1, 2, [1, flint.fmpq(4,3)])
546-
assert M(1,2,[3,4]) / 3 == flint.fmpq_mat(1, 2, [1, flint.fmpq(4,3)])
550+
assert raises(lambda: M(1,2,[3,4]) / 3, DomainError)
551+
assert M(1,2,[2,4]) / 2 == M(1,2,[1,2])
547552
assert M(2,2,[1,2,3,4]).inv().det() == flint.fmpq(1) / M(2,2,[1,2,3,4]).det()
548553
assert M(2,2,[1,2,3,4]).inv().inv() == M(2,2,[1,2,3,4])
549554
assert raises(lambda: M.randrank(4,3,4,1), ValueError)
@@ -965,7 +970,8 @@ def test_fmpq_poly():
965970
assert 3 * Q([1,2,3]) == Q([3,6,9])
966971
assert Q([1,2,3]) * flint.fmpq(2,3) == (Q([1,2,3]) * 2) / 3
967972
assert flint.fmpq(2,3) * Q([1,2,3]) == (Q([1,2,3]) * 2) / 3
968-
assert raises(lambda: Q([1,2]) / Q([1,2]), TypeError)
973+
assert Q([1,2]) / Q([1,2]) == Q([1])
974+
assert raises(lambda: Q([1,2]) / Q([2,2]), DomainError)
969975
assert Q([1,2,3]) / flint.fmpq(2,3) == Q([1,2,3]) * flint.fmpq(3,2)
970976
assert Q([1,2,3]) ** 2 == Q([1,2,3]) * Q([1,2,3])
971977
assert raises(lambda: pow(Q([1,2]), 3, 5), NotImplementedError)
@@ -2056,7 +2062,7 @@ def test_fmpz_mod_poly():
20562062
assert raises(lambda: f.exact_division(0), ZeroDivisionError)
20572063

20582064
assert (f * g).exact_division(g) == f
2059-
assert raises(lambda: f.exact_division(g), ValueError)
2065+
assert raises(lambda: f.exact_division(g), DomainError)
20602066

20612067
# true div
20622068
assert raises(lambda: f / "AAA", TypeError)
@@ -2276,6 +2282,151 @@ def test_fmpz_mod_mat():
22762282
assert raises(lambda: flint.fmpz_mod_mat(A, c11), TypeError)
22772283

22782284

2285+
def test_division_scalar():
2286+
Z = flint.fmpz
2287+
Q = flint.fmpq
2288+
F17 = lambda x: flint.nmod(x, 17)
2289+
ctx = flint.fmpz_mod_ctx(163)
2290+
F163 = lambda a: flint.fmpz_mod(a, ctx)
2291+
# fmpz exact division
2292+
for (a, b) in [(Z(4), Z(2)), (Z(4), 2), (4, Z(2))]:
2293+
assert a / b == Z(2)
2294+
for (a, b) in [(Z(5), Z(2)), (Z(5), 2), (5, Z(2))]:
2295+
assert raises(lambda: a / b, DomainError)
2296+
# fmpz Euclidean division
2297+
for (a, b) in [(Z(5), Z(2)), (Z(5), 2), (5, Z(2))]:
2298+
assert a // b == 2
2299+
assert a % b == 1
2300+
assert divmod(a, b) == (2, 1)
2301+
# field division
2302+
for (a, b) in [(Q(5), Q(2)), (Q(5), 2), (5, Q(2))]:
2303+
assert a / b == Q(5,2)
2304+
for (a, b) in [(F17(5), F17(2)), (F17(5), 2), (5, F17(2))]:
2305+
assert a / b == F17(11)
2306+
for (a, b) in [(F163(5), F163(2)), (F163(5), 2), (5, F163(2))]:
2307+
assert a / b == F163(84)
2308+
# divmod with fields - should this give remainder zero instead of error?
2309+
for K in [Q, F17, F163]:
2310+
for (a, b) in [(K(5), K(2)), (K(5), 2), (5, K(2))]:
2311+
assert raises(lambda: divmod(a, b), TypeError)
2312+
# Zero division
2313+
for R in [Z, Q, F17, F163]:
2314+
assert raises(lambda: R(5) / 0, ZeroDivisionError)
2315+
assert raises(lambda: R(5) / R(0), ZeroDivisionError)
2316+
assert raises(lambda: 5 / R(0), ZeroDivisionError)
2317+
# Bad types
2318+
for R in [Z, Q, F17, F163]:
2319+
assert raises(lambda: R(5) / "AAA", TypeError)
2320+
assert raises(lambda: "AAA" / R(5), TypeError)
2321+
2322+
2323+
def test_division_poly():
2324+
Z = flint.fmpz
2325+
Q = flint.fmpq
2326+
F17 = lambda x: flint.nmod(x, 17)
2327+
ctx = flint.fmpz_mod_ctx(163)
2328+
F163 = lambda a: flint.fmpz_mod(a, ctx)
2329+
PZ = lambda x: flint.fmpz_poly(x)
2330+
PQ = lambda x: flint.fmpq_poly(x)
2331+
PF17 = lambda x: flint.nmod_poly(x, 17)
2332+
PF163 = lambda x: flint.fmpz_mod_poly(x, flint.fmpz_mod_poly_ctx(163))
2333+
# fmpz exact scalar division
2334+
assert PZ([2, 4]) / Z(2) == PZ([1, 2])
2335+
assert PZ([2, 4]) / 2 == PZ([1, 2])
2336+
assert raises(lambda: PZ([2, 5]) / Z(2), DomainError)
2337+
assert raises(lambda: PZ([2, 5]) / 2, DomainError)
2338+
# field division by scalar
2339+
for (K, PK) in [(Q, PQ), (F17, PF17), (F163, PF163)]:
2340+
assert PK([2, 5]) / K(2) == PK([K(2)/K(2), K(5)/K(2)])
2341+
assert PK([2, 5]) / 2 == PK([K(2)/K(2), K(5)/K(2)])
2342+
# No other scalar division is allowed
2343+
for (R, PR) in [(Z, PZ), (Q, PQ), (F17, PF17), (F163, PF163)]:
2344+
assert raises(lambda: R(2) / PR([2, 5]), DomainError)
2345+
assert raises(lambda: 2 / PR([2, 5]), DomainError)
2346+
assert raises(lambda: PR([2, 5]) / 0, ZeroDivisionError)
2347+
assert raises(lambda: PR([2, 5]) / R(0), ZeroDivisionError)
2348+
# exact polynomial division
2349+
for (R, PR) in [(Z, PZ), (Q, PQ), (F17, PF17), (F163, PF163)]:
2350+
assert PR([2, 4]) / PR([1, 2]) == PR([2])
2351+
assert PR([2, -3, 1]) / PR([-1, 1]) == PR([-2, 1])
2352+
assert raises(lambda: PR([2, 4]) / PR([1, 3]), DomainError)
2353+
assert PR([2]) / PR([2]) == 2 / PR([2]) == PR([1])
2354+
assert PR([0]) / PR([1, 2]) == 0 / PR([1, 2]) == PR([0])
2355+
if R is Z:
2356+
assert raises(lambda: PR([1, 2]) / PR([2, 4]), DomainError)
2357+
assert raises(lambda: 1 / PR([2]), DomainError)
2358+
else:
2359+
assert PR([1, 2]) / PR([2, 4]) == PR([R(1)/R(2)])
2360+
assert 1 / PR([2]) == PR([R(1)/R(2)])
2361+
assert raises(lambda: PR([1, 2]) / PR([0]), ZeroDivisionError)
2362+
# Euclidean polynomial division
2363+
for (R, PR) in [(Z, PZ), (Q, PQ), (F17, PF17), (F163, PF163)]:
2364+
assert PR([2, 4]) // PR([1, 2]) == PR([2])
2365+
assert PR([2, 4]) % PR([1, 2]) == PR([0])
2366+
assert divmod(PR([2, 4]), PR([1, 2])) == (PR([2]), PR([0]))
2367+
assert PR([3, -3, 1]) // PR([-1, 1]) == PR([-2, 1])
2368+
assert PR([3, -3, 1]) % PR([-1, 1]) == PR([1])
2369+
assert divmod(PR([3, -3, 1]), PR([-1, 1])) == (PR([-2, 1]), PR([1]))
2370+
assert PR([2]) // PR([2]) == 2 // PR([2]) == PR([1])
2371+
assert PR([2]) % PR([2]) == 2 % PR([2]) == PR([0])
2372+
assert divmod(PR([2]), PR([2])) == (PR([1]), PR([0]))
2373+
assert PR([0]) // PR([1, 2]) == 0 // PR([1, 2]) == PR([0])
2374+
assert PR([0]) % PR([1, 2]) == 0 % PR([1, 2]) == PR([0])
2375+
assert divmod(PR([0]), PR([1, 2])) == (PR([0]), PR([0]))
2376+
if R is Z:
2377+
assert PR([2, 2]) // PR([2, 4]) == PR([2, 2]) // PR([2, 4]) == PR([0])
2378+
assert PR([2, 2]) % PR([2, 4]) == PR([2, 2]) % PR([2, 4]) == PR([2, 2])
2379+
assert divmod(PR([2, 2]), PR([2, 4])) == (PR([0]), PR([2, 2]))
2380+
assert 1 // PR([2]) == PR([1]) // PR([2]) == PR([0])
2381+
assert 1 % PR([2]) == PR([1]) % PR([2]) == PR([1])
2382+
assert divmod(1, PR([2])) == (PR([0]), PR([1]))
2383+
else:
2384+
assert PR([2, 2]) // PR([2, 4]) == PR([R(1)/R(2)])
2385+
assert PR([2, 2]) % PR([2, 4]) == PR([1])
2386+
assert divmod(PR([2, 2]), PR([2, 4])) == (PR([R(1)/R(2)]), PR([1]))
2387+
assert 1 // PR([2]) == PR([R(1)/R(2)])
2388+
assert 1 % PR([2]) == PR([0])
2389+
assert divmod(1, PR([2])) == (PR([R(1)/R(2)]), PR([0]))
2390+
2391+
2392+
def test_division_matrix():
2393+
Z = flint.fmpz
2394+
Q = flint.fmpq
2395+
F17 = lambda x: flint.nmod(x, 17)
2396+
ctx = flint.fmpz_mod_ctx(163)
2397+
F163 = lambda a: flint.fmpz_mod(a, ctx)
2398+
MZ = lambda x: flint.fmpz_mat(x)
2399+
MQ = lambda x: flint.fmpq_mat(x)
2400+
MF17 = lambda x: flint.nmod_mat(x, 17)
2401+
MF163 = lambda x: flint.fmpz_mod_mat(x, ctx)
2402+
# fmpz exact division
2403+
assert MZ([[2, 4]]) / Z(2) == MZ([[1, 2]])
2404+
assert MZ([[2, 4]]) / 2 == MZ([[1, 2]])
2405+
assert raises(lambda: MZ([[2, 5]]) / Z(2), DomainError)
2406+
assert raises(lambda: MZ([[2, 5]]) / 2, DomainError)
2407+
# field division by scalar
2408+
for (K, MK) in [(Q, MQ), (F17, MF17), (F163, MF163)]:
2409+
assert MK([[2, 5]]) / K(2) == MK([[K(2)/K(2), K(5)/K(2)]])
2410+
assert MK([[2, 5]]) / 2 == MK([[K(2)/K(2), K(5)/K(2)]])
2411+
# No other division is allowed
2412+
for (R, MR) in [(Z, MZ), (Q, MQ), (F17, MF17), (F163, MF163)]:
2413+
M = MR([[2, 5]])
2414+
for s in (2, R(2)):
2415+
assert raises(lambda: s / M, TypeError)
2416+
assert raises(lambda: M // s, TypeError)
2417+
assert raises(lambda: s // M, TypeError)
2418+
assert raises(lambda: M % s, TypeError)
2419+
assert raises(lambda: s % M, TypeError)
2420+
assert raises(lambda: divmod(s, M), TypeError)
2421+
assert raises(lambda: divmod(M, s), TypeError)
2422+
assert raises(lambda: M / M, TypeError)
2423+
assert raises(lambda: M // M, TypeError)
2424+
assert raises(lambda: M % M, TypeError)
2425+
assert raises(lambda: divmod(M, M), TypeError)
2426+
assert raises(lambda: M / 0, ZeroDivisionError)
2427+
assert raises(lambda: M / R(0), ZeroDivisionError)
2428+
2429+
22792430
def _all_polys():
22802431
return [
22812432
# (poly_type, scalar_type, is_field)
@@ -2436,16 +2587,18 @@ def setbad(obj, i, val):
24362587
assert raises(lambda: P([1, 2, 1]) % P([0]), ZeroDivisionError)
24372588
assert raises(lambda: divmod(P([1, 2, 1]), P([0])), ZeroDivisionError)
24382589

2590+
# Exact/field scalar division
24392591
if is_field:
24402592
assert P([2, 2]) / 2 == P([1, 1])
24412593
assert P([1, 2]) / 2 == P([S(1)/2, 1])
2442-
assert raises(lambda: P([1, 2]) / 0, ZeroDivisionError)
24432594
else:
2444-
assert raises(lambda: P([2, 2]) / 2, TypeError)
2595+
assert P([2, 2]) / 2 == P([1, 1])
2596+
assert raises(lambda: P([1, 2]) / 2, DomainError)
2597+
assert raises(lambda: P([1, 2]) / 0, ZeroDivisionError)
24452598

2446-
assert raises(lambda: 1 / P([1, 1]), TypeError)
2447-
assert raises(lambda: P([1, 2, 1]) / P([1, 1]), TypeError)
2448-
assert raises(lambda: P([1, 2, 1]) / P([1, 2]), TypeError)
2599+
assert P([1, 2, 1]) / P([1, 1]) == P([1, 1])
2600+
assert raises(lambda: 1 / P([1, 1]), DomainError)
2601+
assert raises(lambda: P([1, 2, 1]) / P([1, 2]), DomainError)
24492602

24502603
assert P([1, 1]) ** 0 == P([1])
24512604
assert P([1, 1]) ** 1 == P([1, 1])
@@ -3023,7 +3176,9 @@ def test_all_tests():
30233176
test_fmpz_mod_poly,
30243177
test_fmpz_mod_mat,
30253178

3026-
test_arb,
3179+
test_division_scalar,
3180+
test_division_poly,
3181+
test_division_matrix,
30273182

30283183
test_polys,
30293184

@@ -3048,7 +3203,9 @@ def test_all_tests():
30483203
test_matrices_rref,
30493204
test_matrices_solve,
30503205

3206+
test_arb,
3207+
30513208
test_pickling,
3052-
test_all_tests,
30533209

3210+
test_all_tests,
30543211
]

src/flint/types/fmpq_poly.pyx

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ from flint.flintlib.arith cimport arith_bernoulli_polynomial
1616
from flint.flintlib.arith cimport arith_euler_polynomial
1717
from flint.flintlib.arith cimport arith_legendre_polynomial
1818

19+
from flint.utils.flint_exceptions import DomainError
20+
21+
1922
cdef any_as_fmpq_poly(obj):
2023
if typecheck(obj, fmpq_poly):
2124
return obj
@@ -295,23 +298,35 @@ cdef class fmpq_poly(flint_poly):
295298
return t
296299
return t._mod_(s)
297300

298-
@staticmethod
299-
def _div_(fmpq_poly s, t):
300-
cdef fmpq_poly r
301-
t = any_as_fmpq(t)
301+
def __truediv__(fmpq_poly s, t):
302+
cdef fmpq_poly res
303+
cdef fmpq_poly_t r
304+
t2 = any_as_fmpq(t)
305+
if t2 is NotImplemented:
306+
t2 = any_as_fmpq_poly(t)
307+
if t2 is NotImplemented:
308+
return t2
309+
if fmpq_poly_is_zero((<fmpq_poly>t2).val):
310+
raise ZeroDivisionError("fmpq_poly division by 0")
311+
res = fmpq_poly.__new__(fmpq_poly)
312+
fmpq_poly_init(r)
313+
fmpq_poly_divrem(res.val, r, (<fmpq_poly>s).val, (<fmpq_poly>t2).val)
314+
exact = fmpq_poly_is_zero(r)
315+
fmpq_poly_clear(r)
316+
if not exact:
317+
raise DomainError("fmpq_poly inexact division")
318+
else:
319+
if fmpq_is_zero((<fmpq>t2).val):
320+
raise ZeroDivisionError("fmpq_poly scalar division by 0")
321+
res = fmpq_poly.__new__(fmpq_poly)
322+
fmpq_poly_scalar_div_fmpq(res.val, (<fmpq_poly>s).val, (<fmpq>t2).val)
323+
return res
324+
325+
def __rtruediv__(fmpq_poly s, t):
326+
t = any_as_fmpq_poly(t)
302327
if t is NotImplemented:
303328
return t
304-
if fmpq_is_zero((<fmpq>t).val):
305-
raise ZeroDivisionError("fmpq_poly scalar division by 0")
306-
r = fmpq_poly.__new__(fmpq_poly)
307-
fmpq_poly_scalar_div_fmpq(r.val, (<fmpq_poly>s).val, (<fmpq>t).val)
308-
return r
309-
310-
def __div__(s, t):
311-
return fmpq_poly._div_(s, t)
312-
313-
def __truediv__(s, t):
314-
return fmpq_poly._div_(s, t)
329+
return t / s
315330

316331
def _divmod_(s, t):
317332
cdef fmpq_poly P, Q

0 commit comments

Comments
 (0)