|
7 | 7 | from ..vartypes import continuous_types
|
8 | 8 | from ..model import Model, Point, Potential, Deterministic
|
9 | 9 | from ..blocking import DictToVarBijection, DictToArrayBijection, ArrayOrdering
|
10 |
| -from ..distributions import (DensityDist, Categorical, Multinomial, VonMises, Dirichlet, |
11 |
| - MvStudentT, MvNormal, MatrixNormal, ZeroInflatedPoisson, |
12 |
| - ZeroInflatedNegativeBinomial, Constant, Poisson, Bernoulli, Beta, |
13 |
| - BetaBinomial, HalfStudentT, StudentT, Weibull, Pareto, |
14 |
| - InverseGamma, Gamma, Cauchy, HalfCauchy, Lognormal, Laplace, |
15 |
| - NegativeBinomial, Geometric, Exponential, ExGaussian, Normal, TruncatedNormal, |
16 |
| - Flat, LKJCorr, Wald, ChiSquared, HalfNormal, DiscreteUniform, |
17 |
| - Bound, Uniform, Triangular, Binomial, SkewNormal, DiscreteWeibull, |
18 |
| - Gumbel, Logistic, OrderedLogistic, LogitNormal, Interpolated, |
19 |
| - ZeroInflatedBinomial, HalfFlat, AR1, KroneckerNormal, Rice, |
20 |
| - Kumaraswamy) |
| 10 | +from ..distributions import ( |
| 11 | + DensityDist, Categorical, Multinomial, VonMises, Dirichlet, |
| 12 | + MvStudentT, MvNormal, MatrixNormal, ZeroInflatedPoisson, |
| 13 | + ZeroInflatedNegativeBinomial, Constant, Poisson, Bernoulli, Beta, |
| 14 | + BetaBinomial, HalfStudentT, StudentT, Weibull, Pareto, |
| 15 | + InverseGamma, Gamma, Cauchy, HalfCauchy, Lognormal, Laplace, |
| 16 | + NegativeBinomial, Geometric, Exponential, ExGaussian, Normal, TruncatedNormal, |
| 17 | + Flat, LKJCorr, Wald, ChiSquared, HalfNormal, DiscreteUniform, |
| 18 | + Bound, Uniform, Triangular, Binomial, SkewNormal, DiscreteWeibull, |
| 19 | + Gumbel, Logistic, OrderedLogistic, LogitNormal, Interpolated, |
| 20 | + ZeroInflatedBinomial, HalfFlat, AR1, KroneckerNormal, Rice, |
| 21 | + Kumaraswamy |
| 22 | +) |
21 | 23 |
|
22 | 24 | from ..distributions import continuous
|
23 | 25 | from pymc3.theanof import floatX
|
@@ -535,15 +537,21 @@ def test_normal(self):
|
535 | 537 | )
|
536 | 538 |
|
537 | 539 | def test_truncated_normal(self):
|
538 |
| - # Rplusbig domain is specified for eveything, to avoid silly cases such as |
539 |
| - # {'mu': array(-2.1), 'a': array(-100.), 'b': array(0.01), 'value': array(0.), 'sd': array(0.01)} |
540 |
| - # TruncatedNormal: pdf = 0.0, logpdf = -inf |
541 |
| - # Scipy's answer: pdf = 0.0, logpdf = -22048.413!!! |
542 |
| - self.pymc3_matches_scipy(TruncatedNormal, R, {'mu': R, 'sd': Rplusbig, 'lower':-Rplusbig, 'upper':Rplusbig}, |
543 |
| - lambda value, mu, sd, lower, upper: sp.truncnorm.logpdf( |
544 |
| - value, (lower-mu)/sd, (upper-mu)/sd, loc=mu, scale=sd), |
545 |
| - decimal=select_by_precision(float64=6, float32=1) |
546 |
| - ) |
| 540 | + def scipy_logp(value, mu, sd, lower, upper): |
| 541 | + return sp.truncnorm.logpdf( |
| 542 | + value, (lower-mu)/sd, (upper-mu)/sd, loc=mu, scale=sd) |
| 543 | + |
| 544 | + args = {'mu': array(-2.1), 'lower': array(-100.), 'upper': array(0.01), |
| 545 | + 'sd': array(0.01)} |
| 546 | + val = TruncatedNormal.dist(**args).logp(0.) |
| 547 | + assert_allclose(val.eval(), scipy_logp(value=0, **args)) |
| 548 | + |
| 549 | + self.pymc3_matches_scipy( |
| 550 | + TruncatedNormal, R, |
| 551 | + {'mu': R, 'sd': Rplusbig, 'lower': -Rplusbig, 'upper': Rplusbig}, |
| 552 | + scipy_logp, |
| 553 | + decimal=select_by_precision(float64=6, float32=1) |
| 554 | + ) |
547 | 555 |
|
548 | 556 | def test_half_normal(self):
|
549 | 557 | self.pymc3_matches_scipy(HalfNormal, Rplus, {'sd': Rplus},
|
|
0 commit comments