Skip to content

Commit b724663

Browse files
Googlercopybara-github
authored andcommitted
Add support for the use case of laplace_tail_mass being a tf.Tensor, e.g., as set by a schedule during model training.
PiperOrigin-RevId: 467336716 Change-Id: I415b914b53dc27d8d009a8bff142d3e89440dc8c
1 parent c2ae0e1 commit b724663

File tree

4 files changed

+68
-17
lines changed

4 files changed

+68
-17
lines changed

tensorflow_compression/python/entropy_models/continuous_base.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,9 @@ def __init__(self,
6868
Elias gamma code embedded into the range coder.
6969
bottleneck_dtype: `tf.dtypes.DType`. Data type of bottleneck tensor.
7070
Defaults to `tf.keras.mixed_precision.global_policy().compute_dtype`.
71-
laplace_tail_mass: Float. If non-zero, will augment the prior with a
72-
`NoisyLaplace` mixture component for training stability. (experimental)
71+
laplace_tail_mass: Float, or a float-valued tf.Tensor. If positive,
72+
will augment the prior with a `NoisyLaplace` mixture component for
73+
training stability. (experimental)
7374
"""
7475
super().__init__()
7576
self._prior = None # This will be set by subclasses, if appropriate.
@@ -83,14 +84,12 @@ def __init__(self,
8384
if bottleneck_dtype is None:
8485
bottleneck_dtype = tf.keras.backend.floatx()
8586
self._bottleneck_dtype = tf.as_dtype(bottleneck_dtype)
86-
self._laplace_tail_mass = float(laplace_tail_mass)
87+
self._laplace_tail_mass = laplace_tail_mass
8788

8889
if self.coding_rank < 0:
8990
raise ValueError("`coding_rank` must be at least 0.")
9091
if not 0 < self.tail_mass < 1:
9192
raise ValueError("`tail_mass` must be between 0 and 1.")
92-
if not 0 <= self.laplace_tail_mass < 1:
93-
raise ValueError("`laplace_tail_mass` must be between 0 and 1.")
9493

9594
def _check_compression(self):
9695
if not self.compression:
@@ -299,23 +298,41 @@ def loop_body(i, cdf):
299298
def _log_prob(self, prior, bottleneck_perturbed):
300299
"""Evaluates prior.log_prob(bottleneck + noise)."""
301300
bottleneck_perturbed = tf.cast(bottleneck_perturbed, prior.dtype)
302-
if self.laplace_tail_mass:
301+
laplace_tail_mass = self.laplace_tail_mass
302+
303+
def mixture_log_prob_fn():
304+
tf.debugging.assert_less(
305+
laplace_tail_mass,
306+
tf.constant(1.0, prior.dtype),
307+
message="`laplace_tail_mass` must be less than 1.")
303308
laplace_prior = uniform_noise.NoisyLaplace(
304309
loc=tf.constant(0, dtype=prior.dtype),
305310
scale=tf.constant(1, dtype=prior.dtype))
306311
probs = prior.prob(bottleneck_perturbed)
307-
probs = ((1 - self.laplace_tail_mass) * probs +
308-
self.laplace_tail_mass *
312+
probs = ((1 - laplace_tail_mass) * probs +
313+
laplace_tail_mass *
309314
laplace_prior.prob(bottleneck_perturbed))
310315
probs_too_small = probs < 1e-10
311316
probs_bounded = tf.maximum(probs, 1e-10)
312317
return tf.where(
313318
probs_too_small,
314-
tf.math.log(self.laplace_tail_mass) +
319+
tf.math.log(laplace_tail_mass) +
315320
laplace_prior.log_prob(bottleneck_perturbed),
316321
tf.math.log(probs_bounded))
322+
323+
prior_log_prob_fn = lambda: prior.log_prob(bottleneck_perturbed)
324+
325+
if isinstance(laplace_tail_mass, tf.Tensor):
326+
# Do all the computation in tf (graph mode compatible).
327+
laplace_tail_mass = tf.cast(laplace_tail_mass, prior.dtype)
328+
use_laplace_tail_mass = tf.greater(laplace_tail_mass, 0.0)
329+
return tf.cond(use_laplace_tail_mass, mixture_log_prob_fn,
330+
prior_log_prob_fn)
317331
else:
318-
return prior.log_prob(bottleneck_perturbed)
332+
if laplace_tail_mass > 0:
333+
return mixture_log_prob_fn()
334+
else:
335+
return prior_log_prob_fn()
319336

320337
@abc.abstractmethod
321338
def get_config(self):
@@ -340,7 +357,7 @@ def get_config(self):
340357
tail_mass=self.tail_mass,
341358
cdf_shapes=(self.cdf.shape[0], self.cdf_offset.shape[0]),
342359
bottleneck_dtype=self.bottleneck_dtype.name,
343-
laplace_tail_mass=self.laplace_tail_mass,
360+
laplace_tail_mass=float(self.laplace_tail_mass),
344361
)
345362

346363
def get_weights(self):

tensorflow_compression/python/entropy_models/continuous_batched.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,9 @@ def __init__(self,
171171
use. If provided (not `None`), then `offset_heuristic` is ineffective.
172172
decode_sanity_check: Boolean. If `True`, an raises an error if the binary
173173
strings passed into `decompress` are not completely decoded.
174-
laplace_tail_mass: Float. If non-zero, will augment the prior with a
175-
`NoisyLaplace` mixture component for training stability. (experimental)
174+
laplace_tail_mass: Float, or a float-valued tf.Tensor. If positive,
175+
will augment the prior with a `NoisyLaplace` mixture component for
176+
training stability. (experimental)
176177
"""
177178
if (prior is None) == (prior_shape is None):
178179
raise ValueError("Either `prior` or `prior_shape` must be provided.")

tensorflow_compression/python/entropy_models/continuous_batched_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,37 @@ def test_small_bitcost_for_dirac_prior(self):
241241
# Quantization noise should be between -.5 and .5
242242
self.assertAllClose(x, x_decoded, rtol=0., atol=.5)
243243

244+
def test_laplace_tail_mass(self):
245+
noisy = uniform_noise.NoisyNormal(loc=0., scale=1.)
246+
em = ContinuousBatchedEntropyModel(noisy, 1, laplace_tail_mass=0.0)
247+
self.assertEqual(em.laplace_tail_mass, 0.0)
248+
em = ContinuousBatchedEntropyModel(noisy, 1,
249+
laplace_tail_mass=tf.constant(1e-3))
250+
self.assertEqual(em.laplace_tail_mass, tf.constant(1e-3))
251+
log_prob = em._log_prob(noisy, tf.constant(0.0))
252+
self.assertEqual(log_prob.dtype, tf.float32)
253+
254+
def test_laplace_tail_mass_works_in_tf_function(self):
255+
noisy = uniform_noise.NoisyNormal(loc=0., scale=1.)
256+
samples = noisy.base.sample([100])
257+
258+
# Since tf.function traces each function twice, and only allows variable
259+
# creation in the first call, we need to have a stateful object in which we
260+
# create the entropy model only the first time the function is called, and
261+
# store it for the second time.
262+
263+
class EntropyModel:
264+
265+
def log_prob(self, values):
266+
if not hasattr(self, "em"):
267+
self.em = ContinuousBatchedEntropyModel(
268+
noisy, 1, laplace_tail_mass=tf.constant(1e-3))
269+
return self.em._log_prob(noisy, values)
270+
271+
values_eager = EntropyModel().log_prob(samples)
272+
values_function = tf.function(EntropyModel().log_prob)(samples)
273+
self.assertAllEqual(values_eager, values_function)
274+
244275

245276
if __name__ == "__main__":
246277
tf.test.main()

tensorflow_compression/python/entropy_models/continuous_indexed.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,9 @@ def __init__(self,
189189
computations. Defaults to `tf.float32`.
190190
decode_sanity_check: Boolean. If `True`, an raises an error if the binary
191191
strings passed into `decompress` are not completely decoded.
192-
laplace_tail_mass: Float. If non-zero, will augment the prior with a
193-
`NoisyLaplace` mixture component for training stability. (experimental)
192+
laplace_tail_mass: Float, or a float-valued tf.Tensor. If positive,
193+
will augment the prior with a `NoisyLaplace` mixture component for
194+
training stability. (experimental)
194195
"""
195196
if not callable(prior_fn):
196197
raise TypeError("`prior_fn` must be a class or factory function.")
@@ -496,8 +497,9 @@ def __init__(self,
496497
Defaults to `tf.keras.mixed_precision.global_policy().compute_dtype`.
497498
prior_dtype: `tf.dtypes.DType`. Data type of prior and probability
498499
computations. Defaults to `tf.float32`.
499-
laplace_tail_mass: Float. If non-zero, will augment the prior with a
500-
`NoisyLaplace` mixture component for training stability. (experimental)
500+
laplace_tail_mass: Float, or a float-valued tf.Tensor. If positive,
501+
will augment the prior with a `NoisyLaplace` mixture component for
502+
training stability. (experimental)
501503
"""
502504
num_scales = int(num_scales)
503505
super().__init__(

0 commit comments

Comments
 (0)