Skip to content

Commit ae6f1da

Browse files
authored
Merge pull request #1871 from domenzain/cdf_methods
Log CDF methods for several distributions
2 parents f8c015f + 6c76f1f commit ae6f1da

File tree

2 files changed

+87
-3
lines changed

2 files changed

+87
-3
lines changed

pymc3/distributions/continuous.py

+74-2
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,18 @@ def logp(self, value):
152152
return bound(-tt.log(upper - lower),
153153
value >= lower, value <= upper)
154154

155+
def logcdf(self, value):
156+
return tt.switch(
157+
tt.or_(tt.lt(value, self.lower), tt.gt(value, self.upper)),
158+
-np.inf,
159+
tt.switch(
160+
tt.eq(value, self.upper),
161+
0,
162+
tt.log((value - self.lower)) -
163+
tt.log((self.upper - self.lower))
164+
)
165+
)
166+
155167

156168
class Flat(Continuous):
157169
"""
@@ -240,10 +252,10 @@ def logcdf(self, value):
240252
return tt.switch(
241253
tt.lt(z, -1.0),
242254
tt.log(tt.erfcx(-z / tt.sqrt(2.)) / 2.) -
243-
tt.sqr(tt.abs_(z)) / 2,
255+
tt.sqr(z) / 2,
244256
tt.log1p(-tt.erfc(z / tt.sqrt(2.)) / 2.)
245257
)
246-
258+
247259

248260
class HalfNormal(PositiveContinuous):
249261
R"""
@@ -295,6 +307,15 @@ def logp(self, value):
295307
value >= 0,
296308
tau > 0, sd > 0)
297309

310+
def logcdf(self, value):
311+
sd = self.sd
312+
z = zvalue(value, mu=0, sd=sd)
313+
return tt.switch(
314+
tt.lt(z, -1.0),
315+
tt.log(tt.erfcx(-z / tt.sqrt(2.))) - tt.sqr(z),
316+
tt.log1p(-tt.erfc(z / tt.sqrt(2.)))
317+
)
318+
298319

299320
class Wald(PositiveContinuous):
300321
R"""
@@ -591,6 +612,20 @@ def logp(self, value):
591612

592613
return -tt.log(2 * b) - abs(value - mu) / b
593614

615+
def logcdf(self, value):
616+
a = self.mu
617+
b = self.b
618+
y = (value - a) / b
619+
return tt.switch(
620+
tt.le(value, a),
621+
tt.log(0.5) + y,
622+
tt.switch(
623+
tt.gt(y, 1),
624+
tt.log1p(-0.5 * tt.exp(-y)),
625+
tt.log(1 - 0.5 * tt.exp(-y))
626+
)
627+
)
628+
594629

595630
class Lognormal(PositiveContinuous):
596631
R"""
@@ -655,6 +690,22 @@ def logp(self, value):
655690
- tt.log(value),
656691
tau > 0)
657692

693+
def logcdf(self, value):
694+
mu = self.mu
695+
sd = self.sd
696+
z = zvalue(tt.log(value), mu=mu, sd=sd)
697+
698+
return tt.switch(
699+
tt.le(value, 0),
700+
-np.inf,
701+
tt.switch(
702+
tt.lt(z, -1.0),
703+
tt.log(tt.erfcx(-z / tt.sqrt(2.)) / 2.) -
704+
tt.sqr(z) / 2,
705+
tt.log1p(-tt.erfc(z / tt.sqrt(2.)) / 2.)
706+
)
707+
)
708+
658709

659710
class StudentT(Continuous):
660711
R"""
@@ -833,6 +884,9 @@ def logp(self, value):
833884
- tt.log1p(((value - alpha) / beta)**2),
834885
beta > 0)
835886

887+
def logcdf (self, value):
888+
return tt.log(0.5 + tt.arctan ((value - self.alpha) / self.beta) / np.pi)
889+
836890

837891
class HalfCauchy(PositiveContinuous):
838892
R"""
@@ -1360,3 +1414,21 @@ def logp(self, value):
13601414
tt.switch(tt.eq(value, c), tt.log(2 / (upper - lower)),
13611415
tt.switch(alltrue_elemwise([c < value, value <= upper]),
13621416
tt.log(2 * (upper - value) / ((upper - lower) * (upper - c))),np.inf)))
1417+
1418+
def logcdf(self, value):
1419+
l = self.lower
1420+
u = self.upper
1421+
c = self.c
1422+
return tt.switch(
1423+
tt.le(value, l),
1424+
-np.inf,
1425+
tt.switch(
1426+
tt.le(value, c),
1427+
tt.log(((value - l) ** 2) / ((u - l) * (c - l))),
1428+
tt.switch(
1429+
tt.lt(value, u),
1430+
tt.log1p(-((u - value) ** 2) / ((u - l) * (u - c))),
1431+
0
1432+
)
1433+
)
1434+
)

