Skip to content

Commit e8997ba

Browse files
committed
Improve truncnormal test and fix example
1 parent 6791f16 commit e8997ba

File tree

2 files changed

+35
-44
lines changed

2 files changed

+35
-44
lines changed

pymc3/examples/censored_data.py

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
import numpy as np
2-
import matplotlib.pyplot as plt
3-
import pymc3 as pm
4-
import theano.tensor as tt
51
'''
62
Data can be left, right, or interval censored.
73
@@ -31,26 +27,11 @@
3127
3228
To establish a baseline they compare to an uncensored model of uncensored data.
3329
'''
30+
import numpy as np
31+
import matplotlib.pyplot as plt
32+
import pymc3 as pm
3433

35-
36-
# Helper functions
37-
def normal_lcdf(mu, sigma, x):
38-
z = (x - mu) / sigma
39-
return tt.switch(
40-
tt.lt(z, -1.0),
41-
tt.log(tt.erfcx(-z / tt.sqrt(2.)) / 2.) - tt.sqr(z) / 2.,
42-
tt.log1p(-tt.erfc(z / tt.sqrt(2.)) / 2.)
43-
)
44-
45-
46-
def normal_lccdf(mu, sigma, x):
47-
z = (x - mu) / sigma
48-
return tt.switch(
49-
tt.gt(z, 1.0),
50-
tt.log(tt.erfcx(z / tt.sqrt(2.)) / 2.) - tt.sqr(z) / 2.,
51-
tt.log1p(-tt.erfc(-z / tt.sqrt(2.)) / 2.)
52-
)
53-
34+
from pymc3.distributions.dist_math import normal_lcdf, normal_lccdf
5435

5536
# Produce normally distributed samples
5637
np.random.seed(123)
@@ -103,6 +84,7 @@ def censored_right_likelihood(mu, sigma, n_right_censored, upper_bound):
10384
def censored_left_likelihood(mu, sigma, n_left_censored, lower_bound):
10485
return n_left_censored * normal_lcdf(mu, sigma, lower_bound)
10586

87+
10688
with pm.Model() as unimputed_censored_model:
10789
mu = pm.Normal('mu', mu=0., sd=(high - low) / 2.)
10890
sigma = pm.HalfNormal('sigma', sd=(high - low) / 2.)
@@ -115,7 +97,7 @@ def censored_left_likelihood(mu, sigma, n_left_censored, lower_bound):
11597
right_censored = pm.Potential(
11698
'right_censored',
11799
censored_right_likelihood(mu, sigma, n_right_censored, high)
118-
)
100+
)
119101
left_censored = pm.Potential(
120102
'left_censored',
121103
censored_left_likelihood(mu, sigma, n_left_censored, low)
@@ -144,5 +126,6 @@ def run(n=1500):
144126
pm.plot_posterior(trace[-1000:], varnames=['mu', 'sigma'])
145127
plt.show()
146128

129+
147130
if __name__ == '__main__':
148131
run()

pymc3/tests/test_distributions.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,19 @@
77
from ..vartypes import continuous_types
88
from ..model import Model, Point, Potential, Deterministic
99
from ..blocking import DictToVarBijection, DictToArrayBijection, ArrayOrdering
10-
from ..distributions import (DensityDist, Categorical, Multinomial, VonMises, Dirichlet,
11-
MvStudentT, MvNormal, MatrixNormal, ZeroInflatedPoisson,
12-
ZeroInflatedNegativeBinomial, Constant, Poisson, Bernoulli, Beta,
13-
BetaBinomial, HalfStudentT, StudentT, Weibull, Pareto,
14-
InverseGamma, Gamma, Cauchy, HalfCauchy, Lognormal, Laplace,
15-
NegativeBinomial, Geometric, Exponential, ExGaussian, Normal, TruncatedNormal,
16-
Flat, LKJCorr, Wald, ChiSquared, HalfNormal, DiscreteUniform,
17-
Bound, Uniform, Triangular, Binomial, SkewNormal, DiscreteWeibull,
18-
Gumbel, Logistic, OrderedLogistic, LogitNormal, Interpolated,
19-
ZeroInflatedBinomial, HalfFlat, AR1, KroneckerNormal, Rice,
20-
Kumaraswamy)
10+
from ..distributions import (
11+
DensityDist, Categorical, Multinomial, VonMises, Dirichlet,
12+
MvStudentT, MvNormal, MatrixNormal, ZeroInflatedPoisson,
13+
ZeroInflatedNegativeBinomial, Constant, Poisson, Bernoulli, Beta,
14+
BetaBinomial, HalfStudentT, StudentT, Weibull, Pareto,
15+
InverseGamma, Gamma, Cauchy, HalfCauchy, Lognormal, Laplace,
16+
NegativeBinomial, Geometric, Exponential, ExGaussian, Normal, TruncatedNormal,
17+
Flat, LKJCorr, Wald, ChiSquared, HalfNormal, DiscreteUniform,
18+
Bound, Uniform, Triangular, Binomial, SkewNormal, DiscreteWeibull,
19+
Gumbel, Logistic, OrderedLogistic, LogitNormal, Interpolated,
20+
ZeroInflatedBinomial, HalfFlat, AR1, KroneckerNormal, Rice,
21+
Kumaraswamy
22+
)
2123

2224
from ..distributions import continuous
2325
from pymc3.theanof import floatX
@@ -535,15 +537,21 @@ def test_normal(self):
535537
)
536538

537539
def test_truncated_normal(self):
538-
# Rplusbig domain is specified for eveything, to avoid silly cases such as
539-
# {'mu': array(-2.1), 'a': array(-100.), 'b': array(0.01), 'value': array(0.), 'sd': array(0.01)}
540-
# TruncatedNormal: pdf = 0.0, logpdf = -inf
541-
# Scipy's answer: pdf = 0.0, logpdf = -22048.413!!!
542-
self.pymc3_matches_scipy(TruncatedNormal, R, {'mu': R, 'sd': Rplusbig, 'lower':-Rplusbig, 'upper':Rplusbig},
543-
lambda value, mu, sd, lower, upper: sp.truncnorm.logpdf(
544-
value, (lower-mu)/sd, (upper-mu)/sd, loc=mu, scale=sd),
545-
decimal=select_by_precision(float64=6, float32=1)
546-
)
540+
def scipy_logp(value, mu, sd, lower, upper):
541+
return sp.truncnorm.logpdf(
542+
value, (lower-mu)/sd, (upper-mu)/sd, loc=mu, scale=sd)
543+
544+
args = {'mu': array(-2.1), 'lower': array(-100.), 'upper': array(0.01),
545+
'sd': array(0.01)}
546+
val = TruncatedNormal.dist(**args).logp(0.)
547+
assert_allclose(val.eval(), scipy_logp(value=0, **args))
548+
549+
self.pymc3_matches_scipy(
550+
TruncatedNormal, R,
551+
{'mu': R, 'sd': Rplusbig, 'lower': -Rplusbig, 'upper': Rplusbig},
552+
scipy_logp,
553+
decimal=select_by_precision(float64=6, float32=1)
554+
)
547555

548556
def test_half_normal(self):
549557
self.pymc3_matches_scipy(HalfNormal, Rplus, {'sd': Rplus},

0 commit comments

Comments
 (0)