Skip to content

Commit bb25d9d

Browse files
committed
Change memory management for precomutations
1 parent e866add commit bb25d9d

File tree

2 files changed

+46
-42
lines changed

2 files changed

+46
-42
lines changed

src/flint/types/fmpz_mod.pxd

+2-4
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,11 @@ from flint.flintlib.fmpz_mod cimport (
88

99
cdef class fmpz_mod_ctx:
1010
cdef fmpz_mod_ctx_t val
11-
cdef fmpz_mod_discrete_log_pohlig_hellman_t L
12-
cdef bint _dlog_precomputed
11+
cdef fmpz_mod_discrete_log_pohlig_hellman_t *L
1312
cdef any_as_fmpz_mod(self, obj)
1413
cdef _precompute_dlog_prime(self)
1514

1615
cdef class fmpz_mod(flint_scalar):
1716
cdef fmpz_mod_ctx ctx
1817
cdef fmpz_t val
19-
cdef bint base_dlog_precomputed
20-
cdef fmpz_t x_g
18+
cdef fmpz_t *x_g

src/flint/types/fmpz_mod.pyx

+44-38
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ from flint.types.fmpz cimport(
2222
any_as_fmpz,
2323
fmpz_get_intlong
2424
)
25+
cimport cython
26+
cimport libc.stdlib
2527

2628
cdef class fmpz_mod_ctx:
2729
r"""
@@ -36,11 +38,13 @@ cdef class fmpz_mod_ctx:
3638
cdef fmpz one = fmpz.__new__(fmpz)
3739
fmpz_one(one.val)
3840
fmpz_mod_ctx_init(self.val, one.val)
39-
fmpz_mod_discrete_log_pohlig_hellman_init(self.L)
41+
self.L = NULL
42+
4043

4144
def __dealloc__(self):
4245
fmpz_mod_ctx_clear(self.val)
43-
fmpz_mod_discrete_log_pohlig_hellman_clear(self.L)
46+
if self.L:
47+
fmpz_mod_discrete_log_pohlig_hellman_clear(self.L[0])
4448

4549
def __init__(self, mod):
4650
# Ensure modulus is fmpz type
@@ -58,10 +62,6 @@ cdef class fmpz_mod_ctx:
5862
# Set the modulus
5963
fmpz_mod_ctx_set_modulus(self.val, (<fmpz>mod).val)
6064

61-
# Store whether the pohlig-hellman precomputation has
62-
# been performed
63-
self._dlog_precomputed = 0
64-
6565
def modulus(self):
6666
"""
6767
Return the modulus from the context as an fmpz
@@ -81,10 +81,13 @@ cdef class fmpz_mod_ctx:
8181
Initalise the dlog data, all discrete logs are solved with an
8282
internally chosen base `y`
8383
"""
84+
self.L = <fmpz_mod_discrete_log_pohlig_hellman_t *>libc.stdlib.malloc(
85+
cython.sizeof(fmpz_mod_discrete_log_pohlig_hellman_struct)
86+
)
87+
fmpz_mod_discrete_log_pohlig_hellman_init(self.L[0])
8488
fmpz_mod_discrete_log_pohlig_hellman_precompute_prime(
85-
self.L, self.val.n
89+
self.L[0], self.val.n
8690
)
87-
self._dlog_precomputed = 1
8891

8992
cdef any_as_fmpz_mod(self, obj):
9093
# If `obj` is an `fmpz_mod`, just check moduli
@@ -158,11 +161,12 @@ cdef class fmpz_mod(flint_scalar):
158161

159162
def __cinit__(self):
160163
fmpz_init(self.val)
161-
fmpz_init(self.x_g)
164+
self.x_g = NULL
162165

163166
def __dealloc__(self):
164167
fmpz_clear(self.val)
165-
fmpz_clear(self.x_g)
168+
if self.x_g:
169+
fmpz_clear(self.x_g[0])
166170

167171
def __init__(self, val, ctx):
168172
if not typecheck(ctx, fmpz_mod_ctx):
@@ -186,9 +190,6 @@ cdef class fmpz_mod(flint_scalar):
186190
raise NotImplementedError
187191
fmpz_mod_set_fmpz(self.val, (<fmpz>val).val, self.ctx.val)
188192

189-
# Bool to say whether x_g has been computed before
190-
self.base_dlog_precomputed = 0
191-
192193
def is_zero(self):
193194
"""
194195
Return whether an element is equal to zero
@@ -258,10 +259,6 @@ cdef class fmpz_mod(flint_scalar):
258259
259260
NOTE: Requires that the context modulus is prime.
260261
261-
TODO: This could instead be initalised as a class from a
262-
given base and the precomputations could be stored to allow
263-
faster computations for many discrete logs with the same base.
264-
265262
>>> F = fmpz_mod_ctx(163)
266263
>>> g = F(2)
267264
>>> x = 123
@@ -286,45 +283,54 @@ cdef class fmpz_mod(flint_scalar):
286283
if a is NotImplemented:
287284
raise TypeError
288285

286+
# First, Ensure that self.ctx.L has performed precomputations
287+
# This generates a `y` which is a primative root, and used as
288+
# the base in `fmpz_mod_discrete_log_pohlig_hellman_run`
289+
if not self.ctx.L:
290+
self.ctx._precompute_dlog_prime()
291+
289292
# Solve the discrete log for the chosen base and target
290293
# g = y^x_g and a = y^x_a
291294
# We want to find x such that a = g^x =>
292295
# (y^x_a) = (y^x_g)^x => x = (x_a / x_g) mod (p-1)
293-
cdef fmpz_t x_a
294-
cdef fmpz_t x_g
295-
fmpz_init(x_a)
296-
fmpz_init(x_g)
297296

298-
# Ensure that self.ctx.L has performed precomputations
299-
if not self.ctx._dlog_precomputed:
300-
self.ctx._precompute_dlog_prime()
297+
# For repeated calls to discrete_log, it's more efficient to
298+
# store x_g rather than keep computing it
299+
if not self.x_g:
300+
self.x_g = <fmpz_t *>libc.stdlib.malloc(
301+
cython.sizeof(fmpz_t)
302+
)
303+
fmpz_mod_discrete_log_pohlig_hellman_run(
304+
self.x_g[0], self.ctx.L[0], self.val
305+
)
301306

302-
if not self.base_dlog_precomputed:
303-
fmpz_mod_discrete_log_pohlig_hellman_run(self.x_g, self.ctx.L, self.val)
304-
self.base_dlog_precomputed = 1
305-
fmpz_mod_discrete_log_pohlig_hellman_run(x_a, self.ctx.L, (<fmpz_mod>a).val)
307+
# Then we need to compute x_a which will be different for each call
308+
cdef fmpz_t x_a
309+
fmpz_init(x_a)
310+
fmpz_mod_discrete_log_pohlig_hellman_run(
311+
x_a, self.ctx.L[0], (<fmpz_mod>a).val
312+
)
306313

307314
# If g is not a primative root, then x_g and pm1 will share
308315
# a common factor. We can use this to compute the order of
309316
# g.
310-
cdef fmpz_t g, g_order
317+
cdef fmpz_t g, g_order, x_g
311318
fmpz_init(g)
312319
fmpz_init(g_order)
320+
fmpz_init(x_g)
313321

314-
fmpz_gcd(g, self.x_g, self.ctx.L.pm1)
322+
fmpz_gcd(g, self.x_g[0], self.ctx.L[0].pm1)
315323
if not fmpz_is_one(g):
316-
fmpz_divexact(x_g, self.x_g, g)
324+
fmpz_divexact(x_g, self.x_g[0], g)
317325
fmpz_divexact(x_a, x_a, g)
318-
fmpz_divexact(g_order, self.ctx.L.pm1, g)
326+
fmpz_divexact(g_order, self.ctx.L[0].pm1, g)
319327
else:
320-
fmpz_set(g_order, self.ctx.L.pm1)
321-
fmpz_set(x_g, self.x_g)
328+
fmpz_set(g_order, self.ctx.L[0].pm1)
329+
fmpz_set(x_g, self.x_g[0])
322330

323-
324-
# Finally, compute output exponent
331+
# Finally, compute output exponent by computing
332+
# (x_a / x_g) mod g_order
325333
cdef fmpz x = fmpz.__new__(fmpz)
326-
327-
# Compute (x_a / x_g) mod g_order
328334
fmpz_invmod(x.val, x_g, g_order)
329335
fmpz_mul(x.val, x.val, x_a)
330336
fmpz_type_mod(x.val, x.val, g_order)

0 commit comments

Comments
 (0)