@@ -68,8 +68,9 @@ def __init__(self,
68
68
Elias gamma code embedded into the range coder.
69
69
bottleneck_dtype: `tf.dtypes.DType`. Data type of bottleneck tensor.
70
70
Defaults to `tf.keras.mixed_precision.global_policy().compute_dtype`.
71
- laplace_tail_mass: Float. If non-zero, will augment the prior with a
72
- `NoisyLaplace` mixture component for training stability. (experimental)
71
+ laplace_tail_mass: Float, or a float-valued tf.Tensor. If positive,
72
+ will augment the prior with a `NoisyLaplace` mixture component for
73
+ training stability. (experimental)
73
74
"""
74
75
super ().__init__ ()
75
76
self ._prior = None # This will be set by subclasses, if appropriate.
@@ -83,14 +84,12 @@ def __init__(self,
83
84
if bottleneck_dtype is None :
84
85
bottleneck_dtype = tf .keras .backend .floatx ()
85
86
self ._bottleneck_dtype = tf .as_dtype (bottleneck_dtype )
86
- self ._laplace_tail_mass = float ( laplace_tail_mass )
87
+ self ._laplace_tail_mass = laplace_tail_mass
87
88
88
89
if self .coding_rank < 0 :
89
90
raise ValueError ("`coding_rank` must be at least 0." )
90
91
if not 0 < self .tail_mass < 1 :
91
92
raise ValueError ("`tail_mass` must be between 0 and 1." )
92
- if not 0 <= self .laplace_tail_mass < 1 :
93
- raise ValueError ("`laplace_tail_mass` must be between 0 and 1." )
94
93
95
94
def _check_compression (self ):
96
95
if not self .compression :
@@ -299,23 +298,41 @@ def loop_body(i, cdf):
299
298
def _log_prob (self , prior , bottleneck_perturbed ):
300
299
"""Evaluates prior.log_prob(bottleneck + noise)."""
301
300
bottleneck_perturbed = tf .cast (bottleneck_perturbed , prior .dtype )
302
- if self .laplace_tail_mass :
301
+ laplace_tail_mass = self .laplace_tail_mass
302
+
303
+ def mixture_log_prob_fn ():
304
+ tf .debugging .assert_less (
305
+ laplace_tail_mass ,
306
+ tf .constant (1.0 , prior .dtype ),
307
+ message = "`laplace_tail_mass` must be less than 1." )
303
308
laplace_prior = uniform_noise .NoisyLaplace (
304
309
loc = tf .constant (0 , dtype = prior .dtype ),
305
310
scale = tf .constant (1 , dtype = prior .dtype ))
306
311
probs = prior .prob (bottleneck_perturbed )
307
- probs = ((1 - self . laplace_tail_mass ) * probs +
308
- self . laplace_tail_mass *
312
+ probs = ((1 - laplace_tail_mass ) * probs +
313
+ laplace_tail_mass *
309
314
laplace_prior .prob (bottleneck_perturbed ))
310
315
probs_too_small = probs < 1e-10
311
316
probs_bounded = tf .maximum (probs , 1e-10 )
312
317
return tf .where (
313
318
probs_too_small ,
314
- tf .math .log (self . laplace_tail_mass ) +
319
+ tf .math .log (laplace_tail_mass ) +
315
320
laplace_prior .log_prob (bottleneck_perturbed ),
316
321
tf .math .log (probs_bounded ))
322
+
323
+ prior_log_prob_fn = lambda : prior .log_prob (bottleneck_perturbed )
324
+
325
+ if isinstance (laplace_tail_mass , tf .Tensor ):
326
+ # Do all the computation in tf (graph mode compatible).
327
+ laplace_tail_mass = tf .cast (laplace_tail_mass , prior .dtype )
328
+ use_laplace_tail_mass = tf .greater (laplace_tail_mass , 0.0 )
329
+ return tf .cond (use_laplace_tail_mass , mixture_log_prob_fn ,
330
+ prior_log_prob_fn )
317
331
else :
318
- return prior .log_prob (bottleneck_perturbed )
332
+ if laplace_tail_mass > 0 :
333
+ return mixture_log_prob_fn ()
334
+ else :
335
+ return prior_log_prob_fn ()
319
336
320
337
@abc .abstractmethod
321
338
def get_config (self ):
@@ -340,7 +357,7 @@ def get_config(self):
340
357
tail_mass = self .tail_mass ,
341
358
cdf_shapes = (self .cdf .shape [0 ], self .cdf_offset .shape [0 ]),
342
359
bottleneck_dtype = self .bottleneck_dtype .name ,
343
- laplace_tail_mass = self .laplace_tail_mass ,
360
+ laplace_tail_mass = float ( self .laplace_tail_mass ) ,
344
361
)
345
362
346
363
def get_weights (self ):
0 commit comments