pymc3/tests/test_distributions.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -366,11 +366,15 @@ def test_uniform(self):
366366
self.pymc3_matches_scipy(
367367
Uniform, Runif, {'lower': -Rplusunif, 'upper': Rplusunif},
368368
lambda value, lower, upper: sp.uniform.logpdf(value, lower, upper - lower))
369+
self.check_logcdf(Uniform, Runif, {'lower': -Rplusunif, 'upper': Rplusunif},
370+
lambda value, lower, upper: sp.uniform.logcdf(value, lower, upper - lower))
369371

370372
def test_triangular(self):
371373
self.pymc3_matches_scipy(
372374
Triangular, Runif, {'lower': -Rplusunif, 'c': Runif, 'upper': Rplusunif},
373375
lambda value, c, lower, upper: sp.triang.logpdf(value, c-lower, lower, upper-lower))
376+
self.check_logcdf(Triangular, Runif, {'lower': -Rplusunif, 'c': Runif, 'upper': Rplusunif},
377+
lambda value, c, lower, upper: sp.triang.logcdf(value, c-lower, lower, upper-lower))
374378

375379
def test_bound_normal(self):
376380
PositiveNormal = Bound(Normal, lower=0.)
@@ -390,12 +394,14 @@ def test_flat(self):
390394
def test_normal(self):
391395
self.pymc3_matches_scipy(Normal, R, {'mu': R, 'sd': Rplus},
392396
lambda value, mu, sd: sp.norm.logpdf(value, mu, sd))
393-
self.check_logcdf(Normal, R, {'mu': R, 'sd': Rplus},
397+
self.check_logcdf(Normal, R, {'mu': R, 'sd': Rplus},
394398
lambda value, mu, sd: sp.norm.logcdf(value, mu, sd))
395399

396400
def test_half_normal(self):
397401
self.pymc3_matches_scipy(HalfNormal, Rplus, {'sd': Rplus},
398402
lambda value, sd: sp.halfnorm.logpdf(value, scale=sd))
403+
self.check_logcdf(HalfNormal, Rplus, {'sd': Rplus},
404+
lambda value, sd: sp.halfnorm.logcdf(value, scale=sd))
399405

400406
def test_chi_squared(self):
401407
self.pymc3_matches_scipy(ChiSquared, Rplus, {'nu': Rplusdunif},
@@ -452,11 +458,15 @@ def test_fun(value, mu, alpha):
452458
def test_laplace(self):
453459
self.pymc3_matches_scipy(Laplace, R, {'mu': R, 'b': Rplus},
454460
lambda value, mu, b: sp.laplace.logpdf(value, mu, b))
461+
self.check_logcdf(Laplace, R, {'mu': R, 'b': Rplus},
462+
lambda value, mu, b: sp.laplace.logcdf(value, mu, b))
455463

456464
def test_lognormal(self):
457465
self.pymc3_matches_scipy(
458466
Lognormal, Rplus, {'mu': R, 'tau': Rplusbig},
459467
lambda value, mu, tau: sp.lognorm.logpdf(value, tau**-.5, 0, np.exp(mu)))
468+
self.check_logcdf(Lognormal, Rplus, {'mu': R, 'tau': Rplusbig},
469+
lambda value, mu, tau: sp.lognorm.logcdf(value, tau**-.5, 0, np.exp(mu)))
460470

461471
def test_t(self):
462472
self.pymc3_matches_scipy(StudentT, R, {'nu': Rplus, 'mu': R, 'lam': Rplus},
@@ -465,6 +475,8 @@ def test_t(self):
465475
def test_cauchy(self):
466476
self.pymc3_matches_scipy(Cauchy, R, {'alpha': R, 'beta': Rplusbig},
467477
lambda value, alpha, beta: sp.cauchy.logpdf(value, alpha, beta))
478+
self.check_logcdf(Cauchy, R, {'alpha': R, 'beta': Rplusbig},
479+
lambda value, alpha, beta: sp.cauchy.logcdf(value, alpha, beta))
468480

469481
def test_half_cauchy(self):
470482
self.pymc3_matches_scipy(HalfCauchy, Rplus, {'beta': Rplusbig},

0 commit comments

Comments
 (0)