Skip to content

Fix numerics of TruncatedNormal #3077

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 6, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 76 additions & 96 deletions pymc3/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
from . import transforms
from pymc3.util import get_variable_name
from .special import log_i0
from ..math import invlogit, logit
from .dist_math import bound, logpow, gammaln, betaln, std_cdf, alltrue_elemwise, SplineWrapper, i0e
from ..math import invlogit, logit, logdiffexp
from .dist_math import (
bound, logpow, gammaln, betaln, std_cdf, alltrue_elemwise,
SplineWrapper, i0e, normal_lcdf, normal_lccdf
)
from .distribution import Continuous, draw_values, generate_samples

__all__ = ['Uniform', 'Flat', 'HalfFlat', 'Normal', 'TruncatedNormal', 'Beta',
Expand Down Expand Up @@ -51,10 +54,21 @@ def __init__(self, transform=transforms.logodds, *args, **kwargs):
class BoundedContinuous(Continuous):
"""Base class for bounded continuous distributions"""

def __init__(self, transform='interval', *args, **kwargs):
def __init__(self, transform='auto', lower=None, upper=None,
*args, **kwargs):

lower = tt.as_tensor_variable(lower) if lower is not None else None
upper = tt.as_tensor_variable(upper) if upper is not None else None

if transform == 'interval':
transform = transforms.interval(self.lower, self.upper)
if transform == 'auto':
if lower is None and upper is None:
transform = None
elif lower is not None and upper is None:
transform = transforms.lowerbound(lower)
elif lower is None and upper is not None:
transform = transforms.upperbound(upper)
else:
transform = transforms.interval(lower, upper)

super(BoundedContinuous, self).__init__(
transform=transform, *args, **kwargs)
Expand Down Expand Up @@ -173,7 +187,8 @@ def __init__(self, lower=0, upper=1, *args, **kwargs):
self.mean = (upper + lower) / 2.
self.median = self.mean

super(Uniform, self).__init__(*args, **kwargs)
super(Uniform, self).__init__(
lower=lower, upper=upper, *args, **kwargs)

def random(self, point=None, size=None):
"""
Expand Down Expand Up @@ -446,7 +461,7 @@ def _repr_latex_(self, name=None, dist=None):
get_variable_name(sd))


class TruncatedNormal(Continuous):
class TruncatedNormal(BoundedContinuous):
R"""
Univariate truncated normal log-likelihood.

Expand Down Expand Up @@ -521,34 +536,29 @@ class TruncatedNormal(Continuous):
"""

def __init__(self, mu=0, sd=None, tau=None, lower=None, upper=None,
transform='infer', *args, **kwargs):
transform='auto', *args, **kwargs):
tau, sd = get_tau_sd(tau=tau, sd=sd)
self.sd = tt.as_tensor_variable(sd)
self.tau = tt.as_tensor_variable(tau)
self.lower = tt.as_tensor_variable(lower) if lower is not None else lower
self.upper = tt.as_tensor_variable(upper) if upper is not None else upper
self.mu = tt.as_tensor_variable(mu)
self.mu = tt.as_tensor_variable(mu)

# Calculate mean
pdf_a, pdf_b, cdf_a, cdf_b = self._get_boundary_parameters()
z = cdf_b - cdf_a
self.mean = self.mu + (pdf_a+pdf_b) / z * self.sd
if self.lower is None and self.upper is None:
self._defaultval = mu
elif self.lower is None and self.upper is not None:
self._defaultval = self.upper - 1.
elif self.lower is not None and self.upper is None:
self._defaultval = self.lower + 1.
else:
self._defaultval = (self.lower + self.upper) / 2

assert_negative_support(sd, 'sd', 'TruncatedNormal')
assert_negative_support(tau, 'tau', 'TruncatedNormal')

if transform == 'infer':
if lower is None and upper is None:
transform = None
elif lower is not None and upper is not None:
transform = transforms.interval(lower, upper)
elif upper is not None:
transform = transforms.upperbound(upper)
else:
transform = transforms.lowerbound(lower)

super(TruncatedNormal, self).__init__(
transform=transform, *args, **kwargs)
defaults=('_defaultval',), transform=transform,
lower=lower, upper=upper, *args, **kwargs)

def random(self, point=None, size=None):
"""
Expand Down Expand Up @@ -592,89 +602,57 @@ def logp(self, value):
-------
TensorVariable
"""
sd = self.sd
tau = self.tau
mu = self.mu
a = self.lower
b = self.upper

# In case either a or b are not specified, normalization terms simplify to 1.0 and 0.0
# https://en.wikipedia.org/wiki/Truncated_normal_distribution
norm_left, norm_right = 1.0, 0.0

# Define normalization
if b is not None:
norm_left = self._cdf((b - mu) / sd)

if a is not None:
norm_right = self._cdf((a - mu) / sd)

f = self._pdf((value - mu) / sd) / sd / (norm_left - norm_right)
sd = self.sd

return bound(tt.log(f), value >= a, value <= b, sd > 0)
norm = self._normalization()
logp = Normal.dist(mu=mu, sd=sd).logp(value) - norm

def _cdf(self, value):
"""
Calculate cdf of standard normal distribution
bounds = [sd > 0]
if self.lower is not None:
bounds.append(value >= self.lower)
if self.upper is not None:
bounds.append(value <= self.upper)
return bound(logp, *bounds)

Parameters
----------
value : numeric
Value(s) for which log-probability is calculated. If the log probabilities for multiple
values are desired the values must be provided in a numpy array or theano tensor
def _normalization(self):
mu, sd = self.mu, self.sd

Returns
-------
TensorVariable
"""
return 0.5 * (1.0 + tt.erf(value / tt.sqrt(2)))
if self.lower is None and self.upper is None:
return 0.

def _pdf(self, value):
"""
Calculate pdf of standard normal distribution
if self.lower is not None and self.upper is not None:
lcdf_a = normal_lcdf(mu, sd, self.lower)
lcdf_b = normal_lcdf(mu, sd, self.upper)
lsf_a = normal_lccdf(mu, sd, self.lower)
lsf_b = normal_lccdf(mu, sd, self.upper)

Parameters
----------
value : numeric
Value(s) for which log-probability is calculated. If the log probabilities for multiple
values are desired the values must be provided in a numpy array or theano tensor
return tt.switch(
self.lower > 0,
logdiffexp(lsf_a, lsf_b),
logdiffexp(lcdf_b, lcdf_a),
)

Returns
-------
TensorVariable
"""
return 1.0 / tt.sqrt(2 * np.pi) * tt.exp(-0.5 * (value ** 2))
if self.lower is not None:
return normal_lccdf(mu, sd, self.lower)
else:
return normal_lcdf(mu, sd, self.upper)

def _repr_latex_(self, name=None, dist=None):
if dist is None:
dist = self
sd = dist.sd
mu = dist.mu
a = dist.a
b = dist.b
name = r'\text{%s}' % name
return r'${} \sim \text{{TruncatedNormal}}(\mathit{{mu}}={},~\mathit{{sd}}={},a={},b={})$'.format(name,
get_variable_name(mu),
get_variable_name(sd),
get_variable_name(a),
get_variable_name(b))

def _get_boundary_parameters(self):
"""
Calcualte values of cdf and pdf at boundary points a and b

Returns
-------
pdf(a), pdf(b), cdf(a), cdf(b) if a,b defined, otherwise 0.0, 0.0, 0.0, 1.0
"""
# pdf = 0 at +-inf
pdf_a = self._pdf(self.lower) if not self.lower is None else 0.0
pdf_b = self._pdf(self.upper) if not self.upper is None else 0.0

# b-> inf, cdf(b) = 1.0, a->-inf, cdf(a) = 0
cdf_a = self._cdf(self.lower) if not self.lower is None else 0.0
cdf_b = self._cdf(self.upper) if not self.upper is None else 1.0
return pdf_a, pdf_b, cdf_a, cdf_b
return (
r'${} \sim \text{{TruncatedNormal}}('
'\mathit{{mu}}={},~\mathit{{sd}}={},a={},b={})$'
.format(
name,
get_variable_name(self.mu),
get_variable_name(self.sd),
get_variable_name(self.lower),
get_variable_name(self.upper),
)
)


class HalfNormal(PositiveContinuous):
Expand Down Expand Up @@ -3054,7 +3032,8 @@ def __init__(self, lower=0, upper=1, c=0.5,
self.lower = lower = tt.as_tensor_variable(lower)
self.upper = upper = tt.as_tensor_variable(upper)

super(Triangular, self).__init__(*args, **kwargs)
super(Triangular, self).__init__(lower=lower, upper=upper,
*args, **kwargs)

def random(self, point=None, size=None):
"""
Expand Down Expand Up @@ -3553,7 +3532,8 @@ def __init__(self, x_points, pdf_points, *args, **kwargs):
self.lower = lower = tt.as_tensor_variable(x_points[0])
self.upper = upper = tt.as_tensor_variable(x_points[-1])

super(Interpolated, self).__init__(*args, **kwargs)
super(Interpolated, self).__init__(lower=lower, upper=upper,
*args, **kwargs)

interp = InterpolatedUnivariateSpline(
x_points, pdf_points, k=1, ext='zeros')
Expand Down
19 changes: 19 additions & 0 deletions pymc3/distributions/dist_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,25 @@ def std_cdf(x):
return .5 + .5 * tt.erf(x / tt.sqrt(2.))


def normal_lcdf(mu, sigma, x):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"""Compute the log of the cumulative density function of the normal."""
z = (x - mu) / sigma
return tt.switch(
tt.lt(z, -1.0),
tt.log(tt.erfcx(-z / tt.sqrt(2.)) / 2.) - tt.sqr(z) / 2.,
tt.log1p(-tt.erfc(z / tt.sqrt(2.)) / 2.)
)


def normal_lccdf(mu, sigma, x):
z = (x - mu) / sigma
return tt.switch(
tt.gt(z, 1.0),
tt.log(tt.erfcx(z / tt.sqrt(2.)) / 2.) - tt.sqr(z) / 2.,
tt.log1p(-tt.erfc(-z / tt.sqrt(2.)) / 2.)
)


def sd2rho(sd):
"""
`sd -> rho` theano converter
Expand Down
5 changes: 4 additions & 1 deletion pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from theano import function
import theano
from ..memoize import memoize
from ..model import Model, get_named_nodes_and_relations, FreeRV, ObservedRV, MultiObservedRV
from ..model import (
Model, get_named_nodes_and_relations, FreeRV,
ObservedRV, MultiObservedRV
)
from ..vartypes import string_types

__all__ = ['DensityDist', 'Distribution', 'Continuous', 'Discrete',
Expand Down
9 changes: 7 additions & 2 deletions pymc3/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,13 @@ def logsumexp(x, axis=None):
def logaddexp(a, b):
diff = b - a
return tt.switch(diff > 0,
b + tt.log1p(tt.exp(-diff)),
a + tt.log1p(tt.exp(diff)))
b + tt.log1p(tt.exp(-diff)),
a + tt.log1p(tt.exp(diff)))


def logdiffexp(a, b):
"""log(exp(a) - exp(b))"""
return a + log1mexp(a - b)


def invlogit(x, eps=sys.float_info.epsilon):
Expand Down