Skip to content

Commit 6791f16

Browse files
committed
Fix numerics of TruncatedNormal
1 parent d1d2aa2 commit 6791f16

File tree

4 files changed

+106
-99
lines changed

4 files changed

+106
-99
lines changed

pymc3/distributions/continuous.py

Lines changed: 76 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,11 @@
2020
from . import transforms
2121
from pymc3.util import get_variable_name
2222
from .special import log_i0
23-
from ..math import invlogit, logit
24-
from .dist_math import bound, logpow, gammaln, betaln, std_cdf, alltrue_elemwise, SplineWrapper, i0e
23+
from ..math import invlogit, logit, logdiffexp
24+
from .dist_math import (
25+
bound, logpow, gammaln, betaln, std_cdf, alltrue_elemwise,
26+
SplineWrapper, i0e, normal_lcdf, normal_lccdf
27+
)
2528
from .distribution import Continuous, draw_values, generate_samples
2629

2730
__all__ = ['Uniform', 'Flat', 'HalfFlat', 'Normal', 'TruncatedNormal', 'Beta',
@@ -51,10 +54,21 @@ def __init__(self, transform=transforms.logodds, *args, **kwargs):
5154
class BoundedContinuous(Continuous):
5255
"""Base class for bounded continuous distributions"""
5356

54-
def __init__(self, transform='interval', *args, **kwargs):
57+
def __init__(self, transform='auto', lower=None, upper=None,
58+
*args, **kwargs):
59+
60+
lower = tt.as_tensor_variable(lower) if lower is not None else None
61+
upper = tt.as_tensor_variable(upper) if upper is not None else None
5562

56-
if transform == 'interval':
57-
transform = transforms.interval(self.lower, self.upper)
63+
if transform == 'auto':
64+
if lower is None and upper is None:
65+
transform = None
66+
elif lower is not None and upper is None:
67+
transform = transforms.lowerbound(lower)
68+
elif lower is None and upper is not None:
69+
transform = transforms.upperbound(upper)
70+
else:
71+
transform = transforms.interval(lower, upper)
5872

5973
super(BoundedContinuous, self).__init__(
6074
transform=transform, *args, **kwargs)
@@ -173,7 +187,8 @@ def __init__(self, lower=0, upper=1, *args, **kwargs):
173187
self.mean = (upper + lower) / 2.
174188
self.median = self.mean
175189

176-
super(Uniform, self).__init__(*args, **kwargs)
190+
super(Uniform, self).__init__(
191+
lower=lower, upper=upper, *args, **kwargs)
177192

178193
def random(self, point=None, size=None):
179194
"""
@@ -446,7 +461,7 @@ def _repr_latex_(self, name=None, dist=None):
446461
get_variable_name(sd))
447462

448463

449-
class TruncatedNormal(Continuous):
464+
class TruncatedNormal(BoundedContinuous):
450465
R"""
451466
Univariate truncated normal log-likelihood.
452467
@@ -521,34 +536,29 @@ class TruncatedNormal(Continuous):
521536
"""
522537

523538
def __init__(self, mu=0, sd=None, tau=None, lower=None, upper=None,
524-
transform='infer', *args, **kwargs):
539+
transform='auto', *args, **kwargs):
525540
tau, sd = get_tau_sd(tau=tau, sd=sd)
526541
self.sd = tt.as_tensor_variable(sd)
527542
self.tau = tt.as_tensor_variable(tau)
528543
self.lower = tt.as_tensor_variable(lower) if lower is not None else lower
529544
self.upper = tt.as_tensor_variable(upper) if upper is not None else upper
530-
self.mu = tt.as_tensor_variable(mu)
545+
self.mu = tt.as_tensor_variable(mu)
531546

532-
# Calculate mean
533-
pdf_a, pdf_b, cdf_a, cdf_b = self._get_boundary_parameters()
534-
z = cdf_b - cdf_a
535-
self.mean = self.mu + (pdf_a+pdf_b) / z * self.sd
547+
if self.lower is None and self.upper is None:
548+
self._defaultval = mu
549+
elif self.lower is None and self.upper is not None:
550+
self._defaultval = self.upper - 1.
551+
elif self.lower is not None and self.upper is None:
552+
self._defaultval = self.lower + 1.
553+
else:
554+
self._defaultval = (self.lower + self.upper) / 2
536555

537556
assert_negative_support(sd, 'sd', 'TruncatedNormal')
538557
assert_negative_support(tau, 'tau', 'TruncatedNormal')
539558

540-
if transform == 'infer':
541-
if lower is None and upper is None:
542-
transform = None
543-
elif lower is not None and upper is not None:
544-
transform = transforms.interval(lower, upper)
545-
elif upper is not None:
546-
transform = transforms.upperbound(upper)
547-
else:
548-
transform = transforms.lowerbound(lower)
549-
550559
super(TruncatedNormal, self).__init__(
551-
transform=transform, *args, **kwargs)
560+
defaults=('_defaultval',), transform=transform,
561+
lower=lower, upper=upper, *args, **kwargs)
552562

553563
def random(self, point=None, size=None):
554564
"""
@@ -592,89 +602,57 @@ def logp(self, value):
592602
-------
593603
TensorVariable
594604
"""
595-
sd = self.sd
596-
tau = self.tau
597605
mu = self.mu
598-
a = self.lower
599-
b = self.upper
600-
601-
# In case either a or b are not specified, normalization terms simplify to 1.0 and 0.0
602-
# https://en.wikipedia.org/wiki/Truncated_normal_distribution
603-
norm_left, norm_right = 1.0, 0.0
604-
605-
# Define normalization
606-
if b is not None:
607-
norm_left = self._cdf((b - mu) / sd)
608-
609-
if a is not None:
610-
norm_right = self._cdf((a - mu) / sd)
611-
612-
f = self._pdf((value - mu) / sd) / sd / (norm_left - norm_right)
606+
sd = self.sd
613607

614-
return bound(tt.log(f), value >= a, value <= b, sd > 0)
608+
norm = self._normalization()
609+
logp = Normal.dist(mu=mu, sd=sd).logp(value) - norm
615610

616-
def _cdf(self, value):
617-
"""
618-
Calculate cdf of standard normal distribution
611+
bounds = [sd > 0]
612+
if self.lower is not None:
613+
bounds.append(value >= self.lower)
614+
if self.upper is not None:
615+
bounds.append(value <= self.upper)
616+
return bound(logp, *bounds)
619617

620-
Parameters
621-
----------
622-
value : numeric
623-
Value(s) for which log-probability is calculated. If the log probabilities for multiple
624-
values are desired the values must be provided in a numpy array or theano tensor
618+
def _normalization(self):
619+
mu, sd = self.mu, self.sd
625620

626-
Returns
627-
-------
628-
TensorVariable
629-
"""
630-
return 0.5 * (1.0 + tt.erf(value / tt.sqrt(2)))
621+
if self.lower is None and self.upper is None:
622+
return 0.
631623

632-
def _pdf(self, value):
633-
"""
634-
Calculate pdf of standard normal distribution
624+
if self.lower is not None and self.upper is not None:
625+
lcdf_a = normal_lcdf(mu, sd, self.lower)
626+
lcdf_b = normal_lcdf(mu, sd, self.upper)
627+
lsf_a = normal_lccdf(mu, sd, self.lower)
628+
lsf_b = normal_lccdf(mu, sd, self.upper)
635629

636-
Parameters
637-
----------
638-
value : numeric
639-
Value(s) for which log-probability is calculated. If the log probabilities for multiple
640-
values are desired the values must be provided in a numpy array or theano tensor
630+
return tt.switch(
631+
self.lower > 0,
632+
logdiffexp(lsf_a, lsf_b),
633+
logdiffexp(lcdf_b, lcdf_a),
634+
)
641635

642-
Returns
643-
-------
644-
TensorVariable
645-
"""
646-
return 1.0 / tt.sqrt(2 * np.pi) * tt.exp(-0.5 * (value ** 2))
636+
if self.lower is not None:
637+
return normal_lccdf(mu, sd, self.lower)
638+
else:
639+
return normal_lcdf(mu, sd, self.upper)
647640

648641
def _repr_latex_(self, name=None, dist=None):
649642
if dist is None:
650643
dist = self
651-
sd = dist.sd
652-
mu = dist.mu
653-
a = dist.a
654-
b = dist.b
655644
name = r'\text{%s}' % name
656-
return r'${} \sim \text{{TruncatedNormal}}(\mathit{{mu}}={},~\mathit{{sd}}={},a={},b={})$'.format(name,
657-
get_variable_name(mu),
658-
get_variable_name(sd),
659-
get_variable_name(a),
660-
get_variable_name(b))
661-
662-
def _get_boundary_parameters(self):
663-
"""
664-
Calcualte values of cdf and pdf at boundary points a and b
665-
666-
Returns
667-
-------
668-
pdf(a), pdf(b), cdf(a), cdf(b) if a,b defined, otherwise 0.0, 0.0, 0.0, 1.0
669-
"""
670-
# pdf = 0 at +-inf
671-
pdf_a = self._pdf(self.lower) if not self.lower is None else 0.0
672-
pdf_b = self._pdf(self.upper) if not self.upper is None else 0.0
673-
674-
# b-> inf, cdf(b) = 1.0, a->-inf, cdf(a) = 0
675-
cdf_a = self._cdf(self.lower) if not self.lower is None else 0.0
676-
cdf_b = self._cdf(self.upper) if not self.upper is None else 1.0
677-
return pdf_a, pdf_b, cdf_a, cdf_b
645+
return (
646+
r'${} \sim \text{{TruncatedNormal}}('
647+
'\mathit{{mu}}={},~\mathit{{sd}}={},a={},b={})$'
648+
.format(
649+
name,
650+
get_variable_name(self.mu),
651+
get_variable_name(self.sd),
652+
get_variable_name(self.lower),
653+
get_variable_name(self.upper),
654+
)
655+
)
678656

679657

680658
class HalfNormal(PositiveContinuous):
@@ -3054,7 +3032,8 @@ def __init__(self, lower=0, upper=1, c=0.5,
30543032
self.lower = lower = tt.as_tensor_variable(lower)
30553033
self.upper = upper = tt.as_tensor_variable(upper)
30563034

3057-
super(Triangular, self).__init__(*args, **kwargs)
3035+
super(Triangular, self).__init__(lower=lower, upper=upper,
3036+
*args, **kwargs)
30583037

30593038
def random(self, point=None, size=None):
30603039
"""
@@ -3553,7 +3532,8 @@ def __init__(self, x_points, pdf_points, *args, **kwargs):
35533532
self.lower = lower = tt.as_tensor_variable(x_points[0])
35543533
self.upper = upper = tt.as_tensor_variable(x_points[-1])
35553534

3556-
super(Interpolated, self).__init__(*args, **kwargs)
3535+
super(Interpolated, self).__init__(lower=lower, upper=upper,
3536+
*args, **kwargs)
35573537

35583538
interp = InterpolatedUnivariateSpline(
35593539
x_points, pdf_points, k=1, ext='zeros')

pymc3/distributions/dist_math.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,25 @@ def std_cdf(x):
8787
return .5 + .5 * tt.erf(x / tt.sqrt(2.))
8888

8989

90+
def normal_lcdf(mu, sigma, x):
91+
"""Compute the log of the cumulative density function of the normal."""
92+
z = (x - mu) / sigma
93+
return tt.switch(
94+
tt.lt(z, -1.0),
95+
tt.log(tt.erfcx(-z / tt.sqrt(2.)) / 2.) - tt.sqr(z) / 2.,
96+
tt.log1p(-tt.erfc(z / tt.sqrt(2.)) / 2.)
97+
)
98+
99+
100+
def normal_lccdf(mu, sigma, x):
101+
z = (x - mu) / sigma
102+
return tt.switch(
103+
tt.gt(z, 1.0),
104+
tt.log(tt.erfcx(z / tt.sqrt(2.)) / 2.) - tt.sqr(z) / 2.,
105+
tt.log1p(-tt.erfc(-z / tt.sqrt(2.)) / 2.)
106+
)
107+
108+
90109
def sd2rho(sd):
91110
"""
92111
`sd -> rho` theano converter

pymc3/distributions/distribution.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
from theano import function
77
import theano
88
from ..memoize import memoize
9-
from ..model import Model, get_named_nodes_and_relations, FreeRV, ObservedRV, MultiObservedRV
9+
from ..model import (
10+
Model, get_named_nodes_and_relations, FreeRV,
11+
ObservedRV, MultiObservedRV
12+
)
1013
from ..vartypes import string_types
1114

1215
__all__ = ['DensityDist', 'Distribution', 'Continuous', 'Discrete',

pymc3/math.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,13 @@ def logsumexp(x, axis=None):
114114
def logaddexp(a, b):
115115
diff = b - a
116116
return tt.switch(diff > 0,
117-
b + tt.log1p(tt.exp(-diff)),
118-
a + tt.log1p(tt.exp(diff)))
117+
b + tt.log1p(tt.exp(-diff)),
118+
a + tt.log1p(tt.exp(diff)))
119+
120+
121+
def logdiffexp(a, b):
122+
"""log(exp(a) - exp(b))"""
123+
return a + log1mexp(a - b)
119124

120125

121126
def invlogit(x, eps=sys.float_info.epsilon):

0 commit comments

Comments
 (0